mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Rework FSM storage key
This commit is contained in:
parent
8c4d4ef30a
commit
7c6cf3c122
10 changed files with 213 additions and 160 deletions
|
|
@ -37,5 +37,5 @@ __all__ = (
|
|||
"md",
|
||||
)
|
||||
|
||||
__version__ = "3.0.0a17"
|
||||
__version__ = "3.0.0a18"
|
||||
__api_version__ = "5.3"
|
||||
|
|
|
|||
|
|
@ -1,44 +1,33 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||
|
||||
|
||||
class FSMContext:
|
||||
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
|
||||
def __init__(self, bot: Bot, storage: BaseStorage, key: StorageKey) -> None:
|
||||
self.bot = bot
|
||||
self.storage = storage
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
self.key = key
|
||||
|
||||
async def set_state(self, state: StateType = None) -> None:
|
||||
await self.storage.set_state(
|
||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
|
||||
)
|
||||
await self.storage.set_state(bot=self.bot, key=self.key, state=state)
|
||||
|
||||
async def get_state(self) -> Optional[str]:
|
||||
return await self.storage.get_state(
|
||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
|
||||
)
|
||||
return await self.storage.get_state(bot=self.bot, key=self.key)
|
||||
|
||||
async def set_data(self, data: Dict[str, Any]) -> None:
|
||||
await self.storage.set_data(
|
||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
|
||||
)
|
||||
await self.storage.set_data(bot=self.bot, key=self.key, data=data)
|
||||
|
||||
async def get_data(self) -> Dict[str, Any]:
|
||||
return await self.storage.get_data(
|
||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
|
||||
)
|
||||
return await self.storage.get_data(bot=self.bot, key=self.key)
|
||||
|
||||
async def update_data(
|
||||
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
if data:
|
||||
kwargs.update(data)
|
||||
return await self.storage.update_data(
|
||||
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
|
||||
)
|
||||
return await self.storage.update_data(bot=self.bot, key=self.key, data=kwargs)
|
||||
|
||||
async def clear(self) -> None:
|
||||
await self.set_state(state=None)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
|||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.context import FSMContext
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage
|
||||
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
|
||||
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.types import TelegramObject
|
||||
|
|
@ -31,21 +31,28 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
if context:
|
||||
data.update({"state": context, "raw_state": await context.get_state()})
|
||||
if self.isolate_events:
|
||||
async with self.storage.lock(
|
||||
bot=bot, chat_id=context.chat_id, user_id=context.user_id
|
||||
):
|
||||
async with self.storage.lock(bot=bot, key=context.key):
|
||||
return await handler(event, data)
|
||||
return await handler(event, data)
|
||||
|
||||
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
|
||||
def resolve_event_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
data: Dict[str, Any],
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> Optional[FSMContext]:
|
||||
user = data.get("event_from_user")
|
||||
chat = data.get("event_chat")
|
||||
chat_id = chat.id if chat else None
|
||||
user_id = user.id if user else None
|
||||
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
|
||||
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
|
||||
def resolve_context(
|
||||
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: Optional[int],
|
||||
user_id: Optional[int],
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> Optional[FSMContext]:
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
|
|
@ -54,8 +61,23 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
chat_id, user_id = apply_strategy(
|
||||
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
||||
)
|
||||
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
|
||||
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
|
||||
return None
|
||||
|
||||
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
|
||||
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)
|
||||
def get_context(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
destiny: str = DEFAULT_DESTINY,
|
||||
) -> FSMContext:
|
||||
return FSMContext(
|
||||
bot=bot,
|
||||
storage=self.storage,
|
||||
key=StorageKey(
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
bot_id=bot.id,
|
||||
destiny=destiny,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
||||
|
||||
from aiogram import Bot
|
||||
|
|
@ -7,45 +8,43 @@ from aiogram.dispatcher.fsm.state import State
|
|||
|
||||
StateType = Optional[Union[str, State]]
|
||||
|
||||
DEFAULT_DESTINY = "default"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageKey:
|
||||
bot_id: int
|
||||
chat_id: int
|
||||
user_id: int
|
||||
destiny: str = DEFAULT_DESTINY
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def lock(
|
||||
self, bot: Bot, chat_id: int, user_id: int
|
||||
) -> AsyncGenerator[None, None]: # pragma: no cover
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
yield None
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(
|
||||
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
|
||||
) -> None: # pragma: no cover
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(
|
||||
self, bot: Bot, chat_id: int, user_id: int
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_data(
|
||||
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
|
||||
) -> None: # pragma: no cover
|
||||
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_data(
|
||||
self, bot: Bot, chat_id: int, user_id: int
|
||||
) -> Dict[str, Any]: # pragma: no cover
|
||||
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
async def update_data(
|
||||
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
|
||||
async def update_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
current_data = await self.get_data(bot=bot, key=key)
|
||||
current_data.update(data)
|
||||
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
|
||||
await self.set_data(bot=bot, key=key, data=current_data)
|
||||
return current_data.copy()
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
|
|||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -18,30 +18,26 @@ class MemoryStorageRecord:
|
|||
|
||||
class MemoryStorage(BaseStorage):
|
||||
def __init__(self) -> None:
|
||||
self.storage: DefaultDict[
|
||||
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
|
||||
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
|
||||
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
|
||||
MemoryStorageRecord
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
|
||||
async with self.storage[bot][chat_id][user_id].lock:
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
async with self.storage[key].lock:
|
||||
yield None
|
||||
|
||||
async def set_state(
|
||||
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
|
||||
) -> None:
|
||||
self.storage[bot][chat_id][user_id].state = (
|
||||
state.state if isinstance(state, State) else state
|
||||
)
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
self.storage[key].state = state.state if isinstance(state, State) else state
|
||||
|
||||
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
|
||||
return self.storage[bot][chat_id][user_id].state
|
||||
async def get_state(self, bot: Bot, key: StorageKey) -> Optional[str]:
|
||||
return self.storage[key].state
|
||||
|
||||
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
|
||||
self.storage[bot][chat_id][user_id].data = data.copy()
|
||||
async def set_data(self, bot: Bot, key: StorageKey, data: Dict[str, Any]) -> None:
|
||||
self.storage[key].data = data.copy()
|
||||
|
||||
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
|
||||
return self.storage[bot][chat_id][user_id].data.copy()
|
||||
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||
return self.storage[key].data.copy()
|
||||
|
|
|
|||
|
|
@ -1,35 +1,67 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
|
||||
|
||||
from aioredis import ConnectionPool, Redis
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
||||
|
||||
PrefixFactoryType = Callable[[Bot], str]
|
||||
STATE_KEY = "state"
|
||||
STATE_DATA_KEY = "data"
|
||||
STATE_LOCK_KEY = "lock"
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||
|
||||
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
|
||||
|
||||
|
||||
class KeyBuilder(ABC):
|
||||
"""
|
||||
Base class for Redis key builder
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultKeyBuilder(KeyBuilder):
|
||||
"""
|
||||
Simple Redis key builder with default prefix.
|
||||
|
||||
Generates a colon-joined string with prefix, chat_id, user_id,
|
||||
optional bot_id and optional destiny.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, prefix: str = "fsm", with_bot_id: bool = False, with_destiny: bool = False
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
self.with_bot_id = with_bot_id
|
||||
self.with_destiny = with_destiny
|
||||
|
||||
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
|
||||
parts = [self.prefix]
|
||||
if self.with_bot_id:
|
||||
parts.append(str(key.bot_id))
|
||||
parts.extend([str(key.chat_id), str(key.user_id)])
|
||||
if self.with_destiny:
|
||||
parts.append(key.destiny)
|
||||
parts.append(part)
|
||||
return ":".join(parts)
|
||||
|
||||
|
||||
class RedisStorage(BaseStorage):
|
||||
def __init__(
|
||||
self,
|
||||
redis: Redis,
|
||||
prefix: str = "fsm",
|
||||
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
state_ttl: Optional[int] = None,
|
||||
data_ttl: Optional[int] = None,
|
||||
lock_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if key_builder is None:
|
||||
key_builder = DefaultKeyBuilder()
|
||||
if lock_kwargs is None:
|
||||
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
|
||||
self.redis = redis
|
||||
self.prefix = prefix
|
||||
self.prefix_bot = prefix_bot
|
||||
self.key_builder = key_builder
|
||||
self.state_ttl = state_ttl
|
||||
self.data_ttl = data_ttl
|
||||
self.lock_kwargs = lock_kwargs
|
||||
|
|
@ -47,40 +79,28 @@ class RedisStorage(BaseStorage):
|
|||
async def close(self) -> None:
|
||||
await self.redis.close() # type: ignore
|
||||
|
||||
def generate_key(self, bot: Bot, *parts: Any) -> str:
|
||||
prefix_parts = [self.prefix]
|
||||
if self.prefix_bot:
|
||||
if isinstance(self.prefix_bot, dict):
|
||||
prefix_parts.append(self.prefix_bot[bot.id])
|
||||
elif callable(self.prefix_bot):
|
||||
prefix_parts.append(self.prefix_bot(bot))
|
||||
else:
|
||||
prefix_parts.append(str(bot.id))
|
||||
prefix_parts.extend(parts)
|
||||
return ":".join(map(str, prefix_parts))
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(
|
||||
self, bot: Bot, chat_id: int, user_id: int, state_lock_key: str = STATE_LOCK_KEY
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
key = self.generate_key(bot, chat_id, user_id, state_lock_key)
|
||||
async with self.redis.lock(name=key, **self.lock_kwargs):
|
||||
redis_key = self.key_builder.build(key, "lock")
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
|
||||
yield None
|
||||
|
||||
async def set_state(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
key: StorageKey,
|
||||
state: StateType = None,
|
||||
state_key: str = STATE_KEY,
|
||||
) -> None:
|
||||
key = self.generate_key(bot, chat_id, user_id, state_key)
|
||||
redis_key = self.key_builder.build(key, "state")
|
||||
if state is None:
|
||||
await self.redis.delete(key)
|
||||
await self.redis.delete(redis_key)
|
||||
else:
|
||||
await self.redis.set(
|
||||
key,
|
||||
redis_key,
|
||||
state.state if isinstance(state, State) else state, # type: ignore[arg-type]
|
||||
ex=self.state_ttl, # type: ignore[arg-type]
|
||||
)
|
||||
|
|
@ -88,12 +108,10 @@ class RedisStorage(BaseStorage):
|
|||
async def get_state(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
state_key: str = STATE_KEY,
|
||||
key: StorageKey,
|
||||
) -> Optional[str]:
|
||||
key = self.generate_key(bot, chat_id, user_id, state_key)
|
||||
value = await self.redis.get(key)
|
||||
redis_key = self.key_builder.build(key, "state")
|
||||
value = await self.redis.get(redis_key)
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return cast(Optional[str], value)
|
||||
|
|
@ -101,27 +119,26 @@ class RedisStorage(BaseStorage):
|
|||
async def set_data(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
key: StorageKey,
|
||||
data: Dict[str, Any],
|
||||
state_data_key: str = STATE_DATA_KEY,
|
||||
) -> None:
|
||||
key = self.generate_key(bot, chat_id, user_id, state_data_key)
|
||||
redis_key = self.key_builder.build(key, "data")
|
||||
if not data:
|
||||
await self.redis.delete(key)
|
||||
await self.redis.delete(redis_key)
|
||||
return
|
||||
json_data = bot.session.json_dumps(data)
|
||||
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]
|
||||
await self.redis.set(
|
||||
redis_key,
|
||||
bot.session.json_dumps(data),
|
||||
ex=self.data_ttl, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def get_data(
|
||||
self,
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
state_data_key: str = STATE_DATA_KEY,
|
||||
key: StorageKey,
|
||||
) -> Dict[str, Any]:
|
||||
key = self.generate_key(bot, chat_id, user_id, state_data_key)
|
||||
value = await self.redis.get(key)
|
||||
redis_key = self.key_builder.build(key, "data")
|
||||
value = await self.redis.get(redis_key)
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, bytes):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue