From f50c4dac9c6b6674b2412352a1c5a77e4955593b Mon Sep 17 00:00:00 2001 From: asimaranov Date: Thu, 16 Dec 2021 18:25:53 +0300 Subject: [PATCH] Added arguments to save backward compatibility --- aiogram/dispatcher/fsm/storage/mongo.py | 54 ++++++++++++++++++++----- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/aiogram/dispatcher/fsm/storage/mongo.py b/aiogram/dispatcher/fsm/storage/mongo.py index 7850ea02..4a6ea500 100644 --- a/aiogram/dispatcher/fsm/storage/mongo.py +++ b/aiogram/dispatcher/fsm/storage/mongo.py @@ -3,9 +3,10 @@ from typing import Any, AsyncGenerator, Dict, Optional try: import motor - from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase + from motor.motor_asyncio import AsyncIOMotorClient except ModuleNotFoundError as e: import warnings + warnings.warn("Install motor with `pip install motor`") raise e @@ -27,28 +28,54 @@ class MongoStorage(BaseStorage): def __init__( self, mongo: AsyncIOMotorClient, - db_name: str = 'aiogram_fsm' + db_name: str = 'aiogram_fsm', + with_bot_id: bool = True, + with_destiny: bool = True ) -> None: """ :param mongo: Instance of mongo connection + :param with_bot_id: include Bot id in the database + :param with_destiny: include destiny in the database """ self._mongo = mongo self._db = mongo.get_database(db_name) + self._with_bot_id = with_bot_id + self._with_destiny = with_destiny @classmethod def from_url( - cls, url: str, db_name: str = 'aiogram_fsm' + cls, url: str, + db_name: str = 'aiogram_fsm', + with_bot_id: bool = True, + with_destiny: bool = True + ) -> "MongoStorage": """ Create an instance of :class:`MongoStorage` with specifying the connection string :param url: for example :code:`mongodb://user:password@host:port` :param db_name: name of database to store aiogram data` + :param with_bot_id: include Bot id in the database + :param with_destiny: include destiny in the database """ - return cls(mongo=AsyncIOMotorClient(url), db_name=db_name) + return cls( + mongo=AsyncIOMotorClient(url), + db_name=db_name, + with_bot_id=with_bot_id, + with_destiny=with_destiny + ) + + def _get_db_filter(self, key: StorageKey): + db_filter = {'chat': key.chat_id, 'user': key.user_id} + if self._with_bot_id: + db_filter['bot_id'] = key.bot_id + + if self._with_destiny: + db_filter['destiny'] = key.destiny + return db_filter async def close(self) -> None: await self._mongo.close() @@ -67,11 +94,12 @@ class MongoStorage(BaseStorage): key: StorageKey, state: StateType = None, ) -> None: + if state is None: - await self._db[STATE].delete_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'destiny': key.destiny}) + await self._db[STATE].delete_one(filter=self._get_db_filter(key)) else: await self._db[STATE].update_one( - filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'destiny': key.destiny}, + filter=self._get_db_filter(key), update={'$set': {'state': state.state if isinstance(state, State) else state}}, upsert=True, ) @@ -81,7 +109,9 @@ class MongoStorage(BaseStorage): bot: Bot, key: StorageKey, ) -> Optional[str]: - result = await self._db[STATE].find_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'destiny': key.destiny}) + result = await self._db[STATE].find_one( + filter=self._get_db_filter(key) + ) return result.get('state') if result else None async def set_data( @@ -90,15 +120,17 @@ class MongoStorage(BaseStorage): key: StorageKey, data: Dict[str, Any], ) -> None: - await self._db[DATA].insert_one( - {'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'data': data, 'destiny': key.destiny} - ) + data_to_insert = self._get_db_filter(key) + data_to_insert['data'] = data + await self._db[DATA].insert_one(data_to_insert) async def get_data( self, bot: Bot, key: StorageKey, ) -> Dict[str, Any]: - result = await self._db[DATA].find_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'destiny': key.destiny}) + result = await self._db[DATA].find_one( + filter=self._get_db_filter(key) + ) return result.get('data') if result else {}