Rework FSM storage key

This commit is contained in:
Alex Root Junior 2021-10-11 01:30:19 +03:00
parent 8c4d4ef30a
commit 7c6cf3c122
10 changed files with 213 additions and 160 deletions

View file

@ -37,5 +37,5 @@ __all__ = (
"md",
)
__version__ = "3.0.0a17"
__version__ = "3.0.0a18"
__api_version__ = "5.3"

View file

@ -1,44 +1,33 @@
from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
class FSMContext:
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
self.bot = bot
self.storage = storage
self.chat_id = chat_id
self.user_id = user_id
self.key = key
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
await self.storage.set_state(bot=self.bot, key=self.key, state=state)
async def get_state(self) -> Optional[str]:
return await self.storage.get_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
return await self.storage.get_state(bot=self.bot, key=self.key)
async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
await self.storage.set_data(bot=self.bot, key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
return await self.storage.get_data(bot=self.bot, key=self.key)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
if data:
kwargs.update(data)
return await self.storage.update_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
async def clear(self) -> None:
await self.set_state(state=None)

View file

@ -2,7 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import TelegramObject
@ -31,21 +31,28 @@ class FSMContextMiddleware(BaseMiddleware):
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
async with self.storage.lock(
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
async with self.storage.lock(bot=bot, key=context.key):
return await handler(event, data)
return await handler(event, data)
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
def resolve_event_context(
self,
bot: Bot,
data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
def resolve_context(
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
@ -54,8 +61,23 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)
def get_context(
self,
bot: Bot,
chat_id: int,
user_id: int,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
bot=bot,
storage=self.storage,
key=StorageKey(
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
destiny=destiny,
),
)

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
@ -7,45 +8,43 @@ from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]]
DEFAULT_DESTINY = "default"
@dataclass(frozen=True)
class StorageKey:
bot_id: int
chat_id: int
user_id: int
destiny: str = DEFAULT_DESTINY
class BaseStorage(ABC):
@abstractmethod
@asynccontextmanager
async def lock(
self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
yield None
@abstractmethod
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
pass
@abstractmethod
async def get_state(
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
pass
@abstractmethod
async def set_data(
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
pass
@abstractmethod
async def get_data(
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
pass
async def update_data(
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]:
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
current_data = await self.get_data(bot=bot, key=key)
current_data.update(data)
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
await self.set_data(bot=bot, key=key, data=current_data)
return current_data.copy()
@abstractmethod

View file

@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
@dataclass
@ -18,30 +18,26 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage):
def __init__(self) -> None:
self.storage: DefaultDict[
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord
)
async def close(self) -> None:
pass
@asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[bot][chat_id][user_id].lock:
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
async with self.storage[key].lock:
yield None
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[bot][chat_id][user_id].state
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
return self.storage[key].state
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[bot][chat_id][user_id].data = data.copy()
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
self.storage[key].data = data.copy()
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[bot][chat_id][user_id].data.copy()
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
return self.storage[key].data.copy()

View file

@ -1,35 +1,67 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
from aioredis import ConnectionPool, Redis
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class KeyBuilder(ABC):
"""
Base class for Redis key builder
"""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple Redis key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
self, prefix: str = "fsm", with_bot_id: bool = False, with_destiny: bool = False
) -> None:
self.prefix = prefix
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
if self.with_destiny:
parts.append(key.destiny)
parts.append(part)
return ":".join(parts)
class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
prefix: str = "fsm",
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.prefix = prefix
self.prefix_bot = prefix_bot
self.key_builder = key_builder
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs
@ -47,40 +79,28 @@ class RedisStorage(BaseStorage):
async def close(self) -> None:
await self.redis.close() # type: ignore
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager
async def lock(
self, bot: Bot, chat_id: int, user_id: int, state_lock_key: str = STATE_LOCK_KEY
self,
bot: Bot,
key: StorageKey,
) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, state_lock_key)
async with self.redis.lock(name=key, **self.lock_kwargs):
redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
yield None
async def set_state(
self,
bot: Bot,
chat_id: int,
user_id: int,
key: StorageKey,
state: StateType = None,
state_key: str = STATE_KEY,
) -> None:
key = self.generate_key(bot, chat_id, user_id, state_key)
redis_key = self.key_builder.build(key, "state")
if state is None:
await self.redis.delete(key)
await self.redis.delete(redis_key)
else:
await self.redis.set(
key,
redis_key,
state.state if isinstance(state, State) else state, # type: ignore[arg-type]
ex=self.state_ttl, # type: ignore[arg-type]
)
@ -88,12 +108,10 @@ class RedisStorage(BaseStorage):
async def get_state(
self,
bot: Bot,
chat_id: int,
user_id: int,
state_key: str = STATE_KEY,
key: StorageKey,
) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, state_key)
value = await self.redis.get(key)
redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
@ -101,27 +119,26 @@ class RedisStorage(BaseStorage):
async def set_data(
self,
bot: Bot,
chat_id: int,
user_id: int,
key: StorageKey,
data: Dict[str, Any],
state_data_key: str = STATE_DATA_KEY,
) -> None:
key = self.generate_key(bot, chat_id, user_id, state_data_key)
redis_key = self.key_builder.build(key, "data")
if not data:
await self.redis.delete(key)
await self.redis.delete(redis_key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]
await self.redis.set(
redis_key,
bot.session.json_dumps(data),
ex=self.data_ttl, # type: ignore[arg-type]
)
async def get_data(
self,
bot: Bot,
chat_id: int,
user_id: int,
state_data_key: str = STATE_DATA_KEY,
key: StorageKey,
) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, state_data_key)
value = await self.redis.get(key)
redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key)
if value is None:
return {}
if isinstance(value, bytes):