From 51c1bf48e284ca062367c961930dda35ff4bf609 Mon Sep 17 00:00:00 2001 From: Mystic-Mirage Date: Sun, 15 Jan 2023 02:48:10 +0200 Subject: [PATCH] Implement AWS DynamoDB storage --- aiogram/contrib/fsm_storage/dynamodb.py | 213 ++++++++++++++++++ tests/conftest.py | 24 ++ .../test_fsm_storage/test_storage.py | 16 ++ 3 files changed, 253 insertions(+) create mode 100644 aiogram/contrib/fsm_storage/dynamodb.py diff --git a/aiogram/contrib/fsm_storage/dynamodb.py b/aiogram/contrib/fsm_storage/dynamodb.py new file mode 100644 index 00000000..6bbb5a29 --- /dev/null +++ b/aiogram/contrib/fsm_storage/dynamodb.py @@ -0,0 +1,213 @@ +""" +This module has dynamodb storage for finite-state machine + based on `aiodynamo `_ driver +""" + +from typing import Any, AnyStr, Dict, Optional, Tuple, Union + +from aiohttp import ClientSession + +try: + from aiodynamo.client import Client, Table, URL + from aiodynamo.credentials import Credentials + from aiodynamo.errors import ItemNotFound + from aiodynamo.http.aiohttp import AIOHTTP + from aiodynamo.models import ( + KeySchema, + KeySpec, + KeyType, + PayPerRequest, + RetryConfig, + Throughput, + ) + from aiodynamo.types import NumericTypeConverter +except ModuleNotFoundError as e: + import warnings + + warnings.warn("Install aiodynamo with `pip install aiodynamo`") + raise e + +from ...dispatcher.storage import BaseStorage + +TABLE_PREFIX = "aiogram_fsm" +STATE = "state" +DATA = "data" +BUCKET = "bucket" + + +class DynamoDBStorage(BaseStorage): + def __init__( + self, + *, + credentials: Optional[Credentials] = None, + region: str, + endpoint: Union[URL, str, None] = None, + numeric_type: NumericTypeConverter = float, + throttle_config: RetryConfig = RetryConfig.default(), + table_prefix: str = TABLE_PREFIX, + table_throughput: Union[Throughput, PayPerRequest, None] = None, + ): + if not credentials: + credentials = Credentials.auto() + if isinstance(endpoint, str): + endpoint = URL(endpoint) + if not table_throughput: + table_throughput = Throughput(read=5, write=5) + + self._client_session = ClientSession() + self._dynamo = Client( + AIOHTTP(self._client_session), + credentials, + region, + endpoint, + numeric_type, + throttle_config, + ) + self._table_prefix = table_prefix + self._table_throughput = table_throughput + + def resolve_address( + self, *, chat: Union[str, int, None], user: Union[str, int, None] + ) -> Tuple[str, str]: + chat, user = map(str, self.check_address(chat=chat, user=user)) + return chat, user + + async def get_table(self, name: str) -> Table: + table = self._dynamo.table("_".join([self._table_prefix, name])) + if not await table.exists(): + await table.create( + self._table_throughput, + KeySchema( + KeySpec("chat", KeyType.string), + KeySpec("user", KeyType.string), + ), + ) + return table + + @staticmethod + async def get_item(table: Table, key: Dict[str, str]) -> Optional[Dict[str, Any]]: + try: + return await table.get_item(key) + except ItemNotFound: + pass + + async def close(self): + await self._client_session.close() + + async def wait_closed(self): + pass + + async def set_state( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + state: Optional[AnyStr] = None, + ): + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(STATE) + + if state is None: + await table.delete_item({"chat": chat, "user": user}) + else: + await table.put_item( + { + "chat": chat, + "user": user, + "state": self.resolve_state(state), + }, + ) + + async def get_state( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + default: Optional[str] = None, + ) -> Optional[str]: + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(STATE) + result = await self.get_item(table, {"chat": chat, "user": user}) + + return result.get("state") if result else self.resolve_state(default) + + async def set_data( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + data: Dict = None, + ): + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(DATA) + if not data: + await table.delete_item({"chat": chat, "user": user}) + else: + await table.put_item({"chat": chat, "user": user, "data": data}) + + async def get_data( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + default: Optional[dict] = None, + ) -> Dict: + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(DATA) + result = await self.get_item(table, {"chat": chat, "user": user}) + return (result.get("data") if result else default) or {} + + async def update_data( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + data: Dict = None, + **kwargs, + ): + if data is None: + data = {} + temp_data = await self.get_data(chat=chat, user=user, default={}) + temp_data.update(data, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_data) + + def has_bucket(self): + return True + + async def get_bucket( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + default: Optional[dict] = None, + ) -> Dict: + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(BUCKET) + result = await self.get_item(table, {"chat": chat, "user": user}) + return (result.get("bucket") if result else default) or {} + + async def set_bucket( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + bucket: Dict = None, + ): + chat, user = self.resolve_address(chat=chat, user=user) + table = await self.get_table(BUCKET) + + await table.put_item({"chat": chat, "user": user, "bucket": bucket}) + + async def update_bucket( + self, + *, + chat: Union[str, int, None] = None, + user: Union[str, int, None] = None, + bucket: Dict = None, + **kwargs, + ): + if bucket is None: + bucket = {} + temp_bucket = await self.get_bucket(chat=chat, user=user) + temp_bucket.update(bucket, **kwargs) + await self.set_bucket(chat=chat, user=user, bucket=temp_bucket) diff --git a/tests/conftest.py b/tests/conftest.py index 61cd2b2f..04c638b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import pytest_asyncio import aioredis import pytest from _pytest.config import UsageError +from aiodynamo.credentials import Key, StaticCredentials from aiogram import Bot from . import TOKEN @@ -20,6 +21,11 @@ def pytest_addoption(parser): default=None, help="run tests which require redis connection", ) + parser.addoption( + "--dynamodb", + default=None, + help="run tests which require dynamodb-local connection", + ) parser.addini("asyncio_mode", "", default='auto') @@ -81,6 +87,24 @@ def redis_options(request): raise UsageError("Unsupported aioredis version") +@pytest_asyncio.fixture(scope="session") +def dynamodb_options(request): + dynamodb_uri = request.config.getoption("--dynamodb") + if dynamodb_uri is None: + pytest.skip("need --dynamodb option with dynamodb-local URI to run") + return + return { + "credentials": StaticCredentials( + Key( + "DUMMYIDEXAMPLE", + "DUMMYEXAMPLEKEY", + ) + ), + "region": "dummy-region-1", + "endpoint": dynamodb_uri, + } + + @pytest_asyncio.fixture(name='bot') async def bot_fixture(): """Bot fixture.""" diff --git a/tests/test_contrib/test_fsm_storage/test_storage.py b/tests/test_contrib/test_fsm_storage/test_storage.py index 68ce8a81..ed25b4a4 100644 --- a/tests/test_contrib/test_fsm_storage/test_storage.py +++ b/tests/test_contrib/test_fsm_storage/test_storage.py @@ -4,6 +4,7 @@ import pytest_asyncio from pytest_lazyfixture import lazy_fixture from redis.asyncio.connection import Connection, ConnectionPool +from aiogram.contrib.fsm_storage.dynamodb import DynamoDBStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2 @@ -42,11 +43,17 @@ async def memory_store(): yield MemoryStorage() +@pytest_asyncio.fixture() +async def dynamodb_store(dynamodb_options): + yield DynamoDBStorage(**dynamodb_options) + + @pytest.mark.parametrize( "store", [ lazy_fixture('redis_store'), lazy_fixture('redis_store2'), lazy_fixture('memory_store'), + lazy_fixture("dynamodb_store"), ] ) class TestStorage: @@ -89,6 +96,15 @@ class TestRedisStorage2: connection: Connection = pool._available_connections[0] assert connection.is_connected is False + +@pytest.mark.parametrize( + "store", [ + lazy_fixture("redis_store"), + lazy_fixture("redis_store2"), + lazy_fixture("dynamodb_store"), + ] +) +class TestStorage2: @pytest.mark.parametrize( "chat_id,user_id,state", [