diff --git a/aiogram/contrib/fsm_storage/mongo.py b/aiogram/contrib/fsm_storage/mongo.py index 9ec18090..d9c2f354 100644 --- a/aiogram/contrib/fsm_storage/mongo.py +++ b/aiogram/contrib/fsm_storage/mongo.py @@ -1,12 +1,12 @@ """ This module has mongo storage for finite-state machine - based on `aiomongo `_ driver """ from typing import Union, Dict, Optional, List, Tuple, AnyStr -import aiomongo -from aiomongo import AioMongoClient, Database +import motor +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from ...dispatcher.storage import BaseStorage @@ -35,22 +35,27 @@ class MongoStorage(BaseStorage): """ - def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', + def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None, username=None, password=None, index=True, **kwargs): self._host = host self._port = port self._db_name: str = db_name + self._uri = uri self._username = username self._password = password self._kwargs = kwargs - self._mongo: Union[AioMongoClient, None] = None - self._db: Union[Database, None] = None + self._mongo: Union[AsyncIOMotorClient, None] = None + self._db: Union[AsyncIOMotorDatabase, None] = None self._index = index - async def get_client(self) -> AioMongoClient: - if isinstance(self._mongo, AioMongoClient): + async def get_client(self) -> AsyncIOMotorClient: + if isinstance(self._mongo, AsyncIOMotorClient): + return self._mongo + + if self._uri: + self._mongo = AsyncIOMotorClient(self._uri) return self._mongo uri = 'mongodb://' @@ -63,16 +68,16 @@ class MongoStorage(BaseStorage): uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}' # define and return client - self._mongo = await aiomongo.create_client(uri) + self._mongo = AsyncIOMotorClient(uri) return self._mongo - async def get_db(self) -> Database: + async def get_db(self) -> AsyncIOMotorDatabase: """ Get Mongo db This property is awaitable. """ - if isinstance(self._db, Database): + if isinstance(self._db, AsyncIOMotorDatabase): return self._db mongo = await self.get_client()