From 1dfc00ee4d8fa30f3700aa168e6eb2f493bb0f3a Mon Sep 17 00:00:00 2001 From: Oleg A Date: Thu, 25 Mar 2021 15:41:20 +0300 Subject: [PATCH] fix: return default state on get_state --- aiogram/contrib/fsm_storage/memory.py | 2 +- aiogram/contrib/fsm_storage/mongo.py | 2 +- aiogram/contrib/fsm_storage/redis.py | 4 ++-- aiogram/contrib/fsm_storage/rethinkdb.py | 4 +++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aiogram/contrib/fsm_storage/memory.py b/aiogram/contrib/fsm_storage/memory.py index 9343c6f3..e1d6bdc0 100644 --- a/aiogram/contrib/fsm_storage/memory.py +++ b/aiogram/contrib/fsm_storage/memory.py @@ -35,7 +35,7 @@ class MemoryStorage(BaseStorage): user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.resolve_address(chat=chat, user=user) - return self.data[chat][user]['state'] + return self.data[chat][user].get("state", self.resolve_state(default)) async def get_data(self, *, chat: typing.Union[str, int, None] = None, diff --git a/aiogram/contrib/fsm_storage/mongo.py b/aiogram/contrib/fsm_storage/mongo.py index ab43963f..992e2e70 100644 --- a/aiogram/contrib/fsm_storage/mongo.py +++ b/aiogram/contrib/fsm_storage/mongo.py @@ -136,7 +136,7 @@ class MongoStorage(BaseStorage): db = await self.get_db() result = await db[STATE].find_one(filter={'chat': chat, 'user': user}) - return result.get('state') if result else default + return result.get('state') if result else self.resolve_state(default) async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None, data: Dict = None): diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index b32dacaa..01a0fe5c 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -118,7 +118,7 @@ class RedisStorage(BaseStorage): async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: record = await self.get_record(chat=chat, user=user) - return record['state'] + return record.get('state', self.resolve_state(default)) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: @@ -277,7 +277,7 @@ class RedisStorage2(BaseStorage): chat, user = self.check_address(chat=chat, user=user) key = self.generate_key(chat, user, STATE_KEY) redis = await self.redis() - return await redis.get(key, encoding='utf8') or None + return await redis.get(key, encoding='utf8') or self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index 6d7e2109..c600074e 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -95,7 +95,9 @@ class RethinkDBStorage(BaseStorage): default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn) + return await r.table(self._table).get(chat)[user]['state'].default( + self.resolve_state(default) or None + ).run(conn) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: