diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index c467ccad..610053db 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -16,8 +16,8 @@ from ..utils.backoff import Backoff, BackoffConfig from .event.bases import UNHANDLED, SkipHandler from .event.telegram import TelegramEventObserver from .fsm.middleware import FSMContextMiddleware -from .fsm.storage.base import BaseStorage -from .fsm.storage.memory import MemoryStorage +from .fsm.storage.base import BaseEventIsolation, BaseStorage +from .fsm.storage.memory import DisabledEventIsolation, MemoryStorage from .fsm.strategy import FSMStrategy from .middlewares.error import ErrorsMiddleware from .middlewares.user_context import UserContextMiddleware @@ -35,7 +35,7 @@ class Dispatcher(Router): self, storage: Optional[BaseStorage] = None, fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, - isolate_events: bool = False, + events_isolation: Optional[BaseEventIsolation] = None, **kwargs: Any, ) -> None: super(Dispatcher, self).__init__(**kwargs) @@ -48,19 +48,22 @@ class Dispatcher(Router): ) self.update.register(self._listen_update) - # Error handlers should works is out of all other functions and be registered before all other middlewares + # Error handlers should work is out of all other functions and be registered before all others middlewares self.update.outer_middleware(ErrorsMiddleware(self)) + # User context middleware makes small optimization for all other builtin # middlewares via caching the user and chat instances in the event context self.update.outer_middleware(UserContextMiddleware()) + # FSM middleware should always be registered after User context middleware # because here is used context from previous step self.fsm = FSMContextMiddleware( storage=storage if storage else MemoryStorage(), strategy=fsm_strategy, - isolate_events=isolate_events, + events_isolation=events_isolation if events_isolation else DisabledEventIsolation(), ) self.update.outer_middleware(self.fsm) + self.shutdown.register(self.fsm.close) self._running_lock = Lock() diff --git a/aiogram/dispatcher/fsm/middleware.py b/aiogram/dispatcher/fsm/middleware.py index 8d59ff67..29db32ee 100644 --- a/aiogram/dispatcher/fsm/middleware.py +++ b/aiogram/dispatcher/fsm/middleware.py @@ -2,7 +2,12 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast from aiogram import Bot from aiogram.dispatcher.fsm.context import FSMContext -from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StorageKey +from aiogram.dispatcher.fsm.storage.base import ( + DEFAULT_DESTINY, + BaseEventIsolation, + BaseStorage, + StorageKey, +) from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.types import TelegramObject @@ -12,12 +17,12 @@ class FSMContextMiddleware(BaseMiddleware): def __init__( self, storage: BaseStorage, + events_isolation: BaseEventIsolation, strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, - isolate_events: bool = True, ) -> None: self.storage = storage self.strategy = strategy - self.isolate_events = isolate_events + self.events_isolation = events_isolation async def __call__( self, @@ -30,9 +35,8 @@ class FSMContextMiddleware(BaseMiddleware): data["fsm_storage"] = self.storage if context: data.update({"state": context, "raw_state": await context.get_state()}) - if self.isolate_events: - async with self.storage.lock(bot=bot, key=context.key): - return await handler(event, data) + async with self.events_isolation.lock(bot=bot, key=context.key): + return await handler(event, data) return await handler(event, data) def resolve_event_context( @@ -81,3 +85,7 @@ class FSMContextMiddleware(BaseMiddleware): destiny=destiny, ), ) + + async def close(self) -> None: + await self.storage.close() + await self.events_isolation.close() diff --git a/aiogram/dispatcher/fsm/storage/base.py b/aiogram/dispatcher/fsm/storage/base.py index f4830e0f..0fd13c28 100644 --- a/aiogram/dispatcher/fsm/storage/base.py +++ b/aiogram/dispatcher/fsm/storage/base.py @@ -24,19 +24,6 @@ class BaseStorage(ABC): Base class for all FSM storages """ - @abstractmethod - @asynccontextmanager - async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: - """ - Isolate events with lock. - Will be used as context manager - - :param bot: instance of the current bot - :param key: storage key - :return: An async generator - """ - yield None - @abstractmethod async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None: """ @@ -101,3 +88,21 @@ class BaseStorage(ABC): Close storage (database connection, file or etc.) """ pass + + +class BaseEventIsolation(ABC): + @abstractmethod + @asynccontextmanager + async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: + """ + Isolate events with lock. + Will be used as context manager + + :param bot: instance of the current bot + :param key: storage key + :return: An async generator + """ + yield None + + async def close(self) -> None: + pass diff --git a/aiogram/dispatcher/fsm/storage/memory.py b/aiogram/dispatcher/fsm/storage/memory.py index 19b43fa9..1b04397b 100644 --- a/aiogram/dispatcher/fsm/storage/memory.py +++ b/aiogram/dispatcher/fsm/storage/memory.py @@ -2,18 +2,22 @@ from asyncio import Lock from collections import defaultdict from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional +from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional from aiogram import Bot from aiogram.dispatcher.fsm.state import State -from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType, StorageKey +from aiogram.dispatcher.fsm.storage.base import ( + BaseEventIsolation, + BaseStorage, + StateType, + StorageKey, +) @dataclass class MemoryStorageRecord: data: Dict[str, Any] = field(default_factory=dict) state: Optional[str] = None - lock: Lock = field(default_factory=Lock) class MemoryStorage(BaseStorage): @@ -34,11 +38,6 @@ class MemoryStorage(BaseStorage): async def close(self) -> None: pass - @asynccontextmanager - async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: - async with self.storage[key].lock: - yield None - async def set_state(self, bot: Bot, key: StorageKey, state: StateType = None) -> None: self.storage[key].state = state.state if isinstance(state, State) else state @@ -50,3 +49,21 @@ class MemoryStorage(BaseStorage): async def get_data(self, bot: Bot, key: StorageKey) -> Dict[str, Any]: return self.storage[key].data.copy() + + +class DisabledEventIsolation(BaseEventIsolation): + @asynccontextmanager + async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: + yield + + +class SimpleEventIsolation(BaseEventIsolation): + def __init__(self) -> None: + # TODO: Unused locks cleaner is needed + self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock) + + @asynccontextmanager + async def lock(self, bot: Bot, key: StorageKey) -> AsyncGenerator[None, None]: + lock = self._locks[key] + async with lock: + yield diff --git a/aiogram/dispatcher/fsm/storage/redis.py b/aiogram/dispatcher/fsm/storage/redis.py index 8828691f..58ce50f6 100644 --- a/aiogram/dispatcher/fsm/storage/redis.py +++ b/aiogram/dispatcher/fsm/storage/redis.py @@ -6,7 +6,13 @@ from aioredis import ConnectionPool, Redis from aiogram import Bot from aiogram.dispatcher.fsm.state import State -from aiogram.dispatcher.fsm.storage.base import DEFAULT_DESTINY, BaseStorage, StateType, StorageKey +from aiogram.dispatcher.fsm.storage.base import ( + DEFAULT_DESTINY, + BaseEventIsolation, + BaseStorage, + StateType, + StorageKey, +) DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60} @@ -121,19 +127,12 @@ class RedisStorage(BaseStorage): redis = Redis(connection_pool=pool) return cls(redis=redis, **kwargs) + def create_isolation(self, **kwargs) -> "RedisEventIsolation": + return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs) + async def close(self) -> None: await self.redis.close() # type: ignore - @asynccontextmanager - async def lock( - self, - bot: Bot, - key: StorageKey, - ) -> AsyncGenerator[None, None]: - redis_key = self.key_builder.build(key, "lock") - async with self.redis.lock(name=redis_key, **self.lock_kwargs): - yield None - async def set_state( self, bot: Bot, @@ -189,3 +188,38 @@ class RedisStorage(BaseStorage): if isinstance(value, bytes): value = value.decode("utf-8") return cast(Dict[str, Any], bot.session.json_loads(value)) + + +class RedisEventIsolation(BaseEventIsolation): + def __init__( + self, + redis: Redis, + key_builder: KeyBuilder, + lock_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.redis = redis + self.key_builder = key_builder + self.lock_kwargs = lock_kwargs or {} + + @classmethod + async def from_redis( + cls, + url: str, + connection_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> "RedisEventIsolation": + if connection_kwargs is None: + connection_kwargs = {} + pool = ConnectionPool.from_url(url, **connection_kwargs) + redis = Redis(connection_pool=pool) + return cls(redis=redis, **kwargs) + + @asynccontextmanager + async def lock( + self, + bot: Bot, + key: StorageKey, + ) -> AsyncGenerator[None, None]: + redis_key = self.key_builder.build(key, "lock") + async with self.redis.lock(name=redis_key, **self.lock_kwargs): + yield None diff --git a/poetry.lock b/poetry.lock index 8abb8e68..f7e2c3d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -41,7 +41,7 @@ python-socks = {version = ">=1.0.1", extras = ["asyncio"]} [[package]] name = "aioredis" -version = "2.0.0" +version = "2.0.1" description = "asyncio (PEP 3156) Redis support" category = "main" optional = true @@ -917,7 +917,7 @@ pytest = ">=3.5" [[package]] name = "python-socks" -version = "2.0.2" +version = "2.0.3" description = "Core proxy (SOCKS4, SOCKS5, HTTP tunneling) functionality for Python" category = "main" optional = true @@ -1354,7 +1354,7 @@ redis = ["aioredis"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "678dbb9bbf362c01757842a977f7b0f0f65638a657e3599b6125f0f4e42f4b6c" +content-hash = "5b0879bc48b528901258ff901ba2b79fca747845dd51289eee1a01796390950b" [metadata.files] aiofiles = [ @@ -1440,8 +1440,8 @@ aiohttp-socks = [ {file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"}, ] aioredis = [ - {file = "aioredis-2.0.0-py3-none-any.whl", hash = "sha256:9921d68a3df5c5cdb0d5b49ad4fc88a4cfdd60c108325df4f0066e8410c55ffb"}, - {file = "aioredis-2.0.0.tar.gz", hash = "sha256:3a2de4b614e6a5f8e104238924294dc4e811aefbe17ddf52c04a93cbf06e67db"}, + {file = "aioredis-2.0.1-py3-none-any.whl", hash = "sha256:9ac0d0b3b485d293b8ca1987e6de8658d7dafcca1cddfcd1d506cae8cdebfdd6"}, + {file = "aioredis-2.0.1.tar.gz", hash = "sha256:eaa51aaf993f2d71f54b70527c440437ba65340588afeb786cd87c55c89cd98e"}, ] aiosignal = [ {file = "aiosignal-1.2.0-py3-none-any.whl", hash = "sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a"}, @@ -2033,8 +2033,8 @@ pytest-mypy = [ {file = "pytest_mypy-0.8.1-py3-none-any.whl", hash = "sha256:6e68e8eb7ceeb7d1c83a1590912f784879f037b51adfb9c17b95c6b2fc57466b"}, ] python-socks = [ - {file = "python-socks-2.0.2.tar.gz", hash = "sha256:aa9b7a53e81ae6b6e3ada602761012e470ea1c4cbcd5548f99b3fc102dce4fca"}, - {file = "python_socks-2.0.2-py3-none-any.whl", hash = "sha256:faa46857c79a8bf7def2e904ac839fb56755d7ab76c4cad12a131a85fec07241"}, + {file = "python-socks-2.0.3.tar.gz", hash = "sha256:e3a9ca8e554733862ce4d8ce1d10efb480fd3a3acdafd03393943ec00c98ba8a"}, + {file = "python_socks-2.0.3-py3-none-any.whl", hash = "sha256:950723f27d2cf401e193a9e0a0d45baab848341298f5b397d27fda0c4635e9a9"}, ] pytz = [ {file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"}, diff --git a/pyproject.toml b/pyproject.toml index b42260eb..ca994684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ Babel = { version = "^2.9.1", optional = true } # Proxy aiohttp-socks = { version = "^0.5.5", optional = true } # Redis -aioredis = { version = "^2.0.0", optional = true } +aioredis = {version = "^2.0.1", optional = true} # Docs Sphinx = { version = "^4.2.0", optional = true } sphinx-intl = { version = "^2.0.1", optional = true }