diff --git a/aiogram/__init__.py b/aiogram/__init__.py index ceca7f58..3ab83d9d 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -43,5 +43,5 @@ __all__ = ( 'utils', ) -__version__ = '2.14.3' +__version__ = '2.15' __api_version__ = '5.3' diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index c8b95517..ce25ee07 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -48,7 +48,7 @@ class RedisStorage(BaseStorage): self._loop = loop or asyncio.get_event_loop() self._kwargs = kwargs - self._redis: typing.Optional[aioredis.RedisConnection] = None + self._redis: typing.Optional["aioredis.RedisConnection"] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def close(self): @@ -62,7 +62,7 @@ class RedisStorage(BaseStorage): return await self._redis.wait_closed() return True - async def redis(self) -> aioredis.RedisConnection: + async def redis(self) -> "aioredis.RedisConnection": """ Get Redis connection """ @@ -220,9 +220,9 @@ class AioRedisAdapterBase(ABC): pool_size: int = 10, loop: typing.Optional[asyncio.AbstractEventLoop] = None, prefix: str = "fsm", - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, + state_ttl: typing.Optional[int] = None, + data_ttl: typing.Optional[int] = None, + bucket_ttl: typing.Optional[int] = None, **kwargs, ): self._host = host @@ -247,7 +247,7 @@ class AioRedisAdapterBase(ABC): """Get Redis connection.""" pass - def close(self): + async def close(self): """Grace shutdown.""" pass @@ -257,6 +257,8 @@ class AioRedisAdapterBase(ABC): async def set(self, name, value, ex=None, **kwargs): """Set the value at key ``name`` to ``value``.""" + if ex == 0: + ex = None return await self._redis.set(name, value, ex=ex, **kwargs) async def get(self, name, **kwargs): @@ -295,7 +297,7 @@ class AioRedisAdapterV1(AioRedisAdapterBase): ) return self._redis - def close(self): + async def close(self): async with self._connection_lock: if self._redis and not self._redis.closed: self._redis.close() @@ -310,6 +312,8 @@ class AioRedisAdapterV1(AioRedisAdapterBase): return await self._redis.get(name, encoding="utf8", **kwargs) async def set(self, name, value, ex=None, **kwargs): + if ex == 0: + ex = None return await self._redis.set(name, value, expire=ex, **kwargs) async def keys(self, pattern, **kwargs): @@ -331,6 +335,7 @@ class AioRedisAdapterV2(AioRedisAdapterBase): password=self._password, ssl=self._ssl, max_connections=self._pool_size, + decode_responses=True, **self._kwargs, ) return self._redis @@ -367,9 +372,9 @@ class RedisStorage2(BaseStorage): pool_size: int = 10, loop: typing.Optional[asyncio.AbstractEventLoop] = None, prefix: str = "fsm", - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, + state_ttl: typing.Optional[int] = None, + data_ttl: typing.Optional[int] = None, + bucket_ttl: typing.Optional[int] = None, **kwargs, ): self._host = host @@ -413,6 +418,9 @@ class RedisStorage2(BaseStorage): self._redis = AioRedisAdapterV1(**connection_data) elif redis_version == 2: self._redis = AioRedisAdapterV2(**connection_data) + else: + raise RuntimeError(f"Unsupported aioredis version: {redis_version}") + await self._redis.get_redis() return self._redis def generate_key(self, *parts): @@ -420,11 +428,12 @@ class RedisStorage2(BaseStorage): async def close(self): if self._redis: - return self._redis.close() + return await self._redis.close() async def wait_closed(self): if self._redis: - return await self._redis.wait_closed() + await self._redis.wait_closed() + self._redis = None async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: @@ -451,7 +460,7 @@ class RedisStorage2(BaseStorage): if state is None: await redis.delete(key) else: - await redis.set(key, self.resolve_state(state), expire=self._state_ttl) + await redis.set(key, self.resolve_state(state), ex=self._state_ttl) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): @@ -459,7 +468,7 @@ class RedisStorage2(BaseStorage): key = self.generate_key(chat, user, STATE_DATA_KEY) redis = await self._get_adapter() if data: - await redis.set(key, json.dumps(data), expire=self._data_ttl) + await redis.set(key, json.dumps(data), ex=self._data_ttl) else: await redis.delete(key) @@ -490,7 +499,7 @@ class RedisStorage2(BaseStorage): key = self.generate_key(chat, user, STATE_BUCKET_KEY) redis = await self._get_adapter() if bucket: - await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) + await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) else: await redis.delete(key) diff --git a/aiogram/types/mixins.py b/aiogram/types/mixins.py index 13f8412f..83c65032 100644 --- a/aiogram/types/mixins.py +++ b/aiogram/types/mixins.py @@ -1,5 +1,9 @@ import os import pathlib +from io import IOBase +from typing import Union, Optional + +from aiogram.utils.deprecated import warn_deprecated class Downloadable: @@ -7,32 +11,86 @@ class Downloadable: Mixin for files """ - async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True): + async def download( + self, + destination=None, + timeout=30, + chunk_size=65536, + seek=True, + make_dirs=True, + *, + destination_dir: Optional[Union[str, pathlib.Path]] = None, + destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None + ): """ Download file - :param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` + At most one of these parameters can be used: :param destination_dir:, :param destination_file: + + :param destination: deprecated, use :param destination_dir: or :param destination_file: instead :param timeout: Integer :param chunk_size: Integer :param seek: Boolean - go to start of file when downloading is finished. :param make_dirs: Make dirs if not exist + :param destination_dir: directory for saving files + :param destination_file: path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` :return: destination """ + if destination: + warn_deprecated( + "destination parameter is deprecated, please use destination_dir or destination_file." + ) + if destination_dir and destination_file: + raise ValueError( + "Use only one of the parameters: destination_dir or destination_file." + ) + + file, destination = await self._prepare_destination( + destination, + destination_dir, + destination_file, + make_dirs + ) + + return await self.bot.download_file( + file_path=file.file_path, + destination=destination, + timeout=timeout, + chunk_size=chunk_size, + seek=seek, + ) + + async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs): file = await self.get_file() - is_path = True - if destination is None: + if not(any((dest, destination_dir, destination_file))): destination = file.file_path - elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination): - destination = os.path.join(destination, file.file_path) - else: - is_path = False - if is_path and make_dirs: + elif dest: # backward compatibility + if isinstance(dest, IOBase): + return file, dest + if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest): + destination = os.path.join(dest, file.file_path) + else: + destination = dest + + elif destination_dir: + if isinstance(destination_dir, (str, pathlib.Path)): + destination = os.path.join(destination_dir, file.file_path) + else: + raise TypeError("destination_dir must be str or pathlib.Path") + else: + if isinstance(destination_file, IOBase): + return file, destination_file + elif isinstance(destination_file, (str, pathlib.Path)): + destination = destination_file + else: + raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type") + + if make_dirs and os.path.dirname(destination): os.makedirs(os.path.dirname(destination), exist_ok=True) - return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout, - chunk_size=chunk_size, seek=seek) + return file, destination async def get_file(self): """ diff --git a/aiogram/types/user.py b/aiogram/types/user.py index 8263cfc2..6bde2dcd 100644 --- a/aiogram/types/user.py +++ b/aiogram/types/user.py @@ -64,10 +64,10 @@ class User(base.TelegramObject): return getattr(self, '_locale') @property - def url(self): + def url(self) -> str: return f"tg://user?id={self.id}" - def get_mention(self, name=None, as_html=None): + def get_mention(self, name: Optional[str] = None, as_html: Optional[bool] = None) -> str: if as_html is None and self.bot.parse_mode and self.bot.parse_mode.lower() == 'html': as_html = True diff --git a/tests/conftest.py b/tests/conftest.py index 03c8dbe4..b56c7b77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,28 +1,53 @@ +import aioredis import pytest from _pytest.config import UsageError -import aioredis.util + +try: + import aioredis.util +except ImportError: + pass def pytest_addoption(parser): - parser.addoption("--redis", default=None, - help="run tests which require redis connection") + parser.addoption( + "--redis", + default=None, + help="run tests which require redis connection", + ) def pytest_configure(config): - config.addinivalue_line("markers", "redis: marked tests require redis connection to run") + config.addinivalue_line( + "markers", + "redis: marked tests require redis connection to run", + ) def pytest_collection_modifyitems(config, items): redis_uri = config.getoption("--redis") if redis_uri is None: - skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run") + skip_redis = pytest.mark.skip( + reason="need --redis option with redis URI to run" + ) for item in items: if "redis" in item.keywords: item.add_marker(skip_redis) return + + redis_version = int(aioredis.__version__.split(".")[0]) + options = None + if redis_version == 1: + (host, port), options = aioredis.util.parse_url(redis_uri) + options.update({'host': host, 'port': port}) + elif redis_version == 2: + try: + options = aioredis.connection.parse_url(redis_uri) + except ValueError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + try: - address, options = aioredis.util.parse_url(redis_uri) - assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo." + assert isinstance(options, dict), \ + "Only redis and rediss schemas are supported, eg redis://foo." except AssertionError as e: raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") @@ -30,6 +55,20 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope='session') def redis_options(request): redis_uri = request.config.getoption("--redis") - (host, port), options = aioredis.util.parse_url(redis_uri) - options.update({'host': host, 'port': port}) - return options + if redis_uri is None: + pytest.skip("need --redis option with redis URI to run") + return + + redis_version = int(aioredis.__version__.split(".")[0]) + if redis_version == 1: + (host, port), options = aioredis.util.parse_url(redis_uri) + options.update({'host': host, 'port': port}) + return options + + if redis_version == 2: + try: + return aioredis.connection.parse_url(redis_uri) + except ValueError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + + raise UsageError("Unsupported aioredis version") diff --git a/tests/contrib/fsm_storage/test_storage.py b/tests/contrib/fsm_storage/test_storage.py index 0cde2de2..2668cdab 100644 --- a/tests/contrib/fsm_storage/test_storage.py +++ b/tests/contrib/fsm_storage/test_storage.py @@ -1,12 +1,16 @@ +import aioredis import pytest - +from pytest_lazyfixture import lazy_fixture from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.contrib.fsm_storage.redis import RedisStorage2, RedisStorage +from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 @pytest.fixture() @pytest.mark.redis async def redis_store(redis_options): + if int(aioredis.__version__.split(".")[0]) == 2: + pytest.skip('aioredis v2 is not supported.') + return s = RedisStorage(**redis_options) try: yield s @@ -37,9 +41,9 @@ async def memory_store(): @pytest.mark.parametrize( "store", [ - pytest.lazy_fixture('redis_store'), - pytest.lazy_fixture('redis_store2'), - pytest.lazy_fixture('memory_store'), + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), + lazy_fixture('memory_store'), ] ) class TestStorage: @@ -63,8 +67,8 @@ class TestStorage: @pytest.mark.parametrize( "store", [ - pytest.lazy_fixture('redis_store'), - pytest.lazy_fixture('redis_store2'), + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), ] ) class TestRedisStorage2: @@ -74,6 +78,7 @@ class TestRedisStorage2: assert await store.get_data(chat='1234') == {'foo': 'bar'} pool_id = id(store._redis) await store.close() + await store.wait_closed() assert await store.get_data(chat='1234') == { 'foo': 'bar'} # new pool was opened at this point assert id(store._redis) != pool_id diff --git a/tests/types/test_mixins.py b/tests/types/test_mixins.py new file mode 100644 index 00000000..4327e8aa --- /dev/null +++ b/tests/types/test_mixins.py @@ -0,0 +1,102 @@ +import os +from io import BytesIO +from pathlib import Path + +import pytest + +from aiogram import Bot +from aiogram.types import File +from aiogram.types.mixins import Downloadable +from tests import TOKEN +from tests.types.dataset import FILE + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(name='bot') +async def bot_fixture(): + """ Bot fixture """ + _bot = Bot(TOKEN) + yield _bot + await _bot.session.close() + + +@pytest.fixture +def tmppath(tmpdir, request): + os.chdir(tmpdir) + yield Path(tmpdir) + os.chdir(request.config.invocation_dir) + + +@pytest.fixture +def downloadable(bot): + async def get_file(): + return File(**FILE) + + downloadable = Downloadable() + downloadable.get_file = get_file + downloadable.bot = bot + + return downloadable + + +class TestDownloadable: + async def test_download_make_dirs_false_nodir(self, tmppath, downloadable): + with pytest.raises(FileNotFoundError): + await downloadable.download(make_dirs=False) + + async def test_download_make_dirs_false_mkdir(self, tmppath, downloadable): + os.mkdir('voice') + await downloadable.download(make_dirs=False) + assert os.path.isfile(tmppath.joinpath(FILE["file_path"])) + + async def test_download_make_dirs_true(self, tmppath, downloadable): + await downloadable.download(make_dirs=True) + assert os.path.isfile(tmppath.joinpath(FILE["file_path"])) + + async def test_download_deprecation_warning(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test.file") + + async def test_download_destination(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test.file") + assert os.path.isfile(tmppath.joinpath('test.file')) + + async def test_download_destination_dir_exist(self, tmppath, downloadable): + os.mkdir("test_folder") + with pytest.deprecated_call(): + await downloadable.download("test_folder") + assert os.path.isfile(tmppath.joinpath('test_folder', FILE["file_path"])) + + async def test_download_destination_with_dir(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download(os.path.join('dir_name', 'file_name')) + assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name')) + + async def test_download_destination_io_bytes(self, tmppath, downloadable): + file = BytesIO() + with pytest.deprecated_call(): + await downloadable.download(file) + assert len(file.read()) != 0 + + async def test_download_raise_value_error(self, tmppath, downloadable): + with pytest.raises(ValueError): + await downloadable.download(destination_dir="a", destination_file="b") + + async def test_download_destination_dir(self, tmppath, downloadable): + await downloadable.download(destination_dir='test_dir') + assert os.path.isfile(tmppath.joinpath('test_dir', FILE["file_path"])) + + async def test_download_destination_file(self, tmppath, downloadable): + await downloadable.download(destination_file='file_name') + assert os.path.isfile(tmppath.joinpath('file_name')) + + async def test_download_destination_file_with_dir(self, tmppath, downloadable): + await downloadable.download(destination_file=os.path.join('dir_name', 'file_name')) + assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name')) + + async def test_download_io_bytes(self, tmppath, downloadable): + file = BytesIO() + await downloadable.download(destination_file=file) + assert len(file.read()) != 0