diff --git a/.gitignore b/.gitignore index a8b34bd1..d20c39ba 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,6 @@ docs/html # i18n/l10n *.mo + +# pynev +.python-version diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 106a7b97..bf88eff7 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -44,19 +44,19 @@ class RedisStorage(BaseStorage): self._loop = loop or asyncio.get_event_loop() self._kwargs = kwargs - self._redis: aioredis.RedisConnection = None + self._redis: typing.Optional[aioredis.RedisConnection] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def close(self): - if self._redis and not self._redis.closed: - self._redis.close() - del self._redis - self._redis = None + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() async def wait_closed(self): - if self._redis: - return await self._redis.wait_closed() - return True + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True async def redis(self) -> aioredis.RedisConnection: """ @@ -64,7 +64,7 @@ class RedisStorage(BaseStorage): """ # Use thread-safe asyncio Lock because this method without that is not safe async with self._connection_lock: - if self._redis is None: + if self._redis is None or self._redis.closed: self._redis = await aioredis.create_connection((self._host, self._port), db=self._db, password=self._password, ssl=self._ssl, loop=self._loop, @@ -144,7 +144,7 @@ class RedisStorage(BaseStorage): record_data.update(data, **kwargs) await self.set_record(chat=chat, user=user, state=record['state'], data=record_data) - async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ Get list of all stored chat's and user's @@ -220,11 +220,11 @@ class RedisStorage2(BaseStorage): """ def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, - ssl=None, pool_size=10, loop=None, prefix='fsm', - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, - **kwargs): + ssl=None, pool_size=10, loop=None, prefix='fsm', + state_ttl: int = 0, + data_ttl: int = 0, + bucket_ttl: int = 0, + **kwargs): self._host = host self._port = port self._db = db @@ -239,7 +239,7 @@ class RedisStorage2(BaseStorage): self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: aioredis.RedisConnection = None + self._redis: typing.Optional[aioredis.RedisConnection] = None self._connection_lock = asyncio.Lock(loop=self._loop) async def redis(self) -> aioredis.Redis: @@ -248,7 +248,7 @@ class RedisStorage2(BaseStorage): """ # Use thread-safe asyncio Lock because this method without that is not safe async with self._connection_lock: - if self._redis is None: + if self._redis is None or self._redis.closed: self._redis = await aioredis.create_redis_pool((self._host, self._port), db=self._db, password=self._password, ssl=self._ssl, minsize=1, maxsize=self._pool_size, @@ -262,8 +262,6 @@ class RedisStorage2(BaseStorage): async with self._connection_lock: if self._redis and not self._redis.closed: self._redis.close() - del self._redis - self._redis = None async def wait_closed(self): async with self._connection_lock: @@ -357,7 +355,7 @@ class RedisStorage2(BaseStorage): keys = await conn.keys(self.generate_key('*')) await conn.delete(*keys) - async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ Get list of all stored chat's and user's diff --git a/tests/conftest.py b/tests/conftest.py index fe936e18..03c8dbe4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1,35 @@ -# pytest_plugins = "pytest_asyncio.plugin" +import pytest +from _pytest.config import UsageError +import aioredis.util + + +def pytest_addoption(parser): + parser.addoption("--redis", default=None, + help="run tests which require redis connection") + + +def pytest_configure(config): + config.addinivalue_line("markers", "redis: marked tests require redis connection to run") + + +def pytest_collection_modifyitems(config, items): + redis_uri = config.getoption("--redis") + if redis_uri is None: + skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run") + for item in items: + if "redis" in item.keywords: + item.add_marker(skip_redis) + return + try: + address, options = aioredis.util.parse_url(redis_uri) + assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo." + except AssertionError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + + +@pytest.fixture(scope='session') +def redis_options(request): + redis_uri = request.config.getoption("--redis") + (host, port), options = aioredis.util.parse_url(redis_uri) + options.update({'host': host, 'port': port}) + return options diff --git a/tests/contrib/fsm_storage/test_redis.py b/tests/contrib/fsm_storage/test_redis.py new file mode 100644 index 00000000..527c905e --- /dev/null +++ b/tests/contrib/fsm_storage/test_redis.py @@ -0,0 +1,33 @@ +import pytest + +from aiogram.contrib.fsm_storage.redis import RedisStorage2 + + +@pytest.fixture() +async def store(redis_options): + s = RedisStorage2(**redis_options) + try: + yield s + finally: + conn = await s.redis() + await conn.flushdb() + await s.close() + await s.wait_closed() + + +@pytest.mark.redis +class TestRedisStorage2: + @pytest.mark.asyncio + async def test_set_get(self, store): + assert await store.get_data(chat='1234') == {} + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + + @pytest.mark.asyncio + async def test_close_and_open_connection(self, store): + await store.set_data(chat='1234', data={'foo': 'bar'}) + assert await store.get_data(chat='1234') == {'foo': 'bar'} + pool_id = id(store._redis) + await store.close() + assert await store.get_data(chat='1234') == {'foo': 'bar'} # new pool was opened at this point + assert id(store._redis) != pool_id