Fix outer_middleware resolution (#637)

This commit is contained in:
evgfilim1 2021-07-24 13:30:35 +05:00
parent 16eac75332
commit d30cdff108
No known key found for this signature in database
GPG key ID: 16AEE4D0BB188AEC
2 changed files with 54 additions and 5 deletions

View file

@ -60,15 +60,19 @@ class TelegramEventObserver:
yield filter_
registry.append(filter_)
def _resolve_inner_middlewares(self) -> List[MiddlewareType]:
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType]:
"""
Get all inner middlewares in an tree
Get all middlewares in a tree
:param *:
"""
middlewares = []
for router in self.router.chain_head:
observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares)
if outer:
middlewares.extend(observer.outer_middlewares)
else:
middlewares.extend(observer.middlewares)
return middlewares
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
@ -131,7 +135,7 @@ class TelegramEventObserver:
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger)
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger)
return await wrapped_outer(event, kwargs)
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
@ -141,7 +145,7 @@ class TelegramEventObserver:
kwargs.update(data)
try:
wrapped_inner = self._wrap_middleware(
self._resolve_inner_middlewares(), handler.call
self._resolve_middlewares(), handler.call
)
return await wrapped_inner(event, kwargs)
except SkipHandler:

View file

@ -2,6 +2,7 @@ import asyncio
import datetime
import time
import warnings
from collections import Counter
from typing import Any
import pytest
@ -494,6 +495,50 @@ class TestDispatcher:
assert result["event_router"] == router1
assert result["test"] == "PASS"
@pytest.mark.asyncio
async def test_nested_router_middleware_resolution(self, bot: MockedBot):
counter = Counter()
def mw(type_: str, inject_data: dict):
async def middleware(h, event, data):
counter[type_] += 1
data.update(inject_data)
return await h(event, data)
return middleware
async def handler(event, foo, bar, baz, fizz, buzz):
counter['child.handler'] += 1
root = Dispatcher()
child = Router()
root.message.outer_middleware(mw('root.outer_middleware', {'foo': True}))
root.message.middleware(mw('root.middleware', {'bar': None}))
child.message.outer_middleware(mw('child.outer_middleware', {'fizz': 42}))
child.message.middleware(mw('child.middleware', {'buzz': -42}))
child.message.register(handler)
root.include_router(child)
await root.feed_update(
bot=bot,
update=Update(
update_id=42,
message=Message(
message_id=42,
date=datetime.datetime.fromtimestamp(0),
chat=Chat(id=-42, type='group'),
),
),
baz=...,
)
assert counter['root.outer_middleware'] == 2
assert counter['root.middleware'] == 1
assert counter['child.outer_middleware'] == 1
assert counter['child.middleware'] == 1
assert counter['child.handler'] == 1
@pytest.mark.asyncio
async def test_process_update_call_request(self, bot: MockedBot):
dispatcher = Dispatcher()