diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index 50e2412d..022f3d41 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -60,15 +60,19 @@ class TelegramEventObserver: yield filter_ registry.append(filter_) - def _resolve_inner_middlewares(self) -> List[MiddlewareType]: + def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType]: """ - Get all inner middlewares in an tree + Get all middlewares in a tree + :param *: """ middlewares = [] - for router in self.router.chain_head: + for router in reversed(list(self.router.chain_head)): observer = router.observers[self.event_name] - middlewares.extend(observer.middlewares) + if outer: + middlewares.extend(observer.outer_middlewares) + else: + middlewares.extend(observer.middlewares) return middlewares def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: @@ -131,7 +135,7 @@ class TelegramEventObserver: Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger) + wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger) return await wrapped_outer(event, kwargs) async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: @@ -141,7 +145,7 @@ class TelegramEventObserver: kwargs.update(data) try: wrapped_inner = self._wrap_middleware( - self._resolve_inner_middlewares(), handler.call + self._resolve_middlewares(), handler.call ) return await wrapped_inner(event, kwargs) except SkipHandler: @@ -163,8 +167,7 @@ class TelegramEventObserver: return wrapper def middleware( - self, - middleware: Optional[MiddlewareType] = None, + self, middleware: Optional[MiddlewareType] = None, ) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]: """ Decorator for registering inner middlewares @@ -194,8 +197,7 @@ class TelegramEventObserver: return wrapper(middleware) def outer_middleware( - self, - middleware: Optional[MiddlewareType] = None, + self, middleware: Optional[MiddlewareType] = None, ) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]: """ Decorator for registering outer middlewares diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index e03bdc4d..5bcbd9f6 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -2,6 +2,7 @@ import asyncio import datetime import time import warnings +from collections import Counter from typing import Any import pytest @@ -494,6 +495,50 @@ class TestDispatcher: assert result["event_router"] == router1 assert result["test"] == "PASS" + @pytest.mark.asyncio + async def test_nested_router_middleware_resolution(self, bot: MockedBot): + counter = Counter() + + def mw(type_: str, inject_data: dict): + async def middleware(h, event, data): + counter[type_] += 1 + data.update(inject_data) + return await h(event, data) + + return middleware + + async def handler(event, foo, bar, baz, fizz, buzz): + counter["child.handler"] += 1 + + root = Dispatcher() + child = Router() + + root.message.outer_middleware(mw("root.outer_middleware", {"foo": True})) + root.message.middleware(mw("root.middleware", {"bar": None})) + child.message.outer_middleware(mw("child.outer_middleware", {"fizz": 42})) + child.message.middleware(mw("child.middleware", {"buzz": -42})) + child.message.register(handler) + + root.include_router(child) + await root.feed_update( + bot=bot, + update=Update( + update_id=42, + message=Message( + message_id=42, + date=datetime.datetime.fromtimestamp(0), + chat=Chat(id=-42, type="group"), + ), + ), + baz=..., + ) + + assert counter["root.outer_middleware"] == 2 + assert counter["root.middleware"] == 1 + assert counter["child.outer_middleware"] == 1 + assert counter["child.middleware"] == 1 + assert counter["child.handler"] == 1 + @pytest.mark.asyncio async def test_process_update_call_request(self, bot: MockedBot): dispatcher = Dispatcher()