mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge remote-tracking branch 'origin/dev-3.x-api-5.3' into dev-3.x-api-5.3
This commit is contained in:
commit
7d56751b09
2 changed files with 57 additions and 10 deletions
|
|
@ -60,15 +60,19 @@ class TelegramEventObserver:
|
|||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def _resolve_inner_middlewares(self) -> List[MiddlewareType]:
|
||||
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType]:
|
||||
"""
|
||||
Get all inner middlewares in an tree
|
||||
Get all middlewares in a tree
|
||||
:param *:
|
||||
"""
|
||||
middlewares = []
|
||||
|
||||
for router in self.router.chain_head:
|
||||
for router in reversed(list(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
if outer:
|
||||
middlewares.extend(observer.outer_middlewares)
|
||||
else:
|
||||
middlewares.extend(observer.middlewares)
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
|
|
@ -131,7 +135,7 @@ class TelegramEventObserver:
|
|||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
wrapped_outer = self._wrap_middleware(self.outer_middlewares, self._trigger)
|
||||
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), self._trigger)
|
||||
return await wrapped_outer(event, kwargs)
|
||||
|
||||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
|
|
@ -141,7 +145,7 @@ class TelegramEventObserver:
|
|||
kwargs.update(data)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(
|
||||
self._resolve_inner_middlewares(), handler.call
|
||||
self._resolve_middlewares(), handler.call
|
||||
)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
|
|
@ -163,8 +167,7 @@ class TelegramEventObserver:
|
|||
return wrapper
|
||||
|
||||
def middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType] = None,
|
||||
self, middleware: Optional[MiddlewareType] = None,
|
||||
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
|
||||
"""
|
||||
Decorator for registering inner middlewares
|
||||
|
|
@ -194,8 +197,7 @@ class TelegramEventObserver:
|
|||
return wrapper(middleware)
|
||||
|
||||
def outer_middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType] = None,
|
||||
self, middleware: Optional[MiddlewareType] = None,
|
||||
) -> Union[Callable[[MiddlewareType], MiddlewareType], MiddlewareType]:
|
||||
"""
|
||||
Decorator for registering outer middlewares
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import datetime
|
||||
import time
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
|
@ -494,6 +495,50 @@ class TestDispatcher:
|
|||
assert result["event_router"] == router1
|
||||
assert result["test"] == "PASS"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_router_middleware_resolution(self, bot: MockedBot):
|
||||
counter = Counter()
|
||||
|
||||
def mw(type_: str, inject_data: dict):
|
||||
async def middleware(h, event, data):
|
||||
counter[type_] += 1
|
||||
data.update(inject_data)
|
||||
return await h(event, data)
|
||||
|
||||
return middleware
|
||||
|
||||
async def handler(event, foo, bar, baz, fizz, buzz):
|
||||
counter["child.handler"] += 1
|
||||
|
||||
root = Dispatcher()
|
||||
child = Router()
|
||||
|
||||
root.message.outer_middleware(mw("root.outer_middleware", {"foo": True}))
|
||||
root.message.middleware(mw("root.middleware", {"bar": None}))
|
||||
child.message.outer_middleware(mw("child.outer_middleware", {"fizz": 42}))
|
||||
child.message.middleware(mw("child.middleware", {"buzz": -42}))
|
||||
child.message.register(handler)
|
||||
|
||||
root.include_router(child)
|
||||
await root.feed_update(
|
||||
bot=bot,
|
||||
update=Update(
|
||||
update_id=42,
|
||||
message=Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.fromtimestamp(0),
|
||||
chat=Chat(id=-42, type="group"),
|
||||
),
|
||||
),
|
||||
baz=...,
|
||||
)
|
||||
|
||||
assert counter["root.outer_middleware"] == 2
|
||||
assert counter["root.middleware"] == 1
|
||||
assert counter["child.outer_middleware"] == 1
|
||||
assert counter["child.middleware"] == 1
|
||||
assert counter["child.handler"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_update_call_request(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue