Set state via storage (#542)

* refactor: simplified check_address (removed redundant states check)

* refactor: FSM resolve_state become public, removed redundant elif

* fix: resolve `filters.State` on `set_state`

* refactor: moved state resolution to storage

* fix: return default state on get_state
This commit is contained in:
Oleg A 2021-04-28 01:28:53 +03:00 committed by GitHub
parent ba095f0b9f
commit 4120408aa3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 54 additions and 33 deletions

View file

@ -35,7 +35,7 @@ class MemoryStorage(BaseStorage):
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.resolve_address(chat=chat, user=user)
return self.data[chat][user]['state']
return self.data[chat][user].get("state", self.resolve_state(default))
async def get_data(self, *,
chat: typing.Union[str, int, None] = None,
@ -58,7 +58,7 @@ class MemoryStorage(BaseStorage):
user: typing.Union[str, int, None] = None,
state: typing.AnyStr = None):
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['state'] = state
self.data[chat][user]['state'] = self.resolve_state(state)
async def set_data(self, *,
chat: typing.Union[str, int, None] = None,

View file

@ -65,7 +65,7 @@ class MongoStorage(BaseStorage):
try:
self._mongo = AsyncIOMotorClient(self._uri)
except pymongo.errors.ConfigurationError as e:
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
import logging
logger = logging.getLogger("aiogram")
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
@ -114,7 +114,9 @@ class MongoStorage(BaseStorage):
async def wait_closed(self):
return True
async def set_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
async def set_state(self, *,
chat: Union[str, int, None] = None,
user: Union[str, int, None] = None,
state: Optional[AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
@ -122,8 +124,11 @@ class MongoStorage(BaseStorage):
if state is None:
await db[STATE].delete_one(filter={'chat': chat, 'user': user})
else:
await db[STATE].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'state': state}}, upsert=True)
await db[STATE].update_one(
filter={'chat': chat, 'user': user},
update={'$set': {'state': self.resolve_state(state)}},
upsert=True,
)
async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[str] = None) -> Optional[str]:
@ -131,7 +136,7 @@ class MongoStorage(BaseStorage):
db = await self.get_db()
result = await db[STATE].find_one(filter={'chat': chat, 'user': user})
return result.get('state') if result else default
return result.get('state') if result else self.resolve_state(default)
async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
data: Dict = None):

View file

@ -118,16 +118,19 @@ class RedisStorage(BaseStorage):
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
record = await self.get_record(chat=chat, user=user)
return record['state']
return record.get('state', self.resolve_state(default))
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record['data']
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
record = await self.get_record(chat=chat, user=user)
state = self.resolve_state(state)
await self.set_record(chat=chat, user=user, state=state, data=record['data'])
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
@ -274,7 +277,7 @@ class RedisStorage2(BaseStorage):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis()
return await redis.get(key, encoding='utf8') or None
return await redis.get(key, encoding='utf8') or self.resolve_state(default)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
@ -294,7 +297,7 @@ class RedisStorage2(BaseStorage):
if state is None:
await redis.delete(key)
else:
await redis.set(key, state, expire=self._state_ttl)
await redis.set(key, self.resolve_state(state), expire=self._state_ttl)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):

View file

@ -95,7 +95,9 @@ class RethinkDBStorage(BaseStorage):
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
return await r.table(self._table).get(chat)[user]['state'].default(
self.resolve_state(default) or None
).run(conn)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
@ -103,11 +105,16 @@ class RethinkDBStorage(BaseStorage):
async with self.connection() as conn:
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn)
await r.table(self._table).insert(
{'id': chat, user: {'state': self.resolve_state(state)}},
conflict="update",
).run(conn)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):

View file

@ -40,24 +40,27 @@ class BaseStorage:
@classmethod
def check_address(cls, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> (typing.Union[str, int], typing.Union[str, int]):
user: typing.Union[str, int, None] = None,
) -> (typing.Union[str, int], typing.Union[str, int]):
"""
In all storage's methods chat or user is always required.
If one of them is not provided, you have to set missing value based on the provided one.
This method performs the check described above.
:param chat:
:param user:
:param chat: chat_id
:param user: user_id
:return:
"""
if chat is None and user is None:
raise ValueError('`user` or `chat` parameter is required but no one is provided!')
if user is None and chat is not None:
if user is None:
user = chat
elif user is not None and chat is None:
elif chat is None:
chat = user
return chat, user
async def get_state(self, *,
@ -270,6 +273,21 @@ class BaseStorage:
"""
await self.set_data(chat=chat, user=user, data={})
@staticmethod
def resolve_state(value):
from .filters.state import State
if value is None:
return
if isinstance(value, str):
return value
if isinstance(value, State):
return value.state
return str(value)
class FSMContext:
def __init__(self, storage, chat, user):
@ -279,20 +297,8 @@ class FSMContext:
def proxy(self):
return FSMContextProxy(self)
@staticmethod
def _resolve_state(value):
from .filters.state import State
if value is None:
return
elif isinstance(value, str):
return value
elif isinstance(value, State):
return value.state
return str(value)
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
@ -301,7 +307,7 @@ class FSMContext:
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
async def set_state(self, state: typing.Optional[typing.AnyStr] = None):
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
await self.storage.set_state(chat=self.chat, user=self.user, state=state)
async def set_data(self, data: typing.Dict = None):
await self.storage.set_data(chat=self.chat, user=self.user, data=data)