Rework middlewares, separate management to MiddlewareManager class

This commit is contained in:
Alex Root Junior 2022-02-24 01:52:42 +02:00
parent 880dd153d3
commit 49a933d2ab
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
6 changed files with 93 additions and 108 deletions

View file

@ -0,0 +1,61 @@
import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
from aiogram.dispatcher.event.handler import HandlerType
from aiogram.types import TelegramObject
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __init__(self) -> None:
self._middlewares: List[MiddlewareType[TelegramObject]] = []
def register(
self,
middleware: MiddlewareType[TelegramObject],
) -> MiddlewareType[TelegramObject]:
self._middlewares.append(middleware)
return middleware
def unregister(self, middleware: MiddlewareType[TelegramObject]) -> None:
self._middlewares.remove(middleware)
def __call__(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
if middleware is None:
return self.register
return self.register(middleware)
@overload
def __getitem__(self, item: int) -> MiddlewareType[TelegramObject]:
pass
@overload
def __getitem__(self, item: slice) -> Sequence[MiddlewareType[TelegramObject]]:
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
return self._middlewares[item]
def __len__(self) -> int:
return len(self._middlewares)
@staticmethod
def wrap_middlewares(
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler)
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = handler_wrapper
for m in reversed(middlewares):
middleware = functools.partial(m, middleware)
return middleware

View file

@ -1,19 +1,9 @@
from __future__ import annotations
import functools
from inspect import isclass
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
from pydantic import ValidationError
@ -29,6 +19,7 @@ from .bases import (
SkipHandler,
)
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
from .middleware import MiddlewareManager
if TYPE_CHECKING:
from aiogram.dispatcher.router import Router
@ -48,8 +39,9 @@ class TelegramEventObserver:
self.handlers: List[HandlerObject] = []
self.filters: List[Type[BaseFilter]] = []
self.outer_middlewares: List[MiddlewareType[TelegramObject]] = []
self.middlewares: List[MiddlewareType[TelegramObject]] = []
self.middleware = MiddlewareManager()
self.outer_middleware = MiddlewareManager()
# Re-used filters check method from already implemented handler object
# with dummy callback which never will be used
@ -75,7 +67,11 @@ class TelegramEventObserver:
:param bound_filter:
"""
if not issubclass(bound_filter, BaseFilter):
# TODO: This functionality should be deprecated in the future
# in due to bound filter has uncontrollable ordering and
# makes debugging process is harder that explicit using filters
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
raise TypeError(
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
)
@ -97,18 +93,11 @@ class TelegramEventObserver:
yield filter_
registry.append(filter_)
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType[TelegramObject]]:
"""
Get all middlewares in a tree
:param *:
"""
middlewares = []
if outer:
middlewares.extend(self.outer_middlewares)
else:
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares)
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
middlewares: List[MiddlewareType[TelegramObject]] = []
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middleware)
return middlewares
@ -214,7 +203,10 @@ class TelegramEventObserver:
def wrap_outer_middleware(
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
) -> Any:
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
wrapped_outer = self.middleware.wrap_middlewares(
self.outer_middleware,
callback,
)
return wrapped_outer(event, data)
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
@ -233,8 +225,9 @@ class TelegramEventObserver:
if result:
kwargs.update(data, handler=handler)
try:
wrapped_inner = self._wrap_middleware(
self._resolve_middlewares(), handler.call
wrapped_inner = self.outer_middleware.wrap_middlewares(
self._resolve_middlewares(),
handler.call,
)
return await wrapped_inner(event, kwargs)
except SkipHandler:
@ -254,71 +247,3 @@ class TelegramEventObserver:
return callback
return wrapper
def middleware(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
"""
Decorator for registering inner middlewares
Usage:
.. code-block:: python
@<event>.middleware() # via decorator (variant 1)
.. code-block:: python
@<event>.middleware # via decorator (variant 2)
.. code-block:: python
async def my_middleware(handler, event, data): ...
<event>.middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
self.middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)
def outer_middleware(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
"""
Decorator for registering outer middlewares
Usage:
.. code-block:: python
@<event>.outer_middleware() # via decorator (variant 1)
.. code-block:: python
@<event>.outer_middleware # via decorator (variant 2)
.. code-block:: python
async def my_middleware(handler, event, data): ...
<event>.outer_middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
self.outer_middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)

View file

@ -74,7 +74,7 @@ class TestDispatcher:
assert dp.update.handlers
assert dp.update.handlers[0].callback == dp._listen_update
assert dp.update.outer_middlewares
assert dp.update.outer_middleware
def test_parent_router(self, dispatcher: Dispatcher):
with pytest.raises(RuntimeError):

View file

@ -297,10 +297,9 @@ class TestTelegramEventObserver:
def test_register_middleware(self, middleware_type):
event_observer = TelegramEventObserver(Router(), "test")
middlewares = getattr(event_observer, f"{middleware_type}s")
decorator = getattr(event_observer, middleware_type)
middlewares = getattr(event_observer, middleware_type)
@decorator
@middlewares
async def my_middleware1(handler, event, data):
pass
@ -308,7 +307,7 @@ class TestTelegramEventObserver:
assert my_middleware1.__name__ == "my_middleware1"
assert my_middleware1 in middlewares
@decorator()
@middlewares()
async def my_middleware2(handler, event, data):
pass
@ -319,13 +318,13 @@ class TestTelegramEventObserver:
async def my_middleware3(handler, event, data):
pass
decorator(my_middleware3)
middlewares(my_middleware3)
assert my_middleware3 is not None
assert my_middleware3.__name__ == "my_middleware3"
assert my_middleware3 in middlewares
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]
assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3]
def test_register_global_filters(self):
router = Router(use_builtin_filters=False)

View file

@ -2,7 +2,7 @@ import re
import pytest
from aiogram import Dispatcher, F
from aiogram import Dispatcher
from aiogram.dispatcher.filters import ExceptionMessageFilter, ExceptionTypeFilter
from aiogram.types import Update

View file

@ -111,8 +111,8 @@ class TestSimpleI18nMiddleware:
middleware = SimpleI18nMiddleware(i18n=i18n)
middleware.setup(router=dp)
assert middleware not in dp.update.outer_middlewares
assert middleware in dp.message.outer_middlewares
assert middleware not in dp.update.outer_middleware
assert middleware in dp.message.outer_middleware
async def test_get_unknown_locale(self, i18n: I18n):
dp = Dispatcher()