Add middlewares (API + Docs + Tests)

This commit is contained in:
Alex Root Junior 2020-04-12 20:27:32 +03:00
parent e4cd4c1763
commit 5b6ec599b1
24 changed files with 1120 additions and 42 deletions

View file

@ -3,6 +3,7 @@ from .api.client import session
from .api.client.bot import Bot
from .dispatcher import filters, handler
from .dispatcher.dispatcher import Dispatcher
from .dispatcher.middlewares.base import BaseMiddleware
from .dispatcher.router import Router
try:
@ -22,6 +23,7 @@ __all__ = (
"session",
"Dispatcher",
"Router",
"BaseMiddleware",
"filters",
"handler",
)

View file

@ -240,6 +240,8 @@ class Message(TelegramObject):
return ContentType.PASSPORT_DATA
if self.poll:
return ContentType.POLL
if self.dice:
return ContentType.DICE
return ContentType.UNKNOWN

View file

@ -1,21 +1,12 @@
from __future__ import annotations
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Optional,
Type,
)
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, Generator, List, Type
from pydantic import ValidationError
from ..filters.base import BaseFilter
from ..middlewares.types import MiddlewareStep, UpdateType
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
if TYPE_CHECKING: # pragma: no cover
@ -95,10 +86,8 @@ class TelegramEventObserver(EventObserver):
"""
registry: List[Type[BaseFilter]] = []
router: Optional[Router] = self.router
while router:
for router in self.router.chain:
observer = router.observers[self.event_name]
router = router.parent_router
for filter_ in observer.filters:
if filter_ in registry:
@ -133,6 +122,37 @@ class TelegramEventObserver(EventObserver):
return filters
async def trigger_middleware(
self, step: MiddlewareStep, event: UpdateType, data: Dict[str, Any], result: Any = None,
) -> None:
"""
Trigger middlewares chain
:param step:
:param event:
:param data:
:param result:
:return:
"""
reverse = step == MiddlewareStep.POST_PROCESS
recursive = self.event_name == "update" or step == MiddlewareStep.PROCESS
if self.event_name == "update":
routers = self.router.chain
else:
routers = self.router.chain_head
for router in routers:
await router.middleware.trigger(
step=step,
event_name=self.event_name,
event=event,
data=data,
result=result,
reverse=reverse,
)
if not recursive:
break
def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
) -> HandlerType:
@ -153,12 +173,24 @@ class TelegramEventObserver(EventObserver):
Propagate event to handlers and stops propagation on first match.
Handler will be called when all its filters is pass.
"""
event = args[0]
await self.trigger_middleware(step=MiddlewareStep.PRE_PROCESS, event=event, data=kwargs)
for handler in self.handlers:
result, data = await handler.check(*args, **kwargs)
if result:
kwargs.update(data)
await self.trigger_middleware(
step=MiddlewareStep.PROCESS, event=event, data=kwargs
)
try:
yield await handler.call(*args, **kwargs)
response = await handler.call(*args, **kwargs)
await self.trigger_middleware(
step=MiddlewareStep.POST_PROCESS,
event=event,
data=kwargs,
result=response,
)
yield response
except SkipHandler:
continue
break

View file

@ -0,0 +1,61 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
class AbstractMiddleware(ABC):
"""
Abstract class for middleware.
"""
def __init__(self) -> None:
self._manager: Optional[MiddlewareManager] = None
@property
def manager(self) -> MiddlewareManager:
"""
Instance of MiddlewareManager
"""
if self._manager is None:
raise RuntimeError("Middleware is not configured!")
return self._manager
def setup(self, manager: MiddlewareManager, _stack_level: int = 1) -> AbstractMiddleware:
"""
Mark middleware as configured
:param manager:
:param _stack_level:
:return:
"""
if self.configured:
return manager.setup(self, _stack_level=_stack_level + 1)
self._manager = manager
return self
@property
def configured(self) -> bool:
"""
Check middleware is configured
:return:
"""
return bool(self._manager)
@abstractmethod
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
) -> Any: # pragma: no cover
pass

View file

@ -0,0 +1,300 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict
from aiogram.dispatcher.middlewares.abstract import AbstractMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
class BaseMiddleware(AbstractMiddleware):
"""
Base class for middleware.
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
"""
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
) -> Any:
"""
Trigger action.
:param step:
:param event_name:
:param event:
:param data:
:param result:
:return:
"""
handler_name = f"on_{step.value}_{event_name}"
handler = getattr(self, handler_name, None)
if not handler:
return None
args = (event, result, data) if step == MiddlewareStep.POST_PROCESS else (event, data)
return await handler(*args)
if TYPE_CHECKING: # pragma: no cover
# =============================================================================================
# Event that triggers before process <event>
# =============================================================================================
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process update
"""
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process message
"""
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_message
"""
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process channel_post
"""
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process edited_channel_post
"""
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process inline_query
"""
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process chosen_inline_result
"""
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process callback_query
"""
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process shipping_query
"""
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process pre_checkout_query
"""
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers before process poll
"""
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers before process poll_answer
"""
# =============================================================================================
# Event that triggers on process <event> after filters.
# =============================================================================================
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process update
"""
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process message
"""
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_message
"""
async def on_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process channel_post
"""
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process edited_channel_post
"""
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process inline_query
"""
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process chosen_inline_result
"""
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process callback_query
"""
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process shipping_query
"""
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process pre_checkout_query
"""
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
"""
Event that triggers on process poll
"""
async def on_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
"""
Event that triggers on process poll_answer
"""
# =============================================================================================
# Event that triggers after process <event>.
# =============================================================================================
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing update
"""
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing message
"""
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_message
"""
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing channel_post
"""
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing edited_channel_post
"""
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing inline_query
"""
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing chosen_inline_result
"""
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing callback_query
"""
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing shipping_query
"""
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing pre_checkout_query
"""
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
"""
Event that triggers after processing poll
"""
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
"""
Event that triggers after processing poll_answer
"""

View file

@ -0,0 +1,71 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from warnings import warn
from .abstract import AbstractMiddleware
from .types import MiddlewareStep, UpdateType
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
class MiddlewareManager:
"""
Middleware manager.
"""
def __init__(self, router: Router) -> None:
self.router = router
self.middlewares: List[AbstractMiddleware] = []
def setup(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Setup middleware
:param middleware:
:param _stack_level:
:return:
"""
if not isinstance(middleware, AbstractMiddleware):
raise TypeError(
f"`middleware` should be instance of BaseMiddleware, not {type(middleware)}"
)
if middleware.configured:
if middleware.manager is self:
warn(
f"Middleware {middleware} is already configured for this Router "
"That's mean re-installing of this middleware has no effect.",
category=RuntimeWarning,
stacklevel=_stack_level + 1,
)
return middleware
raise ValueError(
f"Middleware is already configured for another manager {middleware.manager} "
f"in router {middleware.manager.router}!"
)
self.middlewares.append(middleware)
middleware.setup(self)
return middleware
async def trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
data: Dict[str, Any],
result: Any = None,
reverse: bool = False,
) -> Any:
"""
Call action to middlewares with args lilt.
"""
middlewares = reversed(self.middlewares) if reverse else self.middlewares
for middleware in middlewares:
await middleware.trigger(
step=step, event_name=event_name, event=event, data=data, result=result
)
def __contains__(self, item: AbstractMiddleware) -> bool:
return item in self.middlewares

View file

@ -0,0 +1,34 @@
from __future__ import annotations
from enum import Enum
from typing import Union
from aiogram.api.types import (
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
)
UpdateType = Union[
CallbackQuery,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
]
class MiddlewareStep(Enum):
PRE_PROCESS = "pre_process"
PROCESS = "process"
POST_PROCESS = "post_process"

View file

@ -1,13 +1,15 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Union
from ..api.types import Chat, TelegramObject, Update, User
from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
from .filters import BUILTIN_FILTERS
from .middlewares.abstract import AbstractMiddleware
from .middlewares.manager import MiddlewareManager
class Router:
@ -46,6 +48,7 @@ class Router:
)
self.poll_handler = TelegramEventObserver(router=self, event_name="poll")
self.poll_answer_handler = TelegramEventObserver(router=self, event_name="poll_answer")
self.middleware = MiddlewareManager(router=self)
self.startup = EventObserver()
self.shutdown = EventObserver()
@ -74,6 +77,36 @@ class Router:
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)
@property
def chain_head(self) -> Generator[Router, None, None]:
router: Optional[Router] = self
while router:
yield router
router = router.parent_router
@property
def chain_tail(self) -> Generator[Router, None, None]:
yield self
for router in self.sub_routers:
yield from router.chain_tail
@property
def chain(self) -> Generator[Router, None, None]:
yield from self.chain_head
tail = self.chain_tail
next(tail) # Skip self
yield from tail
def use(self, middleware: AbstractMiddleware, _stack_level: int = 1) -> AbstractMiddleware:
"""
Use middleware
:param middleware:
:param _stack_level:
:return:
"""
return self.middleware.setup(middleware, _stack_level=_stack_level + 1)
@property
def parent_router(self) -> Optional[Router]:
return self._parent_router

View file

@ -1,3 +1,4 @@
import logging
dispatcher = logging.getLogger("aiogram.dispatcher")
middlewares = logging.getLogger("aiogram.middlewares")