Add support for URI-provided db_name in MongoStorage

This commit is contained in:
Grigory Statsenko 2022-02-20 16:55:41 +03:00
parent bb1c774bcc
commit 0ab6b15e5d

View file

@ -43,14 +43,18 @@ class MongoStorage(BaseStorage):
"""
def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None,
username=None, password=None, index=True, **kwargs):
username=None, password=None, index=True, db_from_uri=False, **kwargs):
self._host = host
self._port = port
self._db_name: str = db_name
self._uri = uri
self._username = username
self._password = password
self._kwargs = kwargs # custom client options like SSL configuration, etc.
# allows db_name to be provided as part of the URI without passing it also as db_name,
# while also allowing compatibility for those relying on the 'aiogram_fsm' default
self._db_from_uri = db_from_uri
# custom client options like SSL configuration, etc.
self._kwargs = kwargs
self._mongo: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
@ -70,6 +74,11 @@ class MongoStorage(BaseStorage):
logger = logging.getLogger("aiogram")
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
raise e
if self._db_from_uri:
# Since the URI was provided directly, and this flag was specified,
# extract db_name from the given URI
self._db_name = self._mongo.get_default_database().name
return self._mongo
uri = 'mongodb://'
@ -82,7 +91,7 @@ class MongoStorage(BaseStorage):
uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}'
# define and return client
self._mongo = AsyncIOMotorClient(uri)
self._mongo = AsyncIOMotorClient(uri, **self._kwargs)
return self._mongo
async def get_db(self) -> AsyncIOMotorDatabase: