Merge branch 'aiogram:dev-2.x' into dev-2.x-bot-download

This commit is contained in:
darksidecat 2021-09-21 10:48:05 +03:00 committed by GitHub
commit 962a7edb52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 259 additions and 46 deletions

View file

@ -43,5 +43,5 @@ __all__ = (
'utils',
)
__version__ = '2.14.3'
__version__ = '2.15'
__api_version__ = '5.3'

View file

@ -48,7 +48,7 @@ class RedisStorage(BaseStorage):
self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs
self._redis: typing.Optional[aioredis.RedisConnection] = None
self._redis: typing.Optional["aioredis.RedisConnection"] = None
self._connection_lock = asyncio.Lock(loop=self._loop)
async def close(self):
@ -62,7 +62,7 @@ class RedisStorage(BaseStorage):
return await self._redis.wait_closed()
return True
async def redis(self) -> aioredis.RedisConnection:
async def redis(self) -> "aioredis.RedisConnection":
"""
Get Redis connection
"""
@ -220,9 +220,9 @@ class AioRedisAdapterBase(ABC):
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
state_ttl: typing.Optional[int] = None,
data_ttl: typing.Optional[int] = None,
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
@ -247,7 +247,7 @@ class AioRedisAdapterBase(ABC):
"""Get Redis connection."""
pass
def close(self):
async def close(self):
"""Grace shutdown."""
pass
@ -257,6 +257,8 @@ class AioRedisAdapterBase(ABC):
async def set(self, name, value, ex=None, **kwargs):
"""Set the value at key ``name`` to ``value``."""
if ex == 0:
ex = None
return await self._redis.set(name, value, ex=ex, **kwargs)
async def get(self, name, **kwargs):
@ -295,7 +297,7 @@ class AioRedisAdapterV1(AioRedisAdapterBase):
)
return self._redis
def close(self):
async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
@ -310,6 +312,8 @@ class AioRedisAdapterV1(AioRedisAdapterBase):
return await self._redis.get(name, encoding="utf8", **kwargs)
async def set(self, name, value, ex=None, **kwargs):
if ex == 0:
ex = None
return await self._redis.set(name, value, expire=ex, **kwargs)
async def keys(self, pattern, **kwargs):
@ -331,6 +335,7 @@ class AioRedisAdapterV2(AioRedisAdapterBase):
password=self._password,
ssl=self._ssl,
max_connections=self._pool_size,
decode_responses=True,
**self._kwargs,
)
return self._redis
@ -367,9 +372,9 @@ class RedisStorage2(BaseStorage):
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
state_ttl: typing.Optional[int] = None,
data_ttl: typing.Optional[int] = None,
bucket_ttl: typing.Optional[int] = None,
**kwargs,
):
self._host = host
@ -413,6 +418,9 @@ class RedisStorage2(BaseStorage):
self._redis = AioRedisAdapterV1(**connection_data)
elif redis_version == 2:
self._redis = AioRedisAdapterV2(**connection_data)
else:
raise RuntimeError(f"Unsupported aioredis version: {redis_version}")
await self._redis.get_redis()
return self._redis
def generate_key(self, *parts):
@ -420,11 +428,12 @@ class RedisStorage2(BaseStorage):
async def close(self):
if self._redis:
return self._redis.close()
return await self._redis.close()
async def wait_closed(self):
if self._redis:
return await self._redis.wait_closed()
await self._redis.wait_closed()
self._redis = None
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
@ -451,7 +460,7 @@ class RedisStorage2(BaseStorage):
if state is None:
await redis.delete(key)
else:
await redis.set(key, self.resolve_state(state), expire=self._state_ttl)
await redis.set(key, self.resolve_state(state), ex=self._state_ttl)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
@ -459,7 +468,7 @@ class RedisStorage2(BaseStorage):
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self._get_adapter()
if data:
await redis.set(key, json.dumps(data), expire=self._data_ttl)
await redis.set(key, json.dumps(data), ex=self._data_ttl)
else:
await redis.delete(key)
@ -490,7 +499,7 @@ class RedisStorage2(BaseStorage):
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self._get_adapter()
if bucket:
await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl)
await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
else:
await redis.delete(key)

View file

@ -1,5 +1,9 @@
import os
import pathlib
from io import IOBase
from typing import Union, Optional
from aiogram.utils.deprecated import warn_deprecated
class Downloadable:
@ -7,32 +11,86 @@ class Downloadable:
Mixin for files
"""
async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download(
self,
destination=None,
timeout=30,
chunk_size=65536,
seek=True,
make_dirs=True,
*,
destination_dir: Optional[Union[str, pathlib.Path]] = None,
destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None
):
"""
Download file
:param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
At most one of these parameters can be used: :param destination_dir:, :param destination_file:
:param destination: deprecated, use :param destination_dir: or :param destination_file: instead
:param timeout: Integer
:param chunk_size: Integer
:param seek: Boolean - go to start of file when downloading is finished.
:param make_dirs: Make dirs if not exist
:param destination_dir: directory for saving files
:param destination_file: path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
:return: destination
"""
if destination:
warn_deprecated(
"destination parameter is deprecated, please use destination_dir or destination_file."
)
if destination_dir and destination_file:
raise ValueError(
"Use only one of the parameters: destination_dir or destination_file."
)
file, destination = await self._prepare_destination(
destination,
destination_dir,
destination_file,
make_dirs
)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs):
file = await self.get_file()
is_path = True
if destination is None:
if not(any((dest, destination_dir, destination_file))):
destination = file.file_path
elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination):
destination = os.path.join(destination, file.file_path)
else:
is_path = False
if is_path and make_dirs:
elif dest: # backward compatibility
if isinstance(dest, IOBase):
return file, dest
if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest):
destination = os.path.join(dest, file.file_path)
else:
destination = dest
elif destination_dir:
if isinstance(destination_dir, (str, pathlib.Path)):
destination = os.path.join(destination_dir, file.file_path)
else:
raise TypeError("destination_dir must be str or pathlib.Path")
else:
if isinstance(destination_file, IOBase):
return file, destination_file
elif isinstance(destination_file, (str, pathlib.Path)):
destination = destination_file
else:
raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type")
if make_dirs and os.path.dirname(destination):
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return file, destination
async def get_file(self):
"""

View file

@ -64,10 +64,10 @@ class User(base.TelegramObject):
return getattr(self, '_locale')
@property
def url(self):
def url(self) -> str:
return f"tg://user?id={self.id}"
def get_mention(self, name=None, as_html=None):
def get_mention(self, name: Optional[str] = None, as_html: Optional[bool] = None) -> str:
if as_html is None and self.bot.parse_mode and self.bot.parse_mode.lower() == 'html':
as_html = True

View file

@ -1,28 +1,53 @@
import aioredis
import pytest
from _pytest.config import UsageError
import aioredis.util
try:
import aioredis.util
except ImportError:
pass
def pytest_addoption(parser):
parser.addoption("--redis", default=None,
help="run tests which require redis connection")
parser.addoption(
"--redis",
default=None,
help="run tests which require redis connection",
)
def pytest_configure(config):
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
config.addinivalue_line(
"markers",
"redis: marked tests require redis connection to run",
)
def pytest_collection_modifyitems(config, items):
redis_uri = config.getoption("--redis")
if redis_uri is None:
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
skip_redis = pytest.mark.skip(
reason="need --redis option with redis URI to run"
)
for item in items:
if "redis" in item.keywords:
item.add_marker(skip_redis)
return
redis_version = int(aioredis.__version__.split(".")[0])
options = None
if redis_version == 1:
(host, port), options = aioredis.util.parse_url(redis_uri)
options.update({'host': host, 'port': port})
elif redis_version == 2:
try:
options = aioredis.connection.parse_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
try:
address, options = aioredis.util.parse_url(redis_uri)
assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo."
assert isinstance(options, dict), \
"Only redis and rediss schemas are supported, eg redis://foo."
except AssertionError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
@ -30,6 +55,20 @@ def pytest_collection_modifyitems(config, items):
@pytest.fixture(scope='session')
def redis_options(request):
redis_uri = request.config.getoption("--redis")
(host, port), options = aioredis.util.parse_url(redis_uri)
options.update({'host': host, 'port': port})
return options
if redis_uri is None:
pytest.skip("need --redis option with redis URI to run")
return
redis_version = int(aioredis.__version__.split(".")[0])
if redis_version == 1:
(host, port), options = aioredis.util.parse_url(redis_uri)
options.update({'host': host, 'port': port})
return options
if redis_version == 2:
try:
return aioredis.connection.parse_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
raise UsageError("Unsupported aioredis version")

View file

@ -1,12 +1,16 @@
import aioredis
import pytest
from pytest_lazyfixture import lazy_fixture
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.contrib.fsm_storage.redis import RedisStorage2, RedisStorage
from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2
@pytest.fixture()
@pytest.mark.redis
async def redis_store(redis_options):
if int(aioredis.__version__.split(".")[0]) == 2:
pytest.skip('aioredis v2 is not supported.')
return
s = RedisStorage(**redis_options)
try:
yield s
@ -37,9 +41,9 @@ async def memory_store():
@pytest.mark.parametrize(
"store", [
pytest.lazy_fixture('redis_store'),
pytest.lazy_fixture('redis_store2'),
pytest.lazy_fixture('memory_store'),
lazy_fixture('redis_store'),
lazy_fixture('redis_store2'),
lazy_fixture('memory_store'),
]
)
class TestStorage:
@ -63,8 +67,8 @@ class TestStorage:
@pytest.mark.parametrize(
"store", [
pytest.lazy_fixture('redis_store'),
pytest.lazy_fixture('redis_store2'),
lazy_fixture('redis_store'),
lazy_fixture('redis_store2'),
]
)
class TestRedisStorage2:
@ -74,6 +78,7 @@ class TestRedisStorage2:
assert await store.get_data(chat='1234') == {'foo': 'bar'}
pool_id = id(store._redis)
await store.close()
await store.wait_closed()
assert await store.get_data(chat='1234') == {
'foo': 'bar'} # new pool was opened at this point
assert id(store._redis) != pool_id

102
tests/types/test_mixins.py Normal file
View file

@ -0,0 +1,102 @@
import os
from io import BytesIO
from pathlib import Path
import pytest
from aiogram import Bot
from aiogram.types import File
from aiogram.types.mixins import Downloadable
from tests import TOKEN
from tests.types.dataset import FILE
pytestmark = pytest.mark.asyncio
@pytest.fixture(name='bot')
async def bot_fixture():
""" Bot fixture """
_bot = Bot(TOKEN)
yield _bot
await _bot.session.close()
@pytest.fixture
def tmppath(tmpdir, request):
os.chdir(tmpdir)
yield Path(tmpdir)
os.chdir(request.config.invocation_dir)
@pytest.fixture
def downloadable(bot):
async def get_file():
return File(**FILE)
downloadable = Downloadable()
downloadable.get_file = get_file
downloadable.bot = bot
return downloadable
class TestDownloadable:
async def test_download_make_dirs_false_nodir(self, tmppath, downloadable):
with pytest.raises(FileNotFoundError):
await downloadable.download(make_dirs=False)
async def test_download_make_dirs_false_mkdir(self, tmppath, downloadable):
os.mkdir('voice')
await downloadable.download(make_dirs=False)
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
async def test_download_make_dirs_true(self, tmppath, downloadable):
await downloadable.download(make_dirs=True)
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
async def test_download_deprecation_warning(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download("test.file")
async def test_download_destination(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download("test.file")
assert os.path.isfile(tmppath.joinpath('test.file'))
async def test_download_destination_dir_exist(self, tmppath, downloadable):
os.mkdir("test_folder")
with pytest.deprecated_call():
await downloadable.download("test_folder")
assert os.path.isfile(tmppath.joinpath('test_folder', FILE["file_path"]))
async def test_download_destination_with_dir(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download(os.path.join('dir_name', 'file_name'))
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
async def test_download_destination_io_bytes(self, tmppath, downloadable):
file = BytesIO()
with pytest.deprecated_call():
await downloadable.download(file)
assert len(file.read()) != 0
async def test_download_raise_value_error(self, tmppath, downloadable):
with pytest.raises(ValueError):
await downloadable.download(destination_dir="a", destination_file="b")
async def test_download_destination_dir(self, tmppath, downloadable):
await downloadable.download(destination_dir='test_dir')
assert os.path.isfile(tmppath.joinpath('test_dir', FILE["file_path"]))
async def test_download_destination_file(self, tmppath, downloadable):
await downloadable.download(destination_file='file_name')
assert os.path.isfile(tmppath.joinpath('file_name'))
async def test_download_destination_file_with_dir(self, tmppath, downloadable):
await downloadable.download(destination_file=os.path.join('dir_name', 'file_name'))
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
async def test_download_io_bytes(self, tmppath, downloadable):
file = BytesIO()
await downloadable.download(destination_file=file)
assert len(file.read()) != 0