mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Global filters for router (#644)
* Bump version * Added more comments * Cover registering global filters * Reformat code * Add more tests * Rework event propagation to routers mechanism. Fixed compatibility with Python 3.10 syntax (match keyword) * Fixed tests * Fixed coverage Co-authored-by: evgfilim1 <evgfilim1@yandex.ru>
This commit is contained in:
parent
a70ecb767f
commit
4f2cc75951
13 changed files with 176 additions and 31 deletions
|
|
@ -232,20 +232,11 @@ class Dispatcher(Router):
|
|||
"installed not latest version of aiogram framework",
|
||||
RuntimeWarning,
|
||||
)
|
||||
raise SkipHandler
|
||||
raise SkipHandler()
|
||||
|
||||
kwargs.update(event_update=update)
|
||||
|
||||
for router in self.chain:
|
||||
kwargs.update(event_router=router)
|
||||
observer = router.observers[update_type]
|
||||
response = await observer.trigger(event, update=update, **kwargs)
|
||||
if response is not UNHANDLED:
|
||||
break
|
||||
else:
|
||||
response = UNHANDLED
|
||||
|
||||
return response
|
||||
return await self.propagate_event(update_type=update_type, event=event, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _silent_call_request(cls, bot: Bot, result: TelegramMethod[Any]) -> None:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ MiddlewareType = Union[
|
|||
]
|
||||
|
||||
UNHANDLED = sentinel.UNHANDLED
|
||||
REJECTED = sentinel.REJECTED
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from pydantic import ValidationError
|
|||
|
||||
from ...types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from .bases import UNHANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .bases import REJECTED, UNHANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
|
|
@ -32,6 +32,24 @@ class TelegramEventObserver:
|
|||
self.outer_middlewares: List[MiddlewareType] = []
|
||||
self.middlewares: List[MiddlewareType] = []
|
||||
|
||||
# Re-used filters check method from already implemented handler object
|
||||
# with dummy callback which never will be used
|
||||
self._handler = HandlerObject(callback=lambda: True, filters=[])
|
||||
|
||||
def filter(self, *filters: FilterType, **bound_filters: Any) -> None:
|
||||
"""
|
||||
Register filter for all handlers of this event observer
|
||||
|
||||
:param filters: positional filters
|
||||
:param bound_filters: keyword filters
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
if self._handler.filters is None:
|
||||
self._handler.filters = []
|
||||
self._handler.filters.extend(
|
||||
[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)]
|
||||
)
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
"""
|
||||
Register filter class in factory
|
||||
|
|
@ -139,6 +157,12 @@ class TelegramEventObserver:
|
|||
return await wrapped_outer(event, kwargs)
|
||||
|
||||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
# Check globally defined filters before any other handler will be checked
|
||||
result, data = await self._handler.check(event, **kwargs)
|
||||
if not result:
|
||||
return REJECTED
|
||||
kwargs.update(data)
|
||||
|
||||
for handler in self.handlers:
|
||||
result, data = await handler.check(event, **kwargs)
|
||||
if result:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class Command(BaseFilter):
|
|||
if isinstance(allowed_command, Pattern): # Regexp
|
||||
result = allowed_command.match(command.command)
|
||||
if result:
|
||||
return replace(command, match=result)
|
||||
return replace(command, regexp_match=result)
|
||||
elif command.command == allowed_command: # String
|
||||
return command
|
||||
raise CommandException("Command did not match pattern")
|
||||
|
|
@ -134,7 +134,7 @@ class CommandObject:
|
|||
"""Mention (if available)"""
|
||||
args: Optional[str] = field(repr=False, default=None)
|
||||
"""Command argument"""
|
||||
match: Optional[Match[str]] = field(repr=False, default=None)
|
||||
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
|
||||
"""Will be presented match result if the command is presented as regexp in filter"""
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -26,20 +26,20 @@ class ExceptionMessageFilter(BaseFilter):
|
|||
Allow to match exception by message
|
||||
"""
|
||||
|
||||
match: Union[str, Pattern[str]]
|
||||
pattern: Union[str, Pattern[str]]
|
||||
"""Regexp pattern"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("match")
|
||||
@validator("pattern")
|
||||
def _validate_match(cls, value: Union[str, Pattern[str]]) -> Union[str, Pattern[str]]:
|
||||
if isinstance(value, str):
|
||||
return re.compile(value)
|
||||
return value
|
||||
|
||||
async def __call__(self, exception: Exception) -> Union[bool, Dict[str, Any]]:
|
||||
pattern = cast(Pattern[str], self.match)
|
||||
pattern = cast(Pattern[str], self.pattern)
|
||||
result = pattern.match(str(exception))
|
||||
if not result:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class RedisStorage(BaseStorage):
|
|||
return cls(redis=redis, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.redis.close()
|
||||
await self.redis.close() # type: ignore
|
||||
|
||||
def generate_key(self, bot: Bot, *parts: Any) -> str:
|
||||
prefix_parts = [self.prefix]
|
||||
|
|
@ -73,7 +73,7 @@ class RedisStorage(BaseStorage):
|
|||
await self.redis.delete(key)
|
||||
else:
|
||||
await self.redis.set(
|
||||
key, state.state if isinstance(state, State) else state, ex=self.state_ttl
|
||||
key, state.state if isinstance(state, State) else state, ex=self.state_ttl # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
|
||||
|
|
@ -89,7 +89,7 @@ class RedisStorage(BaseStorage):
|
|||
await self.redis.delete(key)
|
||||
return
|
||||
json_data = bot.session.json_dumps(data)
|
||||
await self.redis.set(key, json_data, ex=self.data_ttl)
|
||||
await self.redis.set(key, json_data, ex=self.data_ttl) # type: ignore[arg-type]
|
||||
|
||||
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
|
||||
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ from __future__ import annotations
|
|||
import warnings
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from ..types import TelegramObject
|
||||
from ..utils.imports import import_module
|
||||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
from .event.event import EventObserver
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
|
|
@ -82,6 +84,22 @@ class Router:
|
|||
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
|
||||
observer.bind_filter(builtin_filter)
|
||||
|
||||
async def propagate_event(self, update_type: str, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
kwargs.update(event_router=self)
|
||||
observer = self.observers[update_type]
|
||||
response = await observer.trigger(event, **kwargs)
|
||||
if response is REJECTED:
|
||||
return UNHANDLED
|
||||
if response is not UNHANDLED:
|
||||
return response
|
||||
|
||||
for router in self.sub_routers:
|
||||
response = await router.propagate_event(update_type=update_type, event=event, **kwargs)
|
||||
if response is not UNHANDLED:
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
@property
|
||||
def chain_head(self) -> Generator[Router, None, None]:
|
||||
router: Optional[Router] = self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue