diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index c8b95517..7037fadf 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -48,7 +48,7 @@ class RedisStorage(BaseStorage): self._loop = loop or asyncio.get_event_loop() self._kwargs = kwargs - self._redis: typing.Optional[aioredis.RedisConnection] = None + self._redis: typing.Optional["aioredis.RedisConnection"] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def close(self): @@ -62,7 +62,7 @@ class RedisStorage(BaseStorage): return await self._redis.wait_closed() return True - async def redis(self) -> aioredis.RedisConnection: + async def redis(self) -> "aioredis.RedisConnection": """ Get Redis connection """ @@ -220,9 +220,9 @@ class AioRedisAdapterBase(ABC): pool_size: int = 10, loop: typing.Optional[asyncio.AbstractEventLoop] = None, prefix: str = "fsm", - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, + state_ttl: typing.Optional[int] = None, + data_ttl: typing.Optional[int] = None, + bucket_ttl: typing.Optional[int] = None, **kwargs, ): self._host = host @@ -247,7 +247,7 @@ class AioRedisAdapterBase(ABC): """Get Redis connection.""" pass - def close(self): + async def close(self): """Grace shutdown.""" pass @@ -257,6 +257,8 @@ class AioRedisAdapterBase(ABC): async def set(self, name, value, ex=None, **kwargs): """Set the value at key ``name`` to ``value``.""" + if ex == 0: + ex = None return await self._redis.set(name, value, ex=ex, **kwargs) async def get(self, name, **kwargs): @@ -295,7 +297,7 @@ class AioRedisAdapterV1(AioRedisAdapterBase): ) return self._redis - def close(self): + async def close(self): async with self._connection_lock: if self._redis and not self._redis.closed: self._redis.close() @@ -310,6 +312,8 @@ class AioRedisAdapterV1(AioRedisAdapterBase): return await self._redis.get(name, encoding="utf8", **kwargs) async def set(self, name, value, ex=None, **kwargs): + if ex == 0: + ex = None return await self._redis.set(name, value, expire=ex, **kwargs) async def keys(self, pattern, **kwargs): @@ -367,9 +371,9 @@ class RedisStorage2(BaseStorage): pool_size: int = 10, loop: typing.Optional[asyncio.AbstractEventLoop] = None, prefix: str = "fsm", - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, + state_ttl: typing.Optional[int] = None, + data_ttl: typing.Optional[int] = None, + bucket_ttl: typing.Optional[int] = None, **kwargs, ): self._host = host @@ -413,6 +417,9 @@ class RedisStorage2(BaseStorage): self._redis = AioRedisAdapterV1(**connection_data) elif redis_version == 2: self._redis = AioRedisAdapterV2(**connection_data) + else: + raise RuntimeError(f"Unsupported aioredis version: {redis_version}") + await self._redis.get_redis() return self._redis def generate_key(self, *parts): @@ -420,11 +427,12 @@ class RedisStorage2(BaseStorage): async def close(self): if self._redis: - return self._redis.close() + return await self._redis.close() async def wait_closed(self): if self._redis: - return await self._redis.wait_closed() + await self._redis.wait_closed() + self._redis = None async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: @@ -451,7 +459,7 @@ class RedisStorage2(BaseStorage): if state is None: await redis.delete(key) else: - await redis.set(key, self.resolve_state(state), expire=self._state_ttl) + await redis.set(key, self.resolve_state(state), ex=self._state_ttl) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): @@ -459,7 +467,7 @@ class RedisStorage2(BaseStorage): key = self.generate_key(chat, user, STATE_DATA_KEY) redis = await self._get_adapter() if data: - await redis.set(key, json.dumps(data), expire=self._data_ttl) + await redis.set(key, json.dumps(data), ex=self._data_ttl) else: await redis.delete(key) @@ -490,7 +498,7 @@ class RedisStorage2(BaseStorage): key = self.generate_key(chat, user, STATE_BUCKET_KEY) redis = await self._get_adapter() if bucket: - await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) + await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) else: await redis.delete(key) diff --git a/tests/conftest.py b/tests/conftest.py index 03c8dbe4..b56c7b77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,28 +1,53 @@ +import aioredis import pytest from _pytest.config import UsageError -import aioredis.util + +try: + import aioredis.util +except ImportError: + pass def pytest_addoption(parser): - parser.addoption("--redis", default=None, - help="run tests which require redis connection") + parser.addoption( + "--redis", + default=None, + help="run tests which require redis connection", + ) def pytest_configure(config): - config.addinivalue_line("markers", "redis: marked tests require redis connection to run") + config.addinivalue_line( + "markers", + "redis: marked tests require redis connection to run", + ) def pytest_collection_modifyitems(config, items): redis_uri = config.getoption("--redis") if redis_uri is None: - skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run") + skip_redis = pytest.mark.skip( + reason="need --redis option with redis URI to run" + ) for item in items: if "redis" in item.keywords: item.add_marker(skip_redis) return + + redis_version = int(aioredis.__version__.split(".")[0]) + options = None + if redis_version == 1: + (host, port), options = aioredis.util.parse_url(redis_uri) + options.update({'host': host, 'port': port}) + elif redis_version == 2: + try: + options = aioredis.connection.parse_url(redis_uri) + except ValueError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + try: - address, options = aioredis.util.parse_url(redis_uri) - assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo." + assert isinstance(options, dict), \ + "Only redis and rediss schemas are supported, eg redis://foo." except AssertionError as e: raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") @@ -30,6 +55,20 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope='session') def redis_options(request): redis_uri = request.config.getoption("--redis") - (host, port), options = aioredis.util.parse_url(redis_uri) - options.update({'host': host, 'port': port}) - return options + if redis_uri is None: + pytest.skip("need --redis option with redis URI to run") + return + + redis_version = int(aioredis.__version__.split(".")[0]) + if redis_version == 1: + (host, port), options = aioredis.util.parse_url(redis_uri) + options.update({'host': host, 'port': port}) + return options + + if redis_version == 2: + try: + return aioredis.connection.parse_url(redis_uri) + except ValueError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + + raise UsageError("Unsupported aioredis version") diff --git a/tests/contrib/fsm_storage/test_storage.py b/tests/contrib/fsm_storage/test_storage.py index 0cde2de2..2668cdab 100644 --- a/tests/contrib/fsm_storage/test_storage.py +++ b/tests/contrib/fsm_storage/test_storage.py @@ -1,12 +1,16 @@ +import aioredis import pytest - +from pytest_lazyfixture import lazy_fixture from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.contrib.fsm_storage.redis import RedisStorage2, RedisStorage +from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 @pytest.fixture() @pytest.mark.redis async def redis_store(redis_options): + if int(aioredis.__version__.split(".")[0]) == 2: + pytest.skip('aioredis v2 is not supported.') + return s = RedisStorage(**redis_options) try: yield s @@ -37,9 +41,9 @@ async def memory_store(): @pytest.mark.parametrize( "store", [ - pytest.lazy_fixture('redis_store'), - pytest.lazy_fixture('redis_store2'), - pytest.lazy_fixture('memory_store'), + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), + lazy_fixture('memory_store'), ] ) class TestStorage: @@ -63,8 +67,8 @@ class TestStorage: @pytest.mark.parametrize( "store", [ - pytest.lazy_fixture('redis_store'), - pytest.lazy_fixture('redis_store2'), + lazy_fixture('redis_store'), + lazy_fixture('redis_store2'), ] ) class TestRedisStorage2: @@ -74,6 +78,7 @@ class TestRedisStorage2: assert await store.get_data(chat='1234') == {'foo': 'bar'} pool_id = id(store._redis) await store.close() + await store.wait_closed() assert await store.get_data(chat='1234') == { 'foo': 'bar'} # new pool was opened at this point assert id(store._redis) != pool_id