From 49a933d2ab036708b8ca5400b876f18586dd596e Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Thu, 24 Feb 2022 01:52:42 +0200 Subject: [PATCH] Rework middlewares, separate management to `MiddlewareManager` class --- aiogram/dispatcher/event/middleware.py | 61 +++++++++ aiogram/dispatcher/event/telegram.py | 121 ++++-------------- tests/test_dispatcher/test_dispatcher.py | 2 +- .../test_event/test_telegram.py | 11 +- .../test_filters/test_exception.py | 2 +- tests/test_utils/test_i18n.py | 4 +- 6 files changed, 93 insertions(+), 108 deletions(-) create mode 100644 aiogram/dispatcher/event/middleware.py diff --git a/aiogram/dispatcher/event/middleware.py b/aiogram/dispatcher/event/middleware.py new file mode 100644 index 00000000..89892e77 --- /dev/null +++ b/aiogram/dispatcher/event/middleware.py @@ -0,0 +1,61 @@ +import functools +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload + +from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType +from aiogram.dispatcher.event.handler import HandlerType +from aiogram.types import TelegramObject + + +class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]): + def __init__(self) -> None: + self._middlewares: List[MiddlewareType[TelegramObject]] = [] + + def register( + self, + middleware: MiddlewareType[TelegramObject], + ) -> MiddlewareType[TelegramObject]: + self._middlewares.append(middleware) + return middleware + + def unregister(self, middleware: MiddlewareType[TelegramObject]) -> None: + self._middlewares.remove(middleware) + + def __call__( + self, + middleware: Optional[MiddlewareType[TelegramObject]] = None, + ) -> Union[ + Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]], + MiddlewareType[TelegramObject], + ]: + if middleware is None: + return self.register + return self.register(middleware) + + @overload + def __getitem__(self, item: int) -> MiddlewareType[TelegramObject]: + pass + + @overload + def __getitem__(self, item: slice) -> Sequence[MiddlewareType[TelegramObject]]: + pass + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]: + return self._middlewares[item] + + def __len__(self) -> int: + return len(self._middlewares) + + @staticmethod + def wrap_middlewares( + middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType + ) -> NextMiddlewareType[MiddlewareEventType]: + @functools.wraps(handler) + def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any: + return handler(event, **kwargs) + + middleware = handler_wrapper + for m in reversed(middlewares): + middleware = functools.partial(m, middleware) + return middleware diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index a08d7cb1..523f3374 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -1,19 +1,9 @@ from __future__ import annotations import functools +from inspect import isclass from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - List, - Optional, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type from pydantic import ValidationError @@ -29,6 +19,7 @@ from .bases import ( SkipHandler, ) from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType +from .middleware import MiddlewareManager if TYPE_CHECKING: from aiogram.dispatcher.router import Router @@ -48,8 +39,9 @@ class TelegramEventObserver: self.handlers: List[HandlerObject] = [] self.filters: List[Type[BaseFilter]] = [] - self.outer_middlewares: List[MiddlewareType[TelegramObject]] = [] - self.middlewares: List[MiddlewareType[TelegramObject]] = [] + + self.middleware = MiddlewareManager() + self.outer_middleware = MiddlewareManager() # Re-used filters check method from already implemented handler object # with dummy callback which never will be used @@ -75,7 +67,11 @@ class TelegramEventObserver: :param bound_filter: """ - if not issubclass(bound_filter, BaseFilter): + # TODO: This functionality should be deprecated in the future + # in due to bound filter has uncontrollable ordering and + # makes debugging process is harder that explicit using filters + + if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter): raise TypeError( "bound_filter() argument 'bound_filter' must be subclass of BaseFilter" ) @@ -97,18 +93,11 @@ class TelegramEventObserver: yield filter_ registry.append(filter_) - def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType[TelegramObject]]: - """ - Get all middlewares in a tree - :param *: - """ - middlewares = [] - if outer: - middlewares.extend(self.outer_middlewares) - else: - for router in reversed(tuple(self.router.chain_head)): - observer = router.observers[self.event_name] - middlewares.extend(observer.middlewares) + def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]: + middlewares: List[MiddlewareType[TelegramObject]] = [] + for router in reversed(tuple(self.router.chain_head)): + observer = router.observers[self.event_name] + middlewares.extend(observer.middleware) return middlewares @@ -214,7 +203,10 @@ class TelegramEventObserver: def wrap_outer_middleware( self, callback: Any, event: TelegramObject, data: Dict[str, Any] ) -> Any: - wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback) + wrapped_outer = self.middleware.wrap_middlewares( + self.outer_middleware, + callback, + ) return wrapped_outer(event, data) async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any: @@ -233,8 +225,9 @@ class TelegramEventObserver: if result: kwargs.update(data, handler=handler) try: - wrapped_inner = self._wrap_middleware( - self._resolve_middlewares(), handler.call + wrapped_inner = self.outer_middleware.wrap_middlewares( + self._resolve_middlewares(), + handler.call, ) return await wrapped_inner(event, kwargs) except SkipHandler: @@ -254,71 +247,3 @@ class TelegramEventObserver: return callback return wrapper - - def middleware( - self, - middleware: Optional[MiddlewareType[TelegramObject]] = None, - ) -> Union[ - Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]], - MiddlewareType[TelegramObject], - ]: - """ - Decorator for registering inner middlewares - - Usage: - - .. code-block:: python - - @.middleware() # via decorator (variant 1) - - .. code-block:: python - - @.middleware # via decorator (variant 2) - - .. code-block:: python - - async def my_middleware(handler, event, data): ... - .middleware(my_middleware) # via method - """ - - def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]: - self.middlewares.append(m) - return m - - if middleware is None: - return wrapper - return wrapper(middleware) - - def outer_middleware( - self, - middleware: Optional[MiddlewareType[TelegramObject]] = None, - ) -> Union[ - Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]], - MiddlewareType[TelegramObject], - ]: - """ - Decorator for registering outer middlewares - - Usage: - - .. code-block:: python - - @.outer_middleware() # via decorator (variant 1) - - .. code-block:: python - - @.outer_middleware # via decorator (variant 2) - - .. code-block:: python - - async def my_middleware(handler, event, data): ... - .outer_middleware(my_middleware) # via method - """ - - def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]: - self.outer_middlewares.append(m) - return m - - if middleware is None: - return wrapper - return wrapper(middleware) diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 1150f073..c3621193 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -74,7 +74,7 @@ class TestDispatcher: assert dp.update.handlers assert dp.update.handlers[0].callback == dp._listen_update - assert dp.update.outer_middlewares + assert dp.update.outer_middleware def test_parent_router(self, dispatcher: Dispatcher): with pytest.raises(RuntimeError): diff --git a/tests/test_dispatcher/test_event/test_telegram.py b/tests/test_dispatcher/test_event/test_telegram.py index a30eb3a9..3a569010 100644 --- a/tests/test_dispatcher/test_event/test_telegram.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -297,10 +297,9 @@ class TestTelegramEventObserver: def test_register_middleware(self, middleware_type): event_observer = TelegramEventObserver(Router(), "test") - middlewares = getattr(event_observer, f"{middleware_type}s") - decorator = getattr(event_observer, middleware_type) + middlewares = getattr(event_observer, middleware_type) - @decorator + @middlewares async def my_middleware1(handler, event, data): pass @@ -308,7 +307,7 @@ class TestTelegramEventObserver: assert my_middleware1.__name__ == "my_middleware1" assert my_middleware1 in middlewares - @decorator() + @middlewares() async def my_middleware2(handler, event, data): pass @@ -319,13 +318,13 @@ class TestTelegramEventObserver: async def my_middleware3(handler, event, data): pass - decorator(my_middleware3) + middlewares(my_middleware3) assert my_middleware3 is not None assert my_middleware3.__name__ == "my_middleware3" assert my_middleware3 in middlewares - assert middlewares == [my_middleware1, my_middleware2, my_middleware3] + assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3] def test_register_global_filters(self): router = Router(use_builtin_filters=False) diff --git a/tests/test_dispatcher/test_filters/test_exception.py b/tests/test_dispatcher/test_filters/test_exception.py index c1ffb6d8..1fcf7015 100644 --- a/tests/test_dispatcher/test_filters/test_exception.py +++ b/tests/test_dispatcher/test_filters/test_exception.py @@ -2,7 +2,7 @@ import re import pytest -from aiogram import Dispatcher, F +from aiogram import Dispatcher from aiogram.dispatcher.filters import ExceptionMessageFilter, ExceptionTypeFilter from aiogram.types import Update diff --git a/tests/test_utils/test_i18n.py b/tests/test_utils/test_i18n.py index 31843080..a4381c45 100644 --- a/tests/test_utils/test_i18n.py +++ b/tests/test_utils/test_i18n.py @@ -111,8 +111,8 @@ class TestSimpleI18nMiddleware: middleware = SimpleI18nMiddleware(i18n=i18n) middleware.setup(router=dp) - assert middleware not in dp.update.outer_middlewares - assert middleware in dp.message.outer_middlewares + assert middleware not in dp.update.outer_middleware + assert middleware in dp.message.outer_middleware async def test_get_unknown_locale(self, i18n: I18n): dp = Dispatcher()