mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Rework events isolation
This commit is contained in:
parent
b67a57eb04
commit
a846acd607
7 changed files with 118 additions and 51 deletions
|
|
@ -16,8 +16,8 @@ from ..utils.backoff import Backoff, BackoffConfig
|
|||
from .event.bases import UNHANDLED, SkipHandler
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .fsm.middleware import FSMContextMiddleware
|
||||
from .fsm.storage.base import BaseStorage
|
||||
from .fsm.storage.memory import MemoryStorage
|
||||
from .fsm.storage.base import BaseEventIsolation, BaseStorage
|
||||
from .fsm.storage.memory import DisabledEventIsolation, MemoryStorage
|
||||
from .fsm.strategy import FSMStrategy
|
||||
from .middlewares.error import ErrorsMiddleware
|
||||
from .middlewares.user_context import UserContextMiddleware
|
||||
|
|
@ -35,7 +35,7 @@ class Dispatcher(Router):
|
|||
self,
|
||||
storage: Optional[BaseStorage] = None,
|
||||
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
||||
isolate_events: bool = False,
|
||||
events_isolation: Optional[BaseEventIsolation] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super(Dispatcher, self).__init__(**kwargs)
|
||||
|
|
@ -48,19 +48,22 @@ class Dispatcher(Router):
|
|||
)
|
||||
self.update.register(self._listen_update)
|
||||
|
||||
# Error handlers should works is out of all other functions and be registered before all other middlewares
|
||||
# Error handlers should work is out of all other functions and be registered before all others middlewares
|
||||
self.update.outer_middleware(ErrorsMiddleware(self))
|
||||
|
||||
# User context middleware makes small optimization for all other builtin
|
||||
# middlewares via caching the user and chat instances in the event context
|
||||
self.update.outer_middleware(UserContextMiddleware())
|
||||
|
||||
# FSM middleware should always be registered after User context middleware
|
||||
# because here is used context from previous step
|
||||
self.fsm = FSMContextMiddleware(
|
||||
storage=storage if storage else MemoryStorage(),
|
||||
strategy=fsm_strategy,
|
||||
isolate_events=isolate_events,
|
||||
events_isolation=events_isolation if events_isolation else DisabledEventIsolation(),
|
||||
)
|
||||
self.update.outer_middleware(self.fsm)
|
||||
self.shutdown.register(self.fsm.close)
|
||||
|
||||
self._running_lock = Lock()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,12 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
|||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.context import FSMContext
|
||||
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey
|
||||
from aiogram.dispatcher.fsm.storage.base import (
|
||||
DEFAULT_DESTINY,
|
||||
BaseEventIsolation,
|
||||
BaseStorage,
|
||||
StorageKey,
|
||||
)
|
||||
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.types import TelegramObject
|
||||
|
|
@ -12,12 +17,12 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
def __init__(
|
||||
self,
|
||||
storage: BaseStorage,
|
||||
events_isolation: BaseEventIsolation,
|
||||
strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
||||
isolate_events: bool = True,
|
||||
) -> None:
|
||||
self.storage = storage
|
||||
self.strategy = strategy
|
||||
self.isolate_events = isolate_events
|
||||
self.events_isolation = events_isolation
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
|
|
@ -30,9 +35,8 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
data["fsm_storage"] = self.storage
|
||||
if context:
|
||||
data.update({"state": context, "raw_state": await context.get_state()})
|
||||
if self.isolate_events:
|
||||
async with self.storage.lock(bot=bot, key=context.key):
|
||||
return await handler(event, data)
|
||||
async with self.events_isolation.lock(bot=bot, key=context.key):
|
||||
return await handler(event, data)
|
||||
return await handler(event, data)
|
||||
|
||||
def resolve_event_context(
|
||||
|
|
@ -81,3 +85,7 @@ class FSMContextMiddleware(BaseMiddleware):
|
|||
destiny=destiny,
|
||||
),
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.storage.close()
|
||||
await self.events_isolation.close()
|
||||
|
|
|
|||
|
|
@ -24,19 +24,6 @@ class BaseStorage(ABC):
|
|||
Base class for all FSM storages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Isolate events with lock.
|
||||
Will be used as context manager
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:return: An async generator
|
||||
"""
|
||||
yield None
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
"""
|
||||
|
|
@ -101,3 +88,21 @@ class BaseStorage(ABC):
|
|||
Close storage (database connection, file or etc.)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseEventIsolation(ABC):
|
||||
@abstractmethod
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Isolate events with lock.
|
||||
Will be used as context manager
|
||||
|
||||
:param bot: instance of the current bot
|
||||
:param key: storage key
|
||||
:return: An async generator
|
||||
"""
|
||||
yield None
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,18 +2,22 @@ from asyncio import Lock
|
|||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
|
||||
from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||
from aiogram.dispatcher.fsm.storage.base import (
|
||||
BaseEventIsolation,
|
||||
BaseStorage,
|
||||
StateType,
|
||||
StorageKey,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryStorageRecord:
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
state: Optional[str] = None
|
||||
lock: Lock = field(default_factory=Lock)
|
||||
|
||||
|
||||
class MemoryStorage(BaseStorage):
|
||||
|
|
@ -34,11 +38,6 @@ class MemoryStorage(BaseStorage):
|
|||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
async with self.storage[key].lock:
|
||||
yield None
|
||||
|
||||
async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None:
|
||||
self.storage[key].state = state.state if isinstance(state, State) else state
|
||||
|
||||
|
|
@ -50,3 +49,21 @@ class MemoryStorage(BaseStorage):
|
|||
|
||||
async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]:
|
||||
return self.storage[key].data.copy()
|
||||
|
||||
|
||||
class DisabledEventIsolation(BaseEventIsolation):
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
|
||||
class SimpleEventIsolation(BaseEventIsolation):
|
||||
def __init__(self) -> None:
|
||||
# TODO: Unused locks cleaner is needed
|
||||
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||
lock = self._locks[key]
|
||||
async with lock:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -6,7 +6,13 @@ from aioredis import ConnectionPool, Redis
|
|||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StateType, StorageKey
|
||||
from aiogram.dispatcher.fsm.storage.base import (
|
||||
DEFAULT_DESTINY,
|
||||
BaseEventIsolation,
|
||||
BaseStorage,
|
||||
StateType,
|
||||
StorageKey,
|
||||
)
|
||||
|
||||
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
|
||||
|
||||
|
|
@ -121,19 +127,12 @@ class RedisStorage(BaseStorage):
|
|||
redis = Redis(connection_pool=pool)
|
||||
return cls(redis=redis, **kwargs)
|
||||
|
||||
def create_isolation(self, **kwargs) -> "RedisEventIsolation":
|
||||
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.redis.close() # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
redis_key = self.key_builder.build(key, "lock")
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
|
||||
yield None
|
||||
|
||||
async def set_state(
|
||||
self,
|
||||
bot: Bot,
|
||||
|
|
@ -189,3 +188,38 @@ class RedisStorage(BaseStorage):
|
|||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
return cast(Dict[str, Any], bot.session.json_loads(value))
|
||||
|
||||
|
||||
class RedisEventIsolation(BaseEventIsolation):
|
||||
def __init__(
|
||||
self,
|
||||
redis: Redis,
|
||||
key_builder: KeyBuilder,
|
||||
lock_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.redis = redis
|
||||
self.key_builder = key_builder
|
||||
self.lock_kwargs = lock_kwargs or {}
|
||||
|
||||
@classmethod
|
||||
async def from_redis(
|
||||
cls,
|
||||
url: str,
|
||||
connection_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "RedisEventIsolation":
|
||||
if connection_kwargs is None:
|
||||
connection_kwargs = {}
|
||||
pool = ConnectionPool.from_url(url, **connection_kwargs)
|
||||
redis = Redis(connection_pool=pool)
|
||||
return cls(redis=redis, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(
|
||||
self,
|
||||
bot: Bot,
|
||||
key: StorageKey,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
redis_key = self.key_builder.build(key, "lock")
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
|
||||
yield None
|
||||
|
|
|
|||
14
poetry.lock
generated
14
poetry.lock
generated
|
|
@ -41,7 +41,7 @@ python-socks = {version = ">=1.0.1", extras = ["asyncio"]}
|
|||
|
||||
[[package]]
|
||||
name = "aioredis"
|
||||
version = "2.0.0"
|
||||
version = "2.0.1"
|
||||
description = "asyncio (PEP 3156) Redis support"
|
||||
category = "main"
|
||||
optional = true
|
||||
|
|
@ -917,7 +917,7 @@ pytest = ">=3.5"
|
|||
|
||||
[[package]]
|
||||
name = "python-socks"
|
||||
version = "2.0.2"
|
||||
version = "2.0.3"
|
||||
description = "Core proxy (SOCKS4, SOCKS5, HTTP tunneling) functionality for Python"
|
||||
category = "main"
|
||||
optional = true
|
||||
|
|
@ -1354,7 +1354,7 @@ redis = ["aioredis"]
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "678dbb9bbf362c01757842a977f7b0f0f65638a657e3599b6125f0f4e42f4b6c"
|
||||
content-hash = "5b0879bc48b528901258ff901ba2b79fca747845dd51289eee1a01796390950b"
|
||||
|
||||
[metadata.files]
|
||||
aiofiles = [
|
||||
|
|
@ -1440,8 +1440,8 @@ aiohttp-socks = [
|
|||
{file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"},
|
||||
]
|
||||
aioredis = [
|
||||
{file = "aioredis-2.0.0-py3-none-any.whl", hash = "sha256:9921d68a3df5c5cdb0d5b49ad4fc88a4cfdd60c108325df4f0066e8410c55ffb"},
|
||||
{file = "aioredis-2.0.0.tar.gz", hash = "sha256:3a2de4b614e6a5f8e104238924294dc4e811aefbe17ddf52c04a93cbf06e67db"},
|
||||
{file = "aioredis-2.0.1-py3-none-any.whl", hash = "sha256:9ac0d0b3b485d293b8ca1987e6de8658d7dafcca1cddfcd1d506cae8cdebfdd6"},
|
||||
{file = "aioredis-2.0.1.tar.gz", hash = "sha256:eaa51aaf993f2d71f54b70527c440437ba65340588afeb786cd87c55c89cd98e"},
|
||||
]
|
||||
aiosignal = [
|
||||
{file = "aiosignal-1.2.0-py3-none-any.whl", hash = "sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a"},
|
||||
|
|
@ -2033,8 +2033,8 @@ pytest-mypy = [
|
|||
{file = "pytest_mypy-0.8.1-py3-none-any.whl", hash = "sha256:6e68e8eb7ceeb7d1c83a1590912f784879f037b51adfb9c17b95c6b2fc57466b"},
|
||||
]
|
||||
python-socks = [
|
||||
{file = "python-socks-2.0.2.tar.gz", hash = "sha256:aa9b7a53e81ae6b6e3ada602761012e470ea1c4cbcd5548f99b3fc102dce4fca"},
|
||||
{file = "python_socks-2.0.2-py3-none-any.whl", hash = "sha256:faa46857c79a8bf7def2e904ac839fb56755d7ab76c4cad12a131a85fec07241"},
|
||||
{file = "python-socks-2.0.3.tar.gz", hash = "sha256:e3a9ca8e554733862ce4d8ce1d10efb480fd3a3acdafd03393943ec00c98ba8a"},
|
||||
{file = "python_socks-2.0.3-py3-none-any.whl", hash = "sha256:950723f27d2cf401e193a9e0a0d45baab848341298f5b397d27fda0c4635e9a9"},
|
||||
]
|
||||
pytz = [
|
||||
{file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ Babel = { version = "^2.9.1", optional = true }
|
|||
# Proxy
|
||||
aiohttp-socks = { version = "^0.5.5", optional = true }
|
||||
# Redis
|
||||
aioredis = { version = "^2.0.0", optional = true }
|
||||
aioredis = {version = "^2.0.1", optional = true}
|
||||
# Docs
|
||||
Sphinx = { version = "^4.2.0", optional = true }
|
||||
sphinx-intl = { version = "^2.0.1", optional = true }
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue