diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 87d76374..da0b20d8 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -1,18 +1,18 @@ """ -This module has redis storage for finite-state machine based on `aioredis `_ driver +This module has redis storage for finite-state machine based on `redis `_ driver. """ import asyncio import logging import typing -from abc import ABC, abstractmethod - -import aioredis from ...dispatcher.storage import BaseStorage from ...utils import json from ...utils.deprecated import deprecated +if typing.TYPE_CHECKING: + import aioredis + STATE_KEY = 'state' STATE_DATA_KEY = 'data' STATE_BUCKET_KEY = 'bucket' @@ -67,6 +67,8 @@ class RedisStorage(BaseStorage): Get Redis connection """ # Use thread-safe asyncio Lock because this method without that is not safe + import aioredis + async with self._connection_lock: if self._redis is None or self._redis.closed: self._redis = await aioredis.create_connection((self._host, self._port), @@ -207,138 +209,6 @@ class RedisStorage(BaseStorage): await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket) -class AioRedisAdapterBase(ABC): - """Base aioredis adapter class.""" - - def __init__( - self, - host: str = "localhost", - port: int = 6379, - db: typing.Optional[int] = None, - password: typing.Optional[str] = None, - ssl: typing.Optional[bool] = None, - pool_size: int = 10, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - prefix: str = "fsm", - state_ttl: typing.Optional[int] = None, - data_ttl: typing.Optional[int] = None, - bucket_ttl: typing.Optional[int] = None, - **kwargs, - ): - self._host = host - self._port = port - self._db = db - self._password = password - self._ssl = ssl - self._pool_size = pool_size - self._kwargs = kwargs - self._prefix = (prefix,) - - self._state_ttl = state_ttl - self._data_ttl = data_ttl - self._bucket_ttl = bucket_ttl - - self._redis: typing.Optional["aioredis.Redis"] = None - self._connection_lock = asyncio.Lock() - - @abstractmethod - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - pass - - async def close(self): - """Grace shutdown.""" - pass - - async def wait_closed(self): - """Wait for grace shutdown finishes.""" - pass - - async def set(self, name, value, ex=None, **kwargs): - """Set the value at key ``name`` to ``value``.""" - if ex == 0: - ex = None - return await self._redis.set(name, value, ex=ex, **kwargs) - - async def get(self, name, **kwargs): - """Return the value at key ``name`` or None.""" - return await self._redis.get(name, **kwargs) - - async def delete(self, *names): - """Delete one or more keys specified by ``names``""" - return await self._redis.delete(*names) - - async def keys(self, pattern, **kwargs): - """Returns a list of keys matching ``pattern``.""" - return await self._redis.keys(pattern, **kwargs) - - async def flushdb(self): - """Delete all keys in the current database.""" - return await self._redis.flushdb() - - -class AioRedisAdapterV1(AioRedisAdapterBase): - """Redis adapter for aioredis v1.""" - - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - async with self._connection_lock: # to prevent race - if self._redis is None or self._redis.closed: - self._redis = await aioredis.create_redis_pool( - (self._host, self._port), - db=self._db, - password=self._password, - ssl=self._ssl, - minsize=1, - maxsize=self._pool_size, - **self._kwargs, - ) - return self._redis - - async def close(self): - async with self._connection_lock: - if self._redis and not self._redis.closed: - self._redis.close() - - async def wait_closed(self): - async with self._connection_lock: - if self._redis: - return await self._redis.wait_closed() - return True - - async def get(self, name, **kwargs): - return await self._redis.get(name, encoding="utf8", **kwargs) - - async def set(self, name, value, ex=None, **kwargs): - if ex == 0: - ex = None - return await self._redis.set(name, value, expire=ex, **kwargs) - - async def keys(self, pattern, **kwargs): - """Returns a list of keys matching ``pattern``.""" - return await self._redis.keys(pattern, encoding="utf8", **kwargs) - - -class AioRedisAdapterV2(AioRedisAdapterBase): - """Redis adapter for aioredis v2.""" - - async def get_redis(self) -> aioredis.Redis: - """Get Redis connection.""" - async with self._connection_lock: # to prevent race - if self._redis is None: - self._redis = aioredis.Redis( - host=self._host, - port=self._port, - db=self._db, - password=self._password, - ssl=self._ssl, - max_connections=self._pool_size, - decode_responses=True, - **self._kwargs, - ) - return self._redis - - class RedisStorage2(BaseStorage): """ Busted Redis-base storage for FSM. @@ -356,7 +226,6 @@ class RedisStorage2(BaseStorage): .. code-block:: python3 await dp.storage.close() - await dp.storage.wait_closed() """ @@ -375,75 +244,49 @@ class RedisStorage2(BaseStorage): bucket_ttl: typing.Optional[int] = None, **kwargs, ): - self._host = host - self._port = port - self._db = db - self._password = password - self._ssl = ssl - self._pool_size = pool_size - self._kwargs = kwargs - self._prefix = (prefix,) + from redis.asyncio import Redis + self._redis: typing.Optional[Redis] = Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl, + max_connections=pool_size, + decode_responses=True, + **kwargs, + ) + + self._prefix = (prefix,) self._state_ttl = state_ttl self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: typing.Optional[AioRedisAdapterBase] = None - self._connection_lock = asyncio.Lock() - @deprecated("This method will be removed in aiogram v3.0. " "You should use your own instance of Redis.", stacklevel=3) - async def redis(self) -> aioredis.Redis: - adapter = await self._get_adapter() - return await adapter.get_redis() - - async def _get_adapter(self) -> AioRedisAdapterBase: - """Get adapter based on aioredis version.""" - if self._redis is None: - redis_version = int(aioredis.__version__.split(".")[0]) - connection_data = dict( - host=self._host, - port=self._port, - db=self._db, - password=self._password, - ssl=self._ssl, - pool_size=self._pool_size, - **self._kwargs, - ) - if redis_version == 1: - self._redis = AioRedisAdapterV1(**connection_data) - elif redis_version == 2: - self._redis = AioRedisAdapterV2(**connection_data) - else: - raise RuntimeError(f"Unsupported aioredis version: {redis_version}") - await self._redis.get_redis() + async def redis(self) -> "aioredis.Redis": return self._redis def generate_key(self, *parts): return ':'.join(self._prefix + tuple(map(str, parts))) async def close(self): - if self._redis: - return await self._redis.close() + await self._redis.close() async def wait_closed(self): - if self._redis: - await self._redis.wait_closed() - self._redis = None + pass async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self._get_adapter() - return await redis.get(key) or self.resolve_state(default) + return await self._redis.get(key) or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self._get_adapter() - raw_result = await redis.get(key) + raw_result = await self._redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -452,21 +295,19 @@ class RedisStorage2(BaseStorage): state: typing.Optional[typing.AnyStr] = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self._get_adapter() if state is None: - await redis.delete(key) + await self._redis.delete(key) else: - await redis.set(key, self.resolve_state(state), ex=self._state_ttl) + await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self._get_adapter() if data: - await redis.set(key, json.dumps(data), ex=self._data_ttl) + await self._redis.set(key, json.dumps(data), ex=self._data_ttl) else: - await redis.delete(key) + await self._redis.delete(key) async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): @@ -483,8 +324,7 @@ class RedisStorage2(BaseStorage): default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self._get_adapter() - raw_result = await redis.get(key) + raw_result = await self._redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -493,11 +333,10 @@ class RedisStorage2(BaseStorage): bucket: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self._get_adapter() if bucket: - await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) + await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl) else: - await redis.delete(key) + await self._redis.delete(key) async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, @@ -515,13 +354,11 @@ class RedisStorage2(BaseStorage): :param full: clean DB or clean only states :return: """ - redis = await self._get_adapter() - if full: - await redis.flushdb() + await self._redis.flushdb() else: - keys = await redis.keys(self.generate_key('*')) - await redis.delete(*keys) + keys = await self._redis.keys(self.generate_key('*')) + await self._redis.delete(*keys) async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ @@ -529,10 +366,9 @@ class RedisStorage2(BaseStorage): :return: list of tuples where first element is chat id and second is user id """ - redis = await self._get_adapter() result = [] - keys = await redis.keys(self.generate_key('*', '*', STATE_KEY)) + keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY)) for item in keys: *_, chat, user, _ = item.split(':') result.append((chat, user)) diff --git a/tests/contrib/fsm_storage/test_storage.py b/tests/contrib/fsm_storage/test_storage.py index ae06025c..fba3102e 100644 --- a/tests/contrib/fsm_storage/test_storage.py +++ b/tests/contrib/fsm_storage/test_storage.py @@ -2,6 +2,7 @@ import aioredis import pytest import pytest_asyncio from pytest_lazyfixture import lazy_fixture +from redis.asyncio.connection import Connection, ConnectionPool from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 @@ -75,15 +76,19 @@ class TestStorage: ) class TestRedisStorage2: @pytest.mark.asyncio - async def test_close_and_open_connection(self, store): + async def test_close_and_open_connection(self, store: RedisStorage2): await store.set_data(chat='1234', data={'foo': 'bar'}) assert await store.get_data(chat='1234') == {'foo': 'bar'} - pool_id = id(store._redis) await store.close() await store.wait_closed() - # new pool will be open at this point - assert await store.get_data(chat='1234') == { - 'foo': 'bar', - } - assert id(store._redis) != pool_id + pool: ConnectionPool = store._redis.connection_pool + + # noinspection PyUnresolvedReferences + assert not pool._in_use_connections + + # noinspection PyUnresolvedReferences + if pool._available_connections: + # noinspection PyUnresolvedReferences + connection: Connection = pool._available_connections[0] + assert connection.is_connected is False