mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Refactor dispatcher. Update docs. Fix tests.
This commit is contained in:
parent
9539c4e2fd
commit
d3b488e16d
290 changed files with 2153 additions and 1780 deletions
|
|
@ -7,11 +7,13 @@ from asyncio import CancelledError, Future, Lock
|
|||
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
||||
|
||||
from .. import loggers
|
||||
from aiogram.client.bot import Bot
|
||||
from ..client.bot import Bot
|
||||
from ..methods import TelegramMethod
|
||||
from ..types import Update, User
|
||||
from ..utils.exceptions import TelegramAPIError
|
||||
from .event.bases import NOT_HANDLED
|
||||
from .event.bases import UNHANDLED, SkipHandler
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .middlewares.error import ErrorsMiddleware
|
||||
from .middlewares.user_context import UserContextMiddleware
|
||||
from .router import Router
|
||||
|
||||
|
|
@ -23,10 +25,15 @@ class Dispatcher(Router):
|
|||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super(Dispatcher, self).__init__(**kwargs)
|
||||
self._running_lock = Lock()
|
||||
|
||||
# Default middleware is needed for contextual features
|
||||
self.update = TelegramEventObserver(router=self, event_name="update")
|
||||
self.observers["update"] = self.update
|
||||
|
||||
self.update.register(self._listen_update)
|
||||
self.update.outer_middleware(UserContextMiddleware())
|
||||
self.update.outer_middleware(ErrorsMiddleware(self))
|
||||
|
||||
self._running_lock = Lock()
|
||||
|
||||
@property
|
||||
def parent_router(self) -> None:
|
||||
|
|
@ -61,7 +68,7 @@ class Dispatcher(Router):
|
|||
Bot.set_current(bot)
|
||||
try:
|
||||
response = await self.update.trigger(update, bot=bot, **kwargs)
|
||||
handled = response is not NOT_HANDLED
|
||||
handled = response is not UNHANDLED
|
||||
return response
|
||||
finally:
|
||||
finish_time = loop.time()
|
||||
|
|
@ -97,6 +104,74 @@ class Dispatcher(Router):
|
|||
yield update
|
||||
update_id = update.update_id + 1
|
||||
|
||||
async def _listen_update(self, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Main updates listener
|
||||
|
||||
Workflow:
|
||||
- Detect content type and propagate to observers in current router
|
||||
- If no one filter is pass - propagate update to child routers as Update
|
||||
|
||||
:param update:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
event: TelegramObject
|
||||
if update.message:
|
||||
update_type = "message"
|
||||
event = update.message
|
||||
elif update.edited_message:
|
||||
update_type = "edited_message"
|
||||
event = update.edited_message
|
||||
elif update.channel_post:
|
||||
update_type = "channel_post"
|
||||
event = update.channel_post
|
||||
elif update.edited_channel_post:
|
||||
update_type = "edited_channel_post"
|
||||
event = update.edited_channel_post
|
||||
elif update.inline_query:
|
||||
update_type = "inline_query"
|
||||
event = update.inline_query
|
||||
elif update.chosen_inline_result:
|
||||
update_type = "chosen_inline_result"
|
||||
event = update.chosen_inline_result
|
||||
elif update.callback_query:
|
||||
update_type = "callback_query"
|
||||
event = update.callback_query
|
||||
elif update.shipping_query:
|
||||
update_type = "shipping_query"
|
||||
event = update.shipping_query
|
||||
elif update.pre_checkout_query:
|
||||
update_type = "pre_checkout_query"
|
||||
event = update.pre_checkout_query
|
||||
elif update.poll:
|
||||
update_type = "poll"
|
||||
event = update.poll
|
||||
elif update.poll_answer:
|
||||
update_type = "poll_answer"
|
||||
event = update.poll_answer
|
||||
else:
|
||||
warnings.warn(
|
||||
"Detected unknown update type.\n"
|
||||
"Seems like Telegram Bot API was updated and you have "
|
||||
"installed not latest version of aiogram framework",
|
||||
RuntimeWarning,
|
||||
)
|
||||
raise SkipHandler
|
||||
|
||||
kwargs.update(event_update=update)
|
||||
|
||||
for router in self.chain:
|
||||
kwargs.update(event_router=router)
|
||||
observer = router.observers[update_type]
|
||||
response = await observer.trigger(event, update=update, **kwargs)
|
||||
if response is not UNHANDLED:
|
||||
break
|
||||
else:
|
||||
response = UNHANDLED
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
async def _silent_call_request(cls, bot: Bot, result: TelegramMethod[Any]) -> None:
|
||||
"""
|
||||
|
|
@ -129,7 +204,7 @@ class Dispatcher(Router):
|
|||
handled = False
|
||||
try:
|
||||
response = await self.feed_update(bot, update, **kwargs)
|
||||
handled = handled is not NOT_HANDLED
|
||||
handled = handled is not UNHANDLED
|
||||
if call_answer and isinstance(response, TelegramMethod):
|
||||
await self._silent_call_request(bot=bot, result=response)
|
||||
return handled
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ MiddlewareType = Union[
|
|||
BaseMiddleware, Callable[[NextMiddlewareType, TelegramObject, Dict[str, Any]], Awaitable[Any]]
|
||||
]
|
||||
|
||||
NOT_HANDLED = sentinel.NOT_HANDLED
|
||||
UNHANDLED = sentinel.UNHANDLED
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from pydantic import ValidationError
|
|||
|
||||
from ...types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .bases import UNHANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
|
|
@ -60,6 +60,17 @@ class TelegramEventObserver:
|
|||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def _resolve_inner_middlewares(self) -> List[MiddlewareType]:
|
||||
"""
|
||||
Get all inner middlewares in an tree
|
||||
"""
|
||||
middlewares = []
|
||||
|
||||
for router in self.router.chain_head:
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
"""
|
||||
Resolve keyword filters via filters factory
|
||||
|
|
@ -129,12 +140,14 @@ class TelegramEventObserver:
|
|||
if result:
|
||||
kwargs.update(data)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
|
||||
wrapped_inner = self._wrap_middleware(
|
||||
self._resolve_inner_middlewares(), handler.call
|
||||
)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
return NOT_HANDLED
|
||||
return UNHANDLED
|
||||
|
||||
def __call__(
|
||||
self, *args: FilterType, **bound_filters: BaseFilter
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ __all__ = (
|
|||
)
|
||||
|
||||
BUILTIN_FILTERS: Dict[str, Tuple[Type[BaseFilter], ...]] = {
|
||||
"update": (),
|
||||
"message": (Text, Command, ContentTypesFilter),
|
||||
"edited_message": (Text, Command, ContentTypesFilter),
|
||||
"channel_post": (Text, ContentTypesFilter),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
|
||||
|
||||
from ...types import Update
|
||||
from ..event.bases import NOT_HANDLED, CancelHandler, SkipHandler
|
||||
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
|
||||
from .base import BaseMiddleware
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
|
|
@ -25,7 +25,9 @@ class ErrorsMiddleware(BaseMiddleware[Update]):
|
|||
except (SkipHandler, CancelHandler): # pragma: no cover
|
||||
raise
|
||||
except Exception as e:
|
||||
response = await self.router.errors.trigger(event, exception=e, **data)
|
||||
if response is NOT_HANDLED:
|
||||
raise
|
||||
return response
|
||||
for router in self.router.chain:
|
||||
observer = router.observers["error"]
|
||||
response = await observer.trigger(event, exception=e, **data)
|
||||
if response is not UNHANDLED:
|
||||
return response
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|||
from ..types import TelegramObject, Update
|
||||
from ..utils.imports import import_module
|
||||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import NOT_HANDLED, SkipHandler
|
||||
from .event.bases import UNHANDLED, SkipHandler
|
||||
from .event.event import EventObserver
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
from .middlewares.error import ErrorsMiddleware
|
||||
|
||||
|
||||
class Router:
|
||||
|
|
@ -37,7 +36,6 @@ class Router:
|
|||
self.sub_routers: List[Router] = []
|
||||
|
||||
# Observers
|
||||
self.update = TelegramEventObserver(router=self, event_name="update")
|
||||
self.message = TelegramEventObserver(router=self, event_name="message")
|
||||
self.edited_message = TelegramEventObserver(router=self, event_name="edited_message")
|
||||
self.channel_post = TelegramEventObserver(router=self, event_name="channel_post")
|
||||
|
|
@ -61,7 +59,6 @@ class Router:
|
|||
self.shutdown = EventObserver()
|
||||
|
||||
self.observers: Dict[str, TelegramEventObserver] = {
|
||||
"update": self.update,
|
||||
"message": self.message,
|
||||
"edited_message": self.edited_message,
|
||||
"channel_post": self.channel_post,
|
||||
|
|
@ -76,11 +73,6 @@ class Router:
|
|||
"error": self.errors,
|
||||
}
|
||||
|
||||
# Root handler
|
||||
self.update.register(self._listen_update)
|
||||
|
||||
self.update.outer_middleware(ErrorsMiddleware(self))
|
||||
|
||||
# Builtin filters
|
||||
if use_builtin_filters:
|
||||
for name, observer in self.observers.items():
|
||||
|
|
@ -138,8 +130,9 @@ class Router:
|
|||
|
||||
if not self.use_builtin_filters and parent.use_builtin_filters:
|
||||
warnings.warn(
|
||||
f"Router(use_builtin_filters=False) has no effect for router {self} "
|
||||
f"in due to builtin filters is already registered in parent router",
|
||||
f"{self.__class__.__name__}(use_builtin_filters=False) has no effect"
|
||||
f" for router {self} in due to builtin filters is already registered"
|
||||
f" in parent router",
|
||||
CodeHasNoEffect,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
|
@ -167,73 +160,6 @@ class Router:
|
|||
router.parent_router = self
|
||||
return router
|
||||
|
||||
async def _listen_update(self, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Main updates listener
|
||||
|
||||
Workflow:
|
||||
- Detect content type and propagate to observers in current router
|
||||
- If no one filter is pass - propagate update to child routers as Update
|
||||
|
||||
:param update:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
event: TelegramObject
|
||||
if update.message:
|
||||
update_type = "message"
|
||||
event = update.message
|
||||
elif update.edited_message:
|
||||
update_type = "edited_message"
|
||||
event = update.edited_message
|
||||
elif update.channel_post:
|
||||
update_type = "channel_post"
|
||||
event = update.channel_post
|
||||
elif update.edited_channel_post:
|
||||
update_type = "edited_channel_post"
|
||||
event = update.edited_channel_post
|
||||
elif update.inline_query:
|
||||
update_type = "inline_query"
|
||||
event = update.inline_query
|
||||
elif update.chosen_inline_result:
|
||||
update_type = "chosen_inline_result"
|
||||
event = update.chosen_inline_result
|
||||
elif update.callback_query:
|
||||
update_type = "callback_query"
|
||||
event = update.callback_query
|
||||
elif update.shipping_query:
|
||||
update_type = "shipping_query"
|
||||
event = update.shipping_query
|
||||
elif update.pre_checkout_query:
|
||||
update_type = "pre_checkout_query"
|
||||
event = update.pre_checkout_query
|
||||
elif update.poll:
|
||||
update_type = "poll"
|
||||
event = update.poll
|
||||
elif update.poll_answer:
|
||||
update_type = "poll_answer"
|
||||
event = update.poll_answer
|
||||
else:
|
||||
warnings.warn(
|
||||
"Detected unknown update type.\n"
|
||||
"Seems like Telegram Bot API was updated and you have "
|
||||
"installed not latest version of aiogram framework",
|
||||
RuntimeWarning,
|
||||
)
|
||||
raise SkipHandler
|
||||
|
||||
kwargs.update(event_update=update, event_router=self)
|
||||
observer = self.observers[update_type]
|
||||
response = await observer.trigger(event, update=update, **kwargs)
|
||||
|
||||
if response is NOT_HANDLED: # Resolve nested routers
|
||||
for router in self.sub_routers:
|
||||
response = await router.update.trigger(event=update, **kwargs)
|
||||
if response is NOT_HANDLED:
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursively call startup callbacks
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from ..types import (
|
||||
UNSET,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue