From 0ab6b15e5dc1a9fc1c0cd6cf82f85bbdbfa7c198 Mon Sep 17 00:00:00 2001 From: Grigory Statsenko Date: Sun, 20 Feb 2022 16:55:41 +0300 Subject: [PATCH] Add support for URI-provided db_name in MongoStorage --- aiogram/contrib/fsm_storage/mongo.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/aiogram/contrib/fsm_storage/mongo.py b/aiogram/contrib/fsm_storage/mongo.py index 7a128f1c..0f873b1f 100644 --- a/aiogram/contrib/fsm_storage/mongo.py +++ b/aiogram/contrib/fsm_storage/mongo.py @@ -43,14 +43,18 @@ class MongoStorage(BaseStorage): """ def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None, - username=None, password=None, index=True, **kwargs): + username=None, password=None, index=True, db_from_uri=False, **kwargs): self._host = host self._port = port self._db_name: str = db_name self._uri = uri self._username = username self._password = password - self._kwargs = kwargs # custom client options like SSL configuration, etc. + # allows db_name to be provided as part of the URI without passing it also as db_name, + # while also allowing compatibility for those relying on the 'aiogram_fsm' default + self._db_from_uri = db_from_uri + # custom client options like SSL configuration, etc. + self._kwargs = kwargs self._mongo: Optional[AsyncIOMotorClient] = None self._db: Optional[AsyncIOMotorDatabase] = None @@ -70,6 +74,11 @@ class MongoStorage(BaseStorage): logger = logging.getLogger("aiogram") logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245") raise e + + if self._db_from_uri: + # Since the URI was provided directly, and this flag was specified, + # extract db_name from the given URI + self._db_name = self._mongo.get_default_database().name return self._mongo uri = 'mongodb://' @@ -82,7 +91,7 @@ class MongoStorage(BaseStorage): uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}' # define and return client - self._mongo = AsyncIOMotorClient(uri) + self._mongo = AsyncIOMotorClient(uri, **self._kwargs) return self._mongo async def get_db(self) -> AsyncIOMotorDatabase: