Added MongoStorage for FSM

This commit is contained in:
Rishat Fayzullin 2024-03-06 22:45:09 +03:00
parent 329c7f4c97
commit a5ccf28498
11 changed files with 312 additions and 102 deletions

View file

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union
from typing import Any, AsyncGenerator, Dict, Literal, Optional, Union
from aiogram.fsm.state import State
@ -19,6 +19,68 @@ class StorageKey:
destiny: str = DEFAULT_DESTINY
class KeyBuilder(ABC):
"""Base class for key builder."""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
"""
This method should be implemented in subclasses
:param key: contextual key
:param part: part of the record
:return: key to be used in storage's db queries
"""
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
self,
*,
prefix: str = "fsm",
separator: str = ":",
with_bot_id: bool = False,
with_destiny: bool = False,
) -> None:
"""
:param prefix: prefix for all records
:param separator: separator
:param with_bot_id: include Bot id in the key
:param with_destiny: include destiny key
"""
self.prefix = prefix
self.separator = separator
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:
raise ValueError(
"Default key builder is not configured to use key destiny other the default.\n"
"\n"
"Probably, you should set `with_destiny=True` in for DefaultKeyBuilder."
)
parts.append(part)
return self.separator.join(parts)
class BaseStorage(ABC):
"""
Base class for all FSM storages

View file

@ -0,0 +1,120 @@
from typing import Any, Dict, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
BaseStorage,
DefaultKeyBuilder,
KeyBuilder,
StateType,
StorageKey,
)
class MongoStorage(BaseStorage):
"""
MongoDB storage required :code:`motor` package installed (:code:`pip install motor`)
"""
def __init__(
self,
client: AsyncIOMotorClient,
key_builder: Optional[KeyBuilder] = None,
db_name: str = "aiogram_fsm",
states_collection_name: str = "states",
data_collection_name: str = "data",
) -> None:
"""
:param client: Instance of AsyncIOMotorClient
:param key_builder: builder that helps to convert contextual key to string
:param db_name: name of the MongoDB database for FSM
:param states_collection_name: name of the collection for storing FSM states.
:param data_collection_name: name of the collection for storing additional data.
"""
if key_builder is None:
key_builder = DefaultKeyBuilder()
self._client = client
self._db_name = db_name
self._states_collection: AsyncIOMotorCollection = self._client[db_name][
states_collection_name
]
self._data_collection: AsyncIOMotorCollection = self._client[db_name][data_collection_name]
self._key_builder = key_builder
@classmethod
def from_url(
cls, url: str, connection_kwargs: Dict[str, Any] = {}, **kwargs: Any
) -> "MongoStorage":
"""
Create an instance of :class:`MongoStorage` with specifying the connection string
:param url: for example :code:`mongodb://user:password@host:port`
:param connection_kwargs: see :code:`motor` docs
:param kwargs: arguments to be passed to :class:`MongoStorage`
:return: an instance of :class:`MongoStorage`
"""
client = AsyncIOMotorClient(url, **connection_kwargs)
return cls(client=client, **kwargs)
async def close(self) -> None:
"""Cleanup client resources and disconnect from MongoDB."""
self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]:
if value is None:
return None
elif isinstance(value, State):
return value.state
return str(value)
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
document_id = self._key_builder.build(key, "state")
if state is None:
await self._states_collection.delete_one({"_id": document_id})
else:
await self._states_collection.update_one(
{"_id": document_id},
{"$set": {"state": self.resolve_state(state)}},
upsert=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
document_id = self._key_builder.build(key, "state")
document = await self._states_collection.find_one({"_id": document_id})
if document is None or document["state"] is None:
return None
else:
return str(document["state"])
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
document_id = self._key_builder.build(key, "data")
if not data:
await self._data_collection.delete_one({"_id": document_id})
else:
await self._data_collection.update_one(
{"_id": document_id},
{"$set": data},
upsert=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
document_id = self._key_builder.build(key, "data")
document = await self._data_collection.find_one({"_id": document_id}, {"_id": 0})
if not document:
return {}
else:
return document # type: ignore
async def update_data(self, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
document_id = self._key_builder.build(key, "data")
update_result = await self._data_collection.find_one_and_update(
{"_id": document_id},
{"$set": data},
upsert=True,
return_document=True,
projection={"_id": 0},
)
if not update_result:
await self._data_collection.delete_one({"_id": document_id})
return update_result # type: ignore

View file

@ -1,7 +1,6 @@
import json
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Literal, Optional, cast
from typing import Any, AsyncGenerator, Callable, Dict, Optional, cast
from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
@ -10,9 +9,10 @@ from redis.typing import ExpiryT
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
DEFAULT_DESTINY,
BaseEventIsolation,
BaseStorage,
DefaultKeyBuilder,
KeyBuilder,
StateType,
StorageKey,
)
@ -22,71 +22,6 @@ _JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
class KeyBuilder(ABC):
"""
Base class for Redis key builder
"""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
"""
This method should be implemented in subclasses
:param key: contextual key
:param part: part of the record
:return: key to be used in Redis queries
"""
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple Redis key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
self,
*,
prefix: str = "fsm",
separator: str = ":",
with_bot_id: bool = False,
with_destiny: bool = False,
) -> None:
"""
:param prefix: prefix for all records
:param separator: separator
:param with_bot_id: include Bot id in the key
:param with_destiny: include destiny key
"""
self.prefix = prefix
self.separator = separator
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:
raise ValueError(
"Redis key builder is not configured to use key destiny other the default.\n"
"\n"
"Probably, you should set `with_destiny=True` in for DefaultKeyBuilder.\n"
"E.g: `RedisStorage(redis, key_builder=DefaultKeyBuilder(with_destiny=True))`"
)
parts.append(part)
return self.separator.join(parts)
class RedisStorage(BaseStorage):
"""
Redis storage required :code:`redis` package installed (:code:`pip install redis`)

View file

@ -78,13 +78,13 @@ Linux / macOS:
.. code-block:: bash
pip install -e ."[dev,test,docs,fast,redis,proxy,i18n]"
pip install -e ."[dev,test,docs,fast,redis,mongo,proxy,i18n]"
Windows:
.. code-block:: bash
pip install -e .[dev,test,docs,fast,redis,proxy,i18n]
pip install -e .[dev,test,docs,fast,redis,mongo,proxy,i18n]
It will install :code:`aiogram` in editable mode into your virtual environment and all dependencies.
@ -116,11 +116,12 @@ All changes should be tested:
pytest tests
Also if you are doing something with Redis-storage, you will need to test everything works with Redis:
Also if you are doing something with Redis-storage or/and MongoDB-storage,
you will need to test everything works with Redis or/and MongoDB:
.. code-block:: bash
pytest --redis redis://<host>:<port>/<db> tests
pytest --redis redis://<host>:<port>/<db> --mongo mongodb://<user>:<password>@<host>:<port> tests
Docs
----

View file

@ -19,13 +19,23 @@ RedisStorage
:members: __init__, from_url
:member-order: bysource
Keys inside storage can be customized via key builders:
MongoStorage
------------
.. autoclass:: aiogram.fsm.storage.redis.KeyBuilder
.. autoclass:: aiogram.fsm.storage.mongo.MongoStorage
:members: __init__, from_url
:member-order: bysource
KeyBuilder
------------
Keys inside Redis and Mongo storages can be customized via key builders:
.. autoclass:: aiogram.fsm.storage.base.KeyBuilder
:members:
:member-order: bysource
.. autoclass:: aiogram.fsm.storage.redis.DefaultKeyBuilder
.. autoclass:: aiogram.fsm.storage.base.DefaultKeyBuilder
:members:
:member-order: bysource

View file

@ -61,6 +61,9 @@ fast = [
redis = [
"redis[hiredis]~=5.0.1",
]
mongo = [
"motor~=3.3.2",
]
proxy = [
"aiohttp-socks~=0.8.3",
]
@ -105,6 +108,7 @@ dev = [
"toml~=0.10.2",
"pre-commit~=3.5.0",
"packaging~=23.1",
"motor-types~=1.0.0b4",
]
[project.urls]
@ -117,6 +121,7 @@ features = [
"dev",
"fast",
"redis",
"mongo",
"proxy",
"i18n",
"cli",
@ -136,6 +141,7 @@ lint = "ruff aiogram"
features = [
"fast",
"redis",
"mongo",
"proxy",
"i18n",
"docs",
@ -150,6 +156,7 @@ features = [
"dev",
"fast",
"redis",
"mongo",
"proxy",
"i18n",
"test",
@ -167,6 +174,7 @@ update = [
features = [
"fast",
"redis",
"mongo",
"proxy",
"i18n",
"test",
@ -182,6 +190,10 @@ cov-redis = [
"pytest --cov-config pyproject.toml --cov=aiogram --html=reports/py{matrix:python}/tests/index.html --redis {env:REDIS_DNS:'redis://localhost:6379'} {args}",
"coverage html -d reports/py{matrix:python}/coverage",
]
cov-mongo = [
"pytest --cov-config pyproject.toml --cov=aiogram --html=reports/py{matrix:python}/tests/index.html --mongo {env:MONGO_DNS:'mongodb://mongo:mongo@localhost:27017'} {args}",
"coverage html -d reports/py{matrix:python}/coverage",
]
view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html"

View file

@ -2,6 +2,8 @@ from pathlib import Path
import pytest
from _pytest.config import UsageError
from pymongo.errors import InvalidURI, PyMongoError
from pymongo.uri_parser import parse_uri as parse_mongo_url
from redis.asyncio.connection import parse_url as parse_redis_url
from aiogram import Dispatcher
@ -10,6 +12,7 @@ from aiogram.fsm.storage.memory import (
MemoryStorage,
SimpleEventIsolation,
)
from aiogram.fsm.storage.mongo import MongoStorage
from aiogram.fsm.storage.redis import RedisEventIsolation, RedisStorage
from tests.mocked_bot import MockedBot
@ -18,24 +21,27 @@ DATA_DIR = Path(__file__).parent / "data"
def pytest_addoption(parser):
parser.addoption("--redis", default=None, help="run tests which require redis connection")
parser.addoption("--mongo", default=None, help="run tests which require mongo connection")
def pytest_configure(config):
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
config.addinivalue_line("markers", "mongo: marked tests require mongo connection to run")
def pytest_collection_modifyitems(config, items):
redis_uri = config.getoption("--redis")
if redis_uri is None:
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
for item in items:
if "redis" in item.keywords:
item.add_marker(skip_redis)
return
try:
parse_redis_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
for db, parse_uri in [("redis", parse_redis_url), ("mongo", parse_mongo_url)]:
uri = config.getoption(f"--{db}")
if uri is None:
skip = pytest.mark.skip(reason=f"need --{db} option with {db} URI to run")
for item in items:
if db in item.keywords:
item.add_marker(skip)
else:
try:
parse_uri(uri)
except (ValueError, InvalidURI) as e:
raise UsageError(f"Invalid {db} URI {uri!r}: {e}")
@pytest.fixture()
@ -62,6 +68,29 @@ async def redis_storage(redis_server):
await storage.close()
@pytest.fixture()
def mongo_server(request):
mongo_uri = request.config.getoption("--mongo")
return mongo_uri
@pytest.fixture()
@pytest.mark.mongo
async def mongo_storage(mongo_server):
if not mongo_server:
pytest.skip("MongoDB is not available here")
storage = MongoStorage.from_url(mongo_server)
try:
await storage._client.server_info()
except PyMongoError as e:
pytest.skip(str(e))
else:
yield storage
await storage._client.drop_database(storage._db_name)
finally:
await storage.close()
@pytest.fixture()
async def memory_storage():
storage = MemoryStorage()

View file

@ -5,3 +5,11 @@ services:
image: redis:6-alpine
ports:
- "${REDIS_PORT-6379}:6379"
mongo:
image: mongo:7.0.6
environment:
MONGO_INITDB_ROOT_USERNAME: mongo
MONGO_INITDB_ROOT_PASSWORD: mongo
ports:
- "${MONGODB_PORT-27017}:27017"

View file

@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from aiogram.fsm.storage.base import BaseEventIsolation, StorageKey
from aiogram.fsm.storage.redis import RedisEventIsolation
from aiogram.fsm.storage.redis import RedisEventIsolation, RedisStorage
from tests.mocked_bot import MockedBot
@ -32,6 +32,14 @@ class TestIsolations:
class TestRedisEventIsolation:
def test_create_isolation(self):
fake_redis = object()
storage = RedisStorage(redis=fake_redis)
isolation = storage.create_isolation()
assert isinstance(isolation, RedisEventIsolation)
assert isolation.redis is fake_redis
assert isolation.key_builder is storage.key_builder
def test_init_without_key_builder(self):
redis = AsyncMock()
isolation = RedisEventIsolation(redis=redis)

View file

@ -1,11 +1,7 @@
import pytest
from aiogram.fsm.storage.base import DEFAULT_DESTINY, StorageKey
from aiogram.fsm.storage.redis import (
DefaultKeyBuilder,
RedisEventIsolation,
RedisStorage,
)
from aiogram.fsm.storage.redis import DefaultKeyBuilder
PREFIX = "test"
BOT_ID = 42
@ -15,7 +11,7 @@ THREAD_ID = 3
FIELD = "data"
class TestRedisDefaultKeyBuilder:
class TestDefaultKeyBuilder:
@pytest.mark.parametrize(
"with_bot_id,with_destiny,result",
[
@ -59,11 +55,3 @@ class TestRedisDefaultKeyBuilder:
destiny=DEFAULT_DESTINY,
)
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
def test_create_isolation(self):
fake_redis = object()
storage = RedisStorage(redis=fake_redis)
isolation = storage.create_isolation()
assert isinstance(isolation, RedisEventIsolation)
assert isolation.redis is fake_redis
assert isolation.key_builder is storage.key_builder

View file

@ -0,0 +1,37 @@
import pytest
from aiogram.fsm.state import State
from aiogram.fsm.storage.mongo import MongoStorage, StorageKey
from tests.mocked_bot import MockedBot
@pytest.fixture(name="storage_key")
def create_storage_key(bot: MockedBot):
return StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)
async def test_update_not_existing_data_with_empty_dictionary(
mongo_storage: MongoStorage,
storage_key: StorageKey,
):
assert await mongo_storage._data_collection.find_one({}) is None
assert await mongo_storage.get_data(key=storage_key) == {}
assert await mongo_storage.update_data(key=storage_key, data={}) == {}
assert await mongo_storage._data_collection.find_one({}) is None
@pytest.mark.parametrize(
"value,result",
[
[None, None],
["", ""],
["text", "text"],
[State(), None],
[State(state="*"), "*"],
[State("text"), "@:text"],
[State("test", group_name="Test"), "Test:test"],
[[1, 2, 3], "[1, 2, 3]"],
],
)
def test_resolve_state(value, result, mongo_storage: MongoStorage):
assert mongo_storage.resolve_state(value) == result