mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Set state via storage (#542)
* refactor: simplified check_address (removed redundant states check) * refactor: FSM resolve_state become public, removed redundant elif * fix: resolve `filters.State` on `set_state` * refactor: moved state resolution to storage * fix: return default state on get_state
This commit is contained in:
parent
ba095f0b9f
commit
4120408aa3
5 changed files with 54 additions and 33 deletions
|
|
@ -35,7 +35,7 @@ class MemoryStorage(BaseStorage):
|
|||
user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = self.resolve_address(chat=chat, user=user)
|
||||
return self.data[chat][user]['state']
|
||||
return self.data[chat][user].get("state", self.resolve_state(default))
|
||||
|
||||
async def get_data(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
|
|
@ -58,7 +58,7 @@ class MemoryStorage(BaseStorage):
|
|||
user: typing.Union[str, int, None] = None,
|
||||
state: typing.AnyStr = None):
|
||||
chat, user = self.resolve_address(chat=chat, user=user)
|
||||
self.data[chat][user]['state'] = state
|
||||
self.data[chat][user]['state'] = self.resolve_state(state)
|
||||
|
||||
async def set_data(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class MongoStorage(BaseStorage):
|
|||
try:
|
||||
self._mongo = AsyncIOMotorClient(self._uri)
|
||||
except pymongo.errors.ConfigurationError as e:
|
||||
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
|
||||
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
|
||||
import logging
|
||||
logger = logging.getLogger("aiogram")
|
||||
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
|
||||
|
|
@ -114,7 +114,9 @@ class MongoStorage(BaseStorage):
|
|||
async def wait_closed(self):
|
||||
return True
|
||||
|
||||
async def set_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
|
||||
async def set_state(self, *,
|
||||
chat: Union[str, int, None] = None,
|
||||
user: Union[str, int, None] = None,
|
||||
state: Optional[AnyStr] = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
db = await self.get_db()
|
||||
|
|
@ -122,8 +124,11 @@ class MongoStorage(BaseStorage):
|
|||
if state is None:
|
||||
await db[STATE].delete_one(filter={'chat': chat, 'user': user})
|
||||
else:
|
||||
await db[STATE].update_one(filter={'chat': chat, 'user': user},
|
||||
update={'$set': {'state': state}}, upsert=True)
|
||||
await db[STATE].update_one(
|
||||
filter={'chat': chat, 'user': user},
|
||||
update={'$set': {'state': self.resolve_state(state)}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
|
|
@ -131,7 +136,7 @@ class MongoStorage(BaseStorage):
|
|||
db = await self.get_db()
|
||||
result = await db[STATE].find_one(filter={'chat': chat, 'user': user})
|
||||
|
||||
return result.get('state') if result else default
|
||||
return result.get('state') if result else self.resolve_state(default)
|
||||
|
||||
async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
|
||||
data: Dict = None):
|
||||
|
|
|
|||
|
|
@ -118,16 +118,19 @@ class RedisStorage(BaseStorage):
|
|||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
return record['state']
|
||||
return record.get('state', self.resolve_state(default))
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Dict:
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
return record['data']
|
||||
|
||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
async def set_state(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
state: typing.Optional[typing.AnyStr] = None):
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
state = self.resolve_state(state)
|
||||
await self.set_record(chat=chat, user=user, state=state, data=record['data'])
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
|
|
@ -274,7 +277,7 @@ class RedisStorage2(BaseStorage):
|
|||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_KEY)
|
||||
redis = await self.redis()
|
||||
return await redis.get(key, encoding='utf8') or None
|
||||
return await redis.get(key, encoding='utf8') or self.resolve_state(default)
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
|
|
@ -294,7 +297,7 @@ class RedisStorage2(BaseStorage):
|
|||
if state is None:
|
||||
await redis.delete(key)
|
||||
else:
|
||||
await redis.set(key, state, expire=self._state_ttl)
|
||||
await redis.set(key, self.resolve_state(state), expire=self._state_ttl)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ class RethinkDBStorage(BaseStorage):
|
|||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
|
||||
return await r.table(self._table).get(chat)[user]['state'].default(
|
||||
self.resolve_state(default) or None
|
||||
).run(conn)
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Dict:
|
||||
|
|
@ -103,11 +105,16 @@ class RethinkDBStorage(BaseStorage):
|
|||
async with self.connection() as conn:
|
||||
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
|
||||
|
||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
async def set_state(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
state: typing.Optional[typing.AnyStr] = None):
|
||||
chat, user = map(str, self.check_address(chat=chat, user=user))
|
||||
async with self.connection() as conn:
|
||||
await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn)
|
||||
await r.table(self._table).insert(
|
||||
{'id': chat, user: {'state': self.resolve_state(state)}},
|
||||
conflict="update",
|
||||
).run(conn)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
|
|
|
|||
|
|
@ -40,24 +40,27 @@ class BaseStorage:
|
|||
@classmethod
|
||||
def check_address(cls, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None) -> (typing.Union[str, int], typing.Union[str, int]):
|
||||
user: typing.Union[str, int, None] = None,
|
||||
) -> (typing.Union[str, int], typing.Union[str, int]):
|
||||
"""
|
||||
In all storage's methods chat or user is always required.
|
||||
If one of them is not provided, you have to set missing value based on the provided one.
|
||||
|
||||
This method performs the check described above.
|
||||
|
||||
:param chat:
|
||||
:param user:
|
||||
:param chat: chat_id
|
||||
:param user: user_id
|
||||
:return:
|
||||
"""
|
||||
if chat is None and user is None:
|
||||
raise ValueError('`user` or `chat` parameter is required but no one is provided!')
|
||||
|
||||
if user is None and chat is not None:
|
||||
if user is None:
|
||||
user = chat
|
||||
elif user is not None and chat is None:
|
||||
|
||||
elif chat is None:
|
||||
chat = user
|
||||
|
||||
return chat, user
|
||||
|
||||
async def get_state(self, *,
|
||||
|
|
@ -270,6 +273,21 @@ class BaseStorage:
|
|||
"""
|
||||
await self.set_data(chat=chat, user=user, data={})
|
||||
|
||||
@staticmethod
|
||||
def resolve_state(value):
|
||||
from .filters.state import State
|
||||
|
||||
if value is None:
|
||||
return
|
||||
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, State):
|
||||
return value.state
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
class FSMContext:
|
||||
def __init__(self, storage, chat, user):
|
||||
|
|
@ -279,20 +297,8 @@ class FSMContext:
|
|||
def proxy(self):
|
||||
return FSMContextProxy(self)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_state(value):
|
||||
from .filters.state import State
|
||||
|
||||
if value is None:
|
||||
return
|
||||
elif isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, State):
|
||||
return value.state
|
||||
return str(value)
|
||||
|
||||
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
|
||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
|
||||
|
||||
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
|
||||
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
|
||||
|
|
@ -301,7 +307,7 @@ class FSMContext:
|
|||
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
|
||||
|
||||
async def set_state(self, state: typing.Optional[typing.AnyStr] = None):
|
||||
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
|
||||
await self.storage.set_state(chat=self.chat, user=self.user, state=state)
|
||||
|
||||
async def set_data(self, data: typing.Dict = None):
|
||||
await self.storage.set_data(chat=self.chat, user=self.user, data=data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue