diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 5d0b762c..4f0208ec 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -5,6 +5,8 @@ This module has redis storage for finite-state machine based on `aioredis aioredis.Redis: + """Get Redis connection.""" + pass + + def close(self): + """Grace shutdown.""" + pass + + async def wait_closed(self): + pass + + @abstractmethod + async def get(self, name): + pass + + @abstractmethod + async def set(self, name, value, expire=None): + pass + + @abstractmethod + async def delete(self, name): + pass + + +class AioRedisAdapterV1(AioRedisAdapterBase): + """Redis adapter for aioredis v1.""" + + async def redis(self) -> aioredis.Redis: + """Get Redis connection.""" + async with self._connection_lock: # to prevent race + if self._redis is None or self._redis.closed: + self._redis = await aioredis.create_redis_pool( + (self._host, self._port), + db=self._db, + password=self._password, + ssl=self._ssl, + minsize=1, + maxsize=self._pool_size, + loop=self._loop, + **self._kwargs, + ) + return self._redis + + def close(self): + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() + + async def wait_closed(self): + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True + + async def get(self, name): + redis = await self.redis() + return await redis.get(name, encoding="utf8") + + async def set(self, name, value, expire=None): + redis = await self.redis() + return await redis.set(name, value, expire=expire) + + async def delete(self, name): + redis = await self.redis() + return await redis.delete(name) + + +class AioRedisAdapterV2(AioRedisAdapterBase): + """Redis adapter for aioredis v2.""" + + async def redis(self) -> aioredis.Redis: + """Get Redis connection.""" + async with self._connection_lock: # to prevent race + if self._redis is None: + self._redis = aioredis.Redis( + host=self._host, + port=self._port, + db=self._db, + password=self._password, + ssl=self._ssl, + max_connections=self._pool_size, + **self._kwargs, + ) + return self._redis + + async def get(self, name): + redis = await self.redis() + return await redis.get(name) + + async def set(self, name, value, expire=None): + redis = await self.redis() + return await redis.set(name, value, ex=expire) + + async def delete(self, name): + redis = await self.redis() + return await redis.delete(name) + + class RedisStorage2(BaseStorage): """ Busted Redis-base storage for FSM. @@ -224,12 +358,22 @@ class RedisStorage2(BaseStorage): await dp.storage.wait_closed() """ - def __init__(self, host: str = 'localhost', port=6379, db=None, password=None, - ssl=None, pool_size=10, loop=None, prefix='fsm', - state_ttl: int = 0, - data_ttl: int = 0, - bucket_ttl: int = 0, - **kwargs): + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: typing.Optional[int] = None, + password: typing.Optional[str] = None, + ssl: typing.Optional[bool] = None, + pool_size: int = 10, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + prefix: str = "fsm", + state_ttl: int = 0, + data_ttl: int = 0, + bucket_ttl: int = 0, + **kwargs, + ): self._host = host self._port = port self._db = db @@ -244,49 +388,53 @@ class RedisStorage2(BaseStorage): self._data_ttl = data_ttl self._bucket_ttl = bucket_ttl - self._redis: typing.Optional[aioredis.RedisConnection] = None + self._redis: typing.Optional[AioRedisAdapterBase] = None self._connection_lock = asyncio.Lock(loop=self._loop) - async def redis(self) -> aioredis.Redis: - """ - Get Redis connection - """ - # Use thread-safe asyncio Lock because this method without that is not safe - async with self._connection_lock: - if self._redis is None or self._redis.closed: - self._redis = await aioredis.create_redis_pool((self._host, self._port), - db=self._db, password=self._password, ssl=self._ssl, - minsize=1, maxsize=self._pool_size, - loop=self._loop, **self._kwargs) + async def redis(self) -> AioRedisAdapterBase: + """Get adapter based on aioredis version.""" + if self._redis is None: + redis_version = version("aioredis").split(".")[0] + connection_data = dict( + host=self._host, + port=self._port, + db=self._db, + password=self._password, + ssl=self._ssl, + pool_size=self._pool_size, + loop=self._loop, + **self._kwargs, + ) + if redis_version == 1: + self._redis = AioRedisAdapterV1(**connection_data) + elif redis_version == 2: + self._redis = AioRedisAdapterV2(**connection_data) return self._redis def generate_key(self, *parts): return ':'.join(self._prefix + tuple(map(str, parts))) async def close(self): - async with self._connection_lock: - if self._redis and not self._redis.closed: - self._redis.close() + if self._redis: + return self._redis.close() async def wait_closed(self): - async with self._connection_lock: - if self._redis: - return await self._redis.wait_closed() - return True + if self._redis: + return await self._redis.wait_closed() async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) redis = await self.redis() - return await redis.get(key, encoding='utf8') or self.resolve_state(default) + return await redis.get(key) or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) redis = await self.redis() - raw_result = await redis.get(key, encoding='utf8') + raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) return default or {} @@ -327,7 +475,7 @@ class RedisStorage2(BaseStorage): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) redis = await self.redis() - raw_result = await redis.get(key, encoding='utf8') + raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) return default or {}