From 5cf7a3996747dc15c52989d5c046b9f789c53440 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Sun, 1 Aug 2021 22:42:30 +0300 Subject: [PATCH] chore: separate get_adapter method --- aiogram/contrib/fsm_storage/redis.py | 30 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index 23f5bf71..1c0d731f 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -383,7 +383,11 @@ class RedisStorage2(BaseStorage): self._redis: typing.Optional[AioRedisAdapterBase] = None self._connection_lock = asyncio.Lock(loop=self._loop) - async def redis(self) -> AioRedisAdapterBase: + async def redis(self) -> aioredis.Redis: + adapter = await self._get_adapter() + return await adapter.get_redis() + + async def _get_adapter(self) -> AioRedisAdapterBase: """Get adapter based on aioredis version.""" if self._redis is None: redis_version = aioredis.__version__.split(".")[0] @@ -419,14 +423,14 @@ class RedisStorage2(BaseStorage): default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self.redis() + redis = await self._get_adapter() return await redis.get(key) or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self.redis() + redis = await self._get_adapter() raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) @@ -436,7 +440,7 @@ class RedisStorage2(BaseStorage): state: typing.Optional[typing.AnyStr] = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) - redis = await self.redis() + redis = await self._get_adapter() if state is None: await redis.delete(key) else: @@ -446,7 +450,7 @@ class RedisStorage2(BaseStorage): data: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_DATA_KEY) - redis = await self.redis() + redis = await self._get_adapter() if data: await redis.set(key, json.dumps(data), expire=self._data_ttl) else: @@ -467,7 +471,7 @@ class RedisStorage2(BaseStorage): default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self.redis() + redis = await self._get_adapter() raw_result = await redis.get(key) if raw_result: return json.loads(raw_result) @@ -477,7 +481,7 @@ class RedisStorage2(BaseStorage): bucket: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_BUCKET_KEY) - redis = await self.redis() + redis = await self._get_adapter() if bucket: await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) else: @@ -499,13 +503,13 @@ class RedisStorage2(BaseStorage): :param full: clean DB or clean only states :return: """ - conn = await self.redis() + redis = await self._get_adapter() if full: - await conn.flushdb() + await redis.flushdb() else: - keys = await conn.keys(self.generate_key('*')) - await conn.delete(*keys) + keys = await redis.keys(self.generate_key('*')) + await redis.delete(*keys) async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ @@ -513,10 +517,10 @@ class RedisStorage2(BaseStorage): :return: list of tuples where first element is chat id and second is user id """ - conn = await self.redis() + redis = await self._get_adapter() result = [] - keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') + keys = await redis.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') for item in keys: *_, chat, user, _ = item.split(':') result.append((chat, user))