diff --git a/aiogram/dispatcher/fsm/storage/mongo.py b/aiogram/dispatcher/fsm/storage/mongo.py new file mode 100644 index 00000000..918801c9 --- /dev/null +++ b/aiogram/dispatcher/fsm/storage/mongo.py @@ -0,0 +1,103 @@ +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, Optional + +try: + import pymongo + import motor + from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase +except ModuleNotFoundError as e: + import warnings + warnings.warn("Install motor with `pip install motor`") + raise e + +from aiogram import Bot +from aiogram.dispatcher.fsm.state import State +from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey + +STATE = 'aiogram_state' +DATA = 'aiogram_data' +BUCKET = 'aiogram_bucket' +COLLECTIONS = (STATE, DATA, BUCKET) + + +class MongoStorage(BaseStorage): + """ + Mongo storage required :code:`motor` package installed (:code:`pip install motor`) + """ + + def __init__( + self, + mongo: AsyncIOMotorClient, + db_name: str = 'aiogram_fsm' + ) -> None: + """ + :param mongo: Instance of mongo connection + + """ + self._mongo = mongo + self._db = mongo.get_database(db_name) + + @classmethod + def from_url( + cls, url: str + ) -> "MongoStorage": + """ + Create an instance of :class:`MongoStorage` with specifying the connection string + + :param url: for example :code:`mongodb://user:password@host:port/db` + """ + + return cls(mongo=AsyncIOMotorClient(url)) + + async def close(self) -> None: + await self._mongo.close() + + @asynccontextmanager + async def lock( + self, + bot: Bot, + key: StorageKey, + ) -> AsyncGenerator[None, None]: + yield None + + async def set_state( + self, + bot: Bot, + key: StorageKey, + state: StateType = None, + ) -> None: + if state is None: + await self._db[STATE].delete_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id}) + else: + await self._db[STATE].update_one( + filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id}, + update={'$set': {'state': state.state if isinstance(state, State) else state}}, + upsert=True, + ) + + async def get_state( + self, + bot: Bot, + key: StorageKey, + ) -> Optional[str]: + result = await self._db[STATE].find_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id}) + return result.get('state') if result else None + + async def set_data( + self, + bot: Bot, + key: StorageKey, + data: Dict[str, Any], + ) -> None: + await self._db[DATA].insert_one( + {'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id, 'data': data} + ) + + async def get_data( + self, + bot: Bot, + key: StorageKey, + ) -> Dict[str, Any]: + result = await self._db[DATA].find_one(filter={'chat': key.chat_id, 'user': key.user_id, 'bot_id': key.bot_id}) + + return result.get('data') if result else {} diff --git a/pyproject.toml b/pyproject.toml index 297f02a4..c47ecd80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ Babel = { version = "^2.9.1", optional = true } aiohttp-socks = { version = "^0.5.5", optional = true } # Redis aioredis = { version = "^2.0.0", optional = true } +# Mongodb +motor = "^2.5.1" + # Docs Sphinx = { version = "^4.2.0", optional = true } sphinx-intl = { version = "^2.0.1", optional = true } diff --git a/tests/conftest.py b/tests/conftest.py index e57ec632..78bb43d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from aioredis.connection import parse_url as parse_redis_url from aiogram import Bot from aiogram.dispatcher.fsm.storage.memory import MemoryStorage +from aiogram.dispatcher.fsm.storage.mongo import MongoStorage from aiogram.dispatcher.fsm.storage.redis import RedisStorage from tests.mocked_bot import MockedBot @@ -67,6 +68,19 @@ async def memory_storage(): await storage.close() +@pytest.fixture() +@pytest.mark.mongo +async def mongo_storage(redis_server): + if not redis_server: + pytest.skip("Mongo is not available here") + storage = MongoStorage.from_url(redis_server) + + try: + yield storage + finally: + await storage.close() + + @pytest.fixture() def bot(): bot = MockedBot() diff --git a/tests/test_dispatcher/test_fsm/storage/test_storages.py b/tests/test_dispatcher/test_fsm/storage/test_storages.py index 428f6d02..65b4742e 100644 --- a/tests/test_dispatcher/test_fsm/storage/test_storages.py +++ b/tests/test_dispatcher/test_fsm/storage/test_storages.py @@ -13,7 +13,7 @@ def create_storate_key(bot: MockedBot): @pytest.mark.parametrize( "storage", - [pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")], + [pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage"), pytest.lazy_fixture("mongo_storage")], ) class TestStorages: async def test_lock(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):