Reformatted

This commit is contained in:
asimaranov 2022-01-01 15:19:21 +03:00
parent b0ccc84a7f
commit fd16c05254
2 changed files with 27 additions and 25 deletions

View file

@ -6,11 +6,11 @@ from motor.motor_asyncio import AsyncIOMotorClient # type: ignore
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey, DEFAULT_DESTINY
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StateType, StorageKey
STATE = 'aiogram_state'
DATA = 'aiogram_data'
BUCKET = 'aiogram_bucket'
STATE = "aiogram_state"
DATA = "aiogram_data"
BUCKET = "aiogram_bucket"
COLLECTIONS = (STATE, DATA, BUCKET)
@ -22,9 +22,9 @@ class MongoStorage(BaseStorage):
def __init__(
self,
mongo: AsyncIOMotorClient,
db_name: str = 'aiogram_fsm',
db_name: str = "aiogram_fsm",
with_bot_id: bool = True,
with_destiny: bool = True
with_destiny: bool = True,
) -> None:
"""
:param mongo: Instance of mongo connection
@ -40,11 +40,11 @@ class MongoStorage(BaseStorage):
@classmethod
def from_url(
cls, url: str,
db_name: str = 'aiogram_fsm',
cls,
url: str,
db_name: str = "aiogram_fsm",
with_bot_id: bool = True,
with_destiny: bool = True
with_destiny: bool = True,
) -> "MongoStorage":
"""
Create an instance of :class:`MongoStorage` with specifying the connection string
@ -60,16 +60,16 @@ class MongoStorage(BaseStorage):
mongo=AsyncIOMotorClient(url),
db_name=db_name,
with_bot_id=with_bot_id,
with_destiny=with_destiny
with_destiny=with_destiny,
)
def _get_db_filter(self, key: StorageKey) -> Dict[str, Any]:
db_filter: Dict[str, Any] = {'chat': key.chat_id, 'user': key.user_id}
db_filter: Dict[str, Any] = {"chat": key.chat_id, "user": key.user_id}
if self._with_bot_id:
db_filter['bot_id'] = key.bot_id
db_filter["bot_id"] = key.bot_id
if self._with_destiny:
db_filter['destiny'] = key.destiny
db_filter["destiny"] = key.destiny
elif key.destiny != DEFAULT_DESTINY:
raise ValueError(
@ -105,7 +105,7 @@ class MongoStorage(BaseStorage):
else:
await self._db[STATE].update_one(
filter=self._get_db_filter(key),
update={'$set': {'state': state.state if isinstance(state, State) else state}},
update={"$set": {"state": state.state if isinstance(state, State) else state}},
upsert=True,
)
@ -114,10 +114,8 @@ class MongoStorage(BaseStorage):
bot: Bot,
key: StorageKey,
) -> Optional[str]:
result = await self._db[STATE].find_one(
filter=self._get_db_filter(key)
)
return result.get('state') if result else None
result = await self._db[STATE].find_one(filter=self._get_db_filter(key))
return result.get("state") if result else None
async def set_data(
self,
@ -126,10 +124,10 @@ class MongoStorage(BaseStorage):
data: Dict[str, Any],
) -> None:
await self._db[DATA].update_one(
filter=self._get_db_filter(key),
update={'$set': {'data': data}},
upsert=True,
)
filter=self._get_db_filter(key),
update={"$set": {"data": data}},
upsert=True,
)
async def get_data(
self,
@ -140,4 +138,4 @@ class MongoStorage(BaseStorage):
filter=self._get_db_filter(key)
)
return result.get('data') or {} if result else {}
return result.get("data") or {} if result else {}

View file

@ -13,7 +13,11 @@ def create_storate_key(bot: MockedBot):
@pytest.mark.parametrize(
"storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage"), pytest.lazy_fixture("mongo_storage")],
[
pytest.lazy_fixture("redis_storage"),
pytest.lazy_fixture("memory_storage"),
pytest.lazy_fixture("mongo_storage"),
],
)
class TestStorages:
async def test_lock(self, bot: MockedBot, storage: BaseStorage, storage_key: StorageKey):