From c7cdfc5ab49918ffc36bf776bce6b8a2c91146eb Mon Sep 17 00:00:00 2001 From: Anthony Byuraev Date: Sat, 4 Jul 2020 18:10:47 +0300 Subject: [PATCH] Change `user`, `chat` in redis --- aiogram/contrib/fsm_storage/redis.py | 178 ++++++++++++++++----------- 1 file changed, 106 insertions(+), 72 deletions(-) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index bf88eff7..0207b975 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -1,5 +1,6 @@ """ -This module has redis storage for finite-state machine based on `aioredis `_ driver +This module has redis storage for finite-state machine + based on `aioredis `_ driver """ import asyncio @@ -35,7 +36,8 @@ class RedisStorage(BaseStorage): await dp.storage.wait_closed() """ - def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs): + def __init__(self, host='localhost', port=6379, db=None, + password=None, ssl=None, loop=None, **kwargs): self._host = host self._port = port self._db = db @@ -72,17 +74,17 @@ class RedisStorage(BaseStorage): return self._redis async def get_record(self, *, - chat: typing.Union[str, int, None] = None, - user: typing.Union[str, int, None] = None) -> typing.Dict: + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None) -> typing.Dict: """ Get record from storage - :param chat: - :param user: + :param chat_id: + :param user_id: :return: """ - chat, user = self.check_address(chat=chat, user=user) - addr = f"fsm:{chat}:{user}" + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + addr = f"fsm:{chat_id}:{user_id}" conn = await self.redis() data = await conn.execute('GET', addr) @@ -90,14 +92,16 @@ class RedisStorage(BaseStorage): return {'state': None, 'data': {}} return json.loads(data) - async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_record(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, state=None, data=None, bucket=None): """ Write record to storage :param bucket: - :param chat: - :param user: + :param chat_id: + :param user_id: :param state: :param data: :return: @@ -107,42 +111,52 @@ class RedisStorage(BaseStorage): if bucket is None: bucket = {} - chat, user = self.check_address(chat=chat, user=user) - addr = f"fsm:{chat}:{user}" + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + addr = f"fsm:{chat_id}:{user_id}" record = {'state': state, 'data': data, 'bucket': bucket} conn = await self.redis() await conn.execute('SET', addr, json.dumps(record)) - async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_state(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: - record = await self.get_record(chat=chat, user=user) + record = await self.get_record(chat_id=chat_id, user_id=user_id) return record['state'] - async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: - record = await self.get_record(chat=chat, user=user) + record = await self.get_record(chat_id=chat_id, user_id=user_id) return record['data'] - async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_state(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): - record = await self.get_record(chat=chat, user=user) - await self.set_record(chat=chat, user=user, state=state, data=record['data']) + record = await self.get_record(chat_id=chat_id, user_id=user_id) + await self.set_record(chat_id=chat_id, user_id=user_id, state=state, data=record['data']) - async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, data: typing.Dict = None): - record = await self.get_record(chat=chat, user=user) - await self.set_record(chat=chat, user=user, state=record['state'], data=data) + record = await self.get_record(chat_id=chat_id, user_id=user_id) + await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=data) - async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def update_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): if data is None: data = {} - record = await self.get_record(chat=chat, user=user) + record = await self.get_record(chat_id=chat_id, user_id=user_id) record_data = record.get('data', {}) record_data.update(data, **kwargs) - await self.set_record(chat=chat, user=user, state=record['state'], data=record_data) + await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record_data) async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]: """ @@ -155,8 +169,8 @@ class RedisStorage(BaseStorage): keys = await conn.execute('KEYS', 'fsm:*') for item in keys: - *_, chat, user = item.decode('utf-8').split(':') - result.append((chat, user)) + *_, chat_id, user_id = item.decode('utf-8').split(':') + result.append((chat_id, user_id)) return result @@ -178,25 +192,30 @@ class RedisStorage(BaseStorage): def has_bucket(self): return True - async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: - record = await self.get_record(chat=chat, user=user) + record = await self.get_record(chat_id=chat_id, user_id=user_id) return record.get('bucket', {}) - async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, bucket: typing.Dict = None): - record = await self.get_record(chat=chat, user=user) - await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket) + record = await self.get_record(chat_id=chat_id, user_id=user_id) + await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record['data'], bucket=bucket) - async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, - user: typing.Union[str, int, None] = None, + async def update_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): - record = await self.get_record(chat=chat, user=user) + record = await self.get_record(chat_id=chat_id, user_id=user_id) record_bucket = record.get('bucket', {}) if bucket is None: bucket = {} record_bucket.update(bucket, **kwargs) - await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket) + await self.set_record(chat_id=chat_id, user_id=user_id, state=record['state'], data=record_bucket, bucket=bucket) class RedisStorage2(BaseStorage): @@ -269,76 +288,91 @@ class RedisStorage2(BaseStorage): return await self._redis.wait_closed() return True - async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_state(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_KEY) redis = await self.redis() return await redis.get(key, encoding='utf8') or None - async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_DATA_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_DATA_KEY) redis = await self.redis() raw_result = await redis.get(key, encoding='utf8') if raw_result: return json.loads(raw_result) return default or {} - async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_state(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_KEY) redis = await self.redis() if state is None: await redis.delete(key) else: await redis.set(key, state, expire=self._state_ttl) - async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, data: typing.Dict = None): - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_DATA_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_DATA_KEY) redis = await self.redis() await redis.set(key, json.dumps(data), expire=self._data_ttl) - async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def update_data(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): if data is None: data = {} - temp_data = await self.get_data(chat=chat, user=user, default={}) + temp_data = await self.get_data(chat_id=chat_id, user_id=user_id, default={}) temp_data.update(data, **kwargs) - await self.set_data(chat=chat, user=user, data=temp_data) + await self.set_data(chat_id=chat_id, user_id=user_id, data=temp_data) def has_bucket(self): return True - async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def get_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_BUCKET_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_BUCKET_KEY) redis = await self.redis() raw_result = await redis.get(key, encoding='utf8') if raw_result: return json.loads(raw_result) return default or {} - async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + async def set_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, bucket: typing.Dict = None): - chat, user = self.check_address(chat=chat, user=user) - key = self.generate_key(chat, user, STATE_BUCKET_KEY) + chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id) + key = self.generate_key(chat_id, user_id, STATE_BUCKET_KEY) redis = await self.redis() await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl) - async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, - user: typing.Union[str, int, None] = None, + async def update_bucket(self, *, + chat_id: typing.Union[str, int, None] = None, + user_id: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): if bucket is None: bucket = {} - temp_bucket = await self.get_bucket(chat=chat, user=user) + temp_bucket = await self.get_bucket(chat_id=chat_id, user_id=user_id) temp_bucket.update(bucket, **kwargs) - await self.set_bucket(chat=chat, user=user, bucket=temp_bucket) + await self.set_bucket(chat_id=chat_id, user_id=user_id, bucket=temp_bucket) async def reset_all(self, full=True): """ @@ -366,8 +400,8 @@ class RedisStorage2(BaseStorage): keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') for item in keys: - *_, chat, user, _ = item.split(':') - result.append((chat, user)) + *_, chat_id, user_id, _ = item.split(':') + result.append((chat_id, user_id)) return result @@ -390,14 +424,14 @@ async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorag log = logging.getLogger('aiogram.RedisStorage') - for chat, user in await storage1.get_states_list(): - state = await storage1.get_state(chat=chat, user=user) - await storage2.set_state(chat=chat, user=user, state=state) + for chat_id, user_id in await storage1.get_states_list(): + state = await storage1.get_state(chat_id=chat_id, user_id=user_id) + await storage2.set_state(chat_id=chat_id, user_id=user_id, state=state) - data = await storage1.get_data(chat=chat, user=user) - await storage2.set_data(chat=chat, user=user, data=data) + data = await storage1.get_data(chat_id=chat_id, user_id=user_id) + await storage2.set_data(chat_id=chat_id, user_id=user_id, data=data) - bucket = await storage1.get_bucket(chat=chat, user=user) - await storage2.set_bucket(chat=chat, user=user, bucket=bucket) + bucket = await storage1.get_bucket(chat_id=chat_id, user_id=user_id) + await storage2.set_bucket(chat_id=chat_id, user_id=user_id, bucket=bucket) - log.info(f"Migrated user {user} in chat {chat}") + log.info(f"Migrated user {user_id} in chat {chat_id}")