mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Rework middlewares, separate management to MiddlewareManager class
This commit is contained in:
parent
880dd153d3
commit
49a933d2ab
6 changed files with 93 additions and 108 deletions
61
aiogram/dispatcher/event/middleware.py
Normal file
61
aiogram/dispatcher/event/middleware.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import functools
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
|
||||
|
||||
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
|
||||
from aiogram.dispatcher.event.handler import HandlerType
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||
def __init__(self) -> None:
|
||||
self._middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
|
||||
def register(
|
||||
self,
|
||||
middleware: MiddlewareType[TelegramObject],
|
||||
) -> MiddlewareType[TelegramObject]:
|
||||
self._middlewares.append(middleware)
|
||||
return middleware
|
||||
|
||||
def unregister(self, middleware: MiddlewareType[TelegramObject]) -> None:
|
||||
self._middlewares.remove(middleware)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
if middleware is None:
|
||||
return self.register
|
||||
return self.register(middleware)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> MiddlewareType[TelegramObject]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Sequence[MiddlewareType[TelegramObject]]:
|
||||
pass
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
|
||||
return self._middlewares[item]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._middlewares)
|
||||
|
||||
@staticmethod
|
||||
def wrap_middlewares(
|
||||
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
|
||||
) -> NextMiddlewareType[MiddlewareEventType]:
|
||||
@functools.wraps(handler)
|
||||
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
||||
return handler(event, **kwargs)
|
||||
|
||||
middleware = handler_wrapper
|
||||
for m in reversed(middlewares):
|
||||
middleware = functools.partial(m, middleware)
|
||||
return middleware
|
||||
|
|
@ -1,19 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -29,6 +19,7 @@ from .bases import (
|
|||
SkipHandler,
|
||||
)
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
from .middleware import MiddlewareManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
|
@ -48,8 +39,9 @@ class TelegramEventObserver:
|
|||
|
||||
self.handlers: List[HandlerObject] = []
|
||||
self.filters: List[Type[BaseFilter]] = []
|
||||
self.outer_middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
self.middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
|
||||
self.middleware = MiddlewareManager()
|
||||
self.outer_middleware = MiddlewareManager()
|
||||
|
||||
# Re-used filters check method from already implemented handler object
|
||||
# with dummy callback which never will be used
|
||||
|
|
@ -75,7 +67,11 @@ class TelegramEventObserver:
|
|||
|
||||
:param bound_filter:
|
||||
"""
|
||||
if not issubclass(bound_filter, BaseFilter):
|
||||
# TODO: This functionality should be deprecated in the future
|
||||
# in due to bound filter has uncontrollable ordering and
|
||||
# makes debugging process is harder that explicit using filters
|
||||
|
||||
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
|
||||
raise TypeError(
|
||||
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
||||
)
|
||||
|
|
@ -97,18 +93,11 @@ class TelegramEventObserver:
|
|||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType[TelegramObject]]:
|
||||
"""
|
||||
Get all middlewares in a tree
|
||||
:param *:
|
||||
"""
|
||||
middlewares = []
|
||||
if outer:
|
||||
middlewares.extend(self.outer_middlewares)
|
||||
else:
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
|
||||
middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middleware)
|
||||
|
||||
return middlewares
|
||||
|
||||
|
|
@ -214,7 +203,10 @@ class TelegramEventObserver:
|
|||
def wrap_outer_middleware(
|
||||
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
|
||||
wrapped_outer = self.middleware.wrap_middlewares(
|
||||
self.outer_middleware,
|
||||
callback,
|
||||
)
|
||||
return wrapped_outer(event, data)
|
||||
|
||||
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
|
|
@ -233,8 +225,9 @@ class TelegramEventObserver:
|
|||
if result:
|
||||
kwargs.update(data, handler=handler)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(
|
||||
self._resolve_middlewares(), handler.call
|
||||
wrapped_inner = self.outer_middleware.wrap_middlewares(
|
||||
self._resolve_middlewares(),
|
||||
handler.call,
|
||||
)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
|
|
@ -254,71 +247,3 @@ class TelegramEventObserver:
|
|||
return callback
|
||||
|
||||
return wrapper
|
||||
|
||||
def middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
"""
|
||||
Decorator for registering inner middlewares
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.middleware() # via decorator (variant 1)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.middleware # via decorator (variant 2)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def my_middleware(handler, event, data): ...
|
||||
<event>.middleware(my_middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
|
||||
self.middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
||||
def outer_middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
"""
|
||||
Decorator for registering outer middlewares
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.outer_middleware() # via decorator (variant 1)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.outer_middleware # via decorator (variant 2)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def my_middleware(handler, event, data): ...
|
||||
<event>.outer_middleware(my_middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
|
||||
self.outer_middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class TestDispatcher:
|
|||
|
||||
assert dp.update.handlers
|
||||
assert dp.update.handlers[0].callback == dp._listen_update
|
||||
assert dp.update.outer_middlewares
|
||||
assert dp.update.outer_middleware
|
||||
|
||||
def test_parent_router(self, dispatcher: Dispatcher):
|
||||
with pytest.raises(RuntimeError):
|
||||
|
|
|
|||
|
|
@ -297,10 +297,9 @@ class TestTelegramEventObserver:
|
|||
def test_register_middleware(self, middleware_type):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
||||
middlewares = getattr(event_observer, f"{middleware_type}s")
|
||||
decorator = getattr(event_observer, middleware_type)
|
||||
middlewares = getattr(event_observer, middleware_type)
|
||||
|
||||
@decorator
|
||||
@middlewares
|
||||
async def my_middleware1(handler, event, data):
|
||||
pass
|
||||
|
||||
|
|
@ -308,7 +307,7 @@ class TestTelegramEventObserver:
|
|||
assert my_middleware1.__name__ == "my_middleware1"
|
||||
assert my_middleware1 in middlewares
|
||||
|
||||
@decorator()
|
||||
@middlewares()
|
||||
async def my_middleware2(handler, event, data):
|
||||
pass
|
||||
|
||||
|
|
@ -319,13 +318,13 @@ class TestTelegramEventObserver:
|
|||
async def my_middleware3(handler, event, data):
|
||||
pass
|
||||
|
||||
decorator(my_middleware3)
|
||||
middlewares(my_middleware3)
|
||||
|
||||
assert my_middleware3 is not None
|
||||
assert my_middleware3.__name__ == "my_middleware3"
|
||||
assert my_middleware3 in middlewares
|
||||
|
||||
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]
|
||||
assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3]
|
||||
|
||||
def test_register_global_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import re
|
|||
|
||||
import pytest
|
||||
|
||||
from aiogram import Dispatcher, F
|
||||
from aiogram import Dispatcher
|
||||
from aiogram.dispatcher.filters import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from aiogram.types import Update
|
||||
|
||||
|
|
|
|||
|
|
@ -111,8 +111,8 @@ class TestSimpleI18nMiddleware:
|
|||
middleware = SimpleI18nMiddleware(i18n=i18n)
|
||||
middleware.setup(router=dp)
|
||||
|
||||
assert middleware not in dp.update.outer_middlewares
|
||||
assert middleware in dp.message.outer_middlewares
|
||||
assert middleware not in dp.update.outer_middleware
|
||||
assert middleware in dp.message.outer_middleware
|
||||
|
||||
async def test_get_unknown_locale(self, i18n: I18n):
|
||||
dp = Dispatcher()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue