From d30cdff108a749f8962a364c50c08abfe9a35ad9 Mon Sep 17 00:00:00 2001 From: evgfilim1 Date: Sat, 24 Jul 2021 13:30:35 +0500 Subject: [PATCH] Fix outer_middleware resolution (#637) --- aiogram/dispatcher/event/telegram.py | 14 +++++--- tests/test_dispatcher/test_dispatcher.py | 45 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index 50e2412d..acd6a27b 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -60,15 +60,19 @@ class TelegramEventObserver: yield filter_ registry.append(filter_) - def _resolve_inner_middlewares(self) -> List[MiddlewareType]: + def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType]: """ - Get all inner middlewares in an tree + Get all middlewares in a tree + :param *: """ middlewares = [] for router in self.router.chain_head: observer = router.observers[self.event_name] - middlewares.extend(observer.middlewares) + if outer: + middlewares.extend(observer.outer_middlewares) + else: + middlewares.extend(observer.middlewares) return middlewares def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: @@ -131,7 +135,7 @@ class TelegramEventObserver: Propagate event to handlers and stops propagation on first match. Handler will be called when all its filters is pass. """ - wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger) + wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger) return await wrapped_outer(event, kwargs) async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: @@ -141,7 +145,7 @@ class TelegramEventObserver: kwargs.update(data) try: wrapped_inner = self._wrap_middleware( - self._resolve_inner_middlewares(), handler.call + self._resolve_middlewares(), handler.call ) return await wrapped_inner(event, kwargs) except SkipHandler: diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index e03bdc4d..a7e43d20 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -2,6 +2,7 @@ import asyncio import datetime import time import warnings +from collections import Counter from typing import Any import pytest @@ -494,6 +495,50 @@ class TestDispatcher: assert result["event_router"] == router1 assert result["test"] == "PASS" + @pytest.mark.asyncio + async def test_nested_router_middleware_resolution(self, bot: MockedBot): + counter = Counter() + + def mw(type_: str, inject_data: dict): + async def middleware(h, event, data): + counter[type_] += 1 + data.update(inject_data) + return await h(event, data) + + return middleware + + async def handler(event, foo, bar, baz, fizz, buzz): + counter['child.handler'] += 1 + + root = Dispatcher() + child = Router() + + root.message.outer_middleware(mw('root.outer_middleware', {'foo': True})) + root.message.middleware(mw('root.middleware', {'bar': None})) + child.message.outer_middleware(mw('child.outer_middleware', {'fizz': 42})) + child.message.middleware(mw('child.middleware', {'buzz': -42})) + child.message.register(handler) + + root.include_router(child) + await root.feed_update( + bot=bot, + update=Update( + update_id=42, + message=Message( + message_id=42, + date=datetime.datetime.fromtimestamp(0), + chat=Chat(id=-42, type='group'), + ), + ), + baz=..., + ) + + assert counter['root.outer_middleware'] == 2 + assert counter['root.middleware'] == 1 + assert counter['child.outer_middleware'] == 1 + assert counter['child.middleware'] == 1 + assert counter['child.handler'] == 1 + @pytest.mark.asyncio async def test_process_update_call_request(self, bot: MockedBot): dispatcher = Dispatcher()