Rewrite filters

This commit is contained in:
Alex Root Junior 2022-08-14 21:43:33 +03:00
parent 3f57c69d4f
commit 8d2aae77c1
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
24 changed files with 311 additions and 539 deletions

View file

@ -114,8 +114,8 @@ docs-gettext:
.PHONY: docs-gettext
docs-serve:
rm -rf docs/_build
$(py) sphinx-autobuild --watch aiogram/ docs/ docs/_build/ $(OPTS)
#rm -rf docs/_build
$(py) sphinx-autobuild --watch aiogram/ --watch CHANGELOG.rst --watch README.rst docs/ docs/_build/ $(OPTS)
.PHONY: docs-serve
$(locale_targets): docs-serve-%:

View file

@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from magic_filter import MagicFilter
from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.filters.base import Filter
from aiogram.handlers import BaseHandler
CallbackType = Callable[..., Any]
@ -53,8 +54,12 @@ class FilterObject(CallableMixin):
if isinstance(self.callback, MagicFilter):
# MagicFilter instance is callable but generates only "CallOperation" instead of applying the filter
self.callback = self.callback.resolve
super().__post_init__()
if isinstance(self.callback, Filter):
self.awaitable = True
@dataclass
class HandlerObject(CallableMixin):

View file

@ -1,17 +1,10 @@
from __future__ import annotations
import warnings
from inspect import isclass
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
from pydantic import ValidationError
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.filters.base import BaseFilter
from ...exceptions import FiltersResolveError
from ...filters import BUILTIN_FILTERS_SET
from ...filters.base import Filter
from ...types import TelegramObject
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, HandlerObject
@ -33,7 +26,6 @@ class TelegramEventObserver:
self.event_name: str = event_name
self.handlers: List[HandlerObject] = []
self.filters: List[Type[BaseFilter]] = []
self.middleware = MiddlewareManager()
self.outer_middleware = MiddlewareManager()
@ -42,63 +34,16 @@ class TelegramEventObserver:
# with dummy callback which never will be used
self._handler = HandlerObject(callback=lambda: True, filters=[])
def filter(self, *filters: CallbackType, _stacklevel: int = 2, **bound_filters: Any) -> None:
def filter(self, *filters: CallbackType) -> None:
"""
Register filter for all handlers of this event observer
:param filters: positional filters
:param bound_filters: keyword filters
"""
resolved_filters = self.resolve_filters(
filters, bound_filters, _stacklevel=_stacklevel + 1
)
if self._handler.filters is None:
self._handler.filters = []
self._handler.filters.extend(
[
FilterObject(filter_) # type: ignore
for filter_ in chain(
resolved_filters,
filters,
)
]
)
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
"""
Register filter class in factory
:param bound_filter:
"""
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
raise TypeError(
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
)
if bound_filter not in BUILTIN_FILTERS_SET:
warnings.warn(
category=DeprecationWarning,
message="filters factory deprecated and will be removed in 3.0b5,"
" use filters directly instead (Example: "
f"`{bound_filter.__name__}(<argument>=<value>)` instead of `<argument>=<value>`)",
stacklevel=2,
)
self.filters.append(bound_filter)
def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]:
"""
Get all bounded filters from current observer and from the parents
with the same event type without duplicates
"""
registry: List[Type[BaseFilter]] = []
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
for filter_ in observer.filters:
if filter_ in registry:
continue
yield filter_
registry.append(filter_)
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
middlewares: List[MiddlewareType[TelegramObject]] = []
@ -108,112 +53,30 @@ class TelegramEventObserver:
return middlewares
def resolve_filters(
self,
filters: Tuple[CallbackType, ...],
full_config: Dict[str, Any],
ignore_default: bool = True,
_stacklevel: int = 2,
) -> List[BaseFilter]:
"""
Resolve keyword filters via filters factory
:param filters: positional filters
:param full_config: keyword arguments to initialize bounded filters for router/handler
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
"""
bound_filters: List[BaseFilter] = []
if ignore_default and not full_config:
return bound_filters
filter_types = set(type(f) for f in filters)
validation_errors = []
for bound_filter in self._resolve_filters_chain():
# skip filter if filter was used as positional filter:
if bound_filter in filter_types:
continue
# skip filter with no fields in full_config
if ignore_default:
full_config_keys = set(full_config.keys())
filter_fields = set(bound_filter.__fields__.keys())
if not full_config_keys.intersection(filter_fields):
continue
# Try to initialize filter.
try:
f = bound_filter(**full_config)
except ValidationError as e:
validation_errors.append(e)
continue
# Clean full config to prevent to re-initialize another filter
# with the same configuration
for key in f.__fields__:
full_config.pop(key, None)
bound_filters.append(f)
if full_config:
possible_cases = []
for error in validation_errors:
for sum_error in error.errors():
if sum_error["loc"][0] in full_config:
possible_cases.append(error)
break
raise FiltersResolveError(
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
)
if bound_filters:
warnings.warn(
category=DeprecationWarning,
message="Filters factory deprecated and will be removed in 3.0b5.\n"
"Use filters directly, for example instead of "
"`@router.message(commands=['help']')` "
"use `@router.message(Command(commands=['help'])`",
stacklevel=_stacklevel,
)
return bound_filters
def register(
self,
callback: CallbackType,
*filters: CallbackType,
flags: Optional[Dict[str, Any]] = None,
_stacklevel: int = 2,
**bound_filters: Any,
) -> CallbackType:
"""
Register event handler
"""
if flags is None:
flags = {}
resolved_filters = self.resolve_filters(
filters,
bound_filters,
ignore_default=False,
_stacklevel=_stacklevel + 1,
)
for resolved_filter in resolved_filters:
resolved_filter.update_handler_flags(flags=flags)
for item in filters:
if isinstance(item, Filter):
item.update_handler_flags(flags=flags)
self.handlers.append(
HandlerObject(
callback=callback,
filters=[
FilterObject(filter_) # type: ignore
for filter_ in chain(
resolved_filters,
filters,
)
],
filters=[FilterObject(filter_) for filter_ in filters],
flags=flags,
)
)
return callback
def wrap_outer_middleware(
@ -253,19 +116,15 @@ class TelegramEventObserver:
def __call__(
self,
*args: CallbackType,
*filters: CallbackType,
flags: Optional[Dict[str, Any]] = None,
_stacklevel: int = 2,
**bound_filters: Any,
) -> Callable[[CallbackType], CallbackType]:
"""
Decorator for registering event handlers
"""
def wrapper(callback: CallbackType) -> CallbackType:
self.register(
callback, *args, flags=flags, **bound_filters, _stacklevel=_stacklevel + 1
)
self.register(callback, *filters, flags=flags)
return callback
return wrapper

View file

@ -1,8 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast
from ...types import TelegramObject
from ...types import TelegramObject, Update
from ...types.error_event import ErrorEvent
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
from .base import BaseMiddleware
@ -26,7 +27,9 @@ class ErrorsMiddleware(BaseMiddleware):
raise
except Exception as e:
response = await self.router.propagate_event(
update_type="error", event=event, **data, exception=e
update_type="error",
event=ErrorEvent(update=cast(Update, event), exception=e),
**data,
)
if response is not UNHANDLED:
return response

View file

@ -3,8 +3,6 @@ from __future__ import annotations
import warnings
from typing import Any, Dict, Final, Generator, List, Optional, Set, Union
from aiogram.filters import BUILTIN_FILTERS
from ..types import TelegramObject
from ..utils.warnings import CodeHasNoEffect
from .event.bases import REJECTED, UNHANDLED
@ -28,7 +26,7 @@ class Router:
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
"""
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
:param use_builtin_filters: `aiogram` has many builtin filters, and you can controll automatic registration of this filters in factory
:param name: Optional router name, can be useful for debugging
"""
@ -83,12 +81,6 @@ class Router:
"error": self.errors,
}
# Builtin filters
if use_builtin_filters:
for name, observer in self.observers.items():
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
observer.bind_filter(builtin_filter)
def __str__(self) -> str:
return f"{type(self).__name__} {self.name!r}"

View file

@ -1,7 +1,6 @@
from itertools import chain
from typing import Dict, Tuple, Type
from .base import BaseFilter
from .base import Filter
from .chat_member_updated import (
ADMINISTRATOR,
CREATOR,
@ -18,7 +17,6 @@ from .chat_member_updated import (
ChatMemberUpdatedFilter,
)
from .command import Command, CommandObject, CommandStart
from .content_types import ContentTypesFilter
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
from .logic import and_f, invert_f, or_f
from .magic_data import MagicData
@ -27,12 +25,11 @@ from .text import Text
__all__ = (
"BUILTIN_FILTERS",
"BaseFilter",
"Filter",
"Text",
"Command",
"CommandObject",
"CommandStart",
"ContentTypesFilter",
"ExceptionMessageFilter",
"ExceptionTypeFilter",
"StateFilter",
@ -55,85 +52,4 @@ __all__ = (
"invert_f",
)
_ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,)
_TELEGRAM_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (StateFilter,)
BUILTIN_FILTERS: Dict[str, Tuple[Type[BaseFilter], ...]] = {
"message": (
Text,
Command,
ContentTypesFilter,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"edited_message": (
Text,
Command,
ContentTypesFilter,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"channel_post": (
Text,
ContentTypesFilter,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"edited_channel_post": (
Text,
ContentTypesFilter,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"inline_query": (
Text,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"chosen_inline_result": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"callback_query": (
Text,
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"shipping_query": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"pre_checkout_query": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"poll": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"poll_answer": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"my_chat_member": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
ChatMemberUpdatedFilter,
),
"chat_member": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
ChatMemberUpdatedFilter,
),
"chat_join_request": (
*_ALL_EVENTS_FILTERS,
*_TELEGRAM_EVENTS_FILTERS,
),
"error": (
ExceptionMessageFilter,
ExceptionTypeFilter,
*_ALL_EVENTS_FILTERS,
),
}
BUILTIN_FILTERS_SET = set(chain.from_iterable(BUILTIN_FILTERS.values()))
BUILTIN_FILTERS: Dict[str, Tuple[Type[Filter], ...]] = {}

View file

@ -1,19 +1,14 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
from pydantic import BaseModel
from aiogram.filters.logic import _LogicFilter
class BaseFilter(BaseModel, ABC, _LogicFilter):
class Filter(_LogicFilter, ABC):
"""
If you want to register own filters like builtin filters you will need to write subclass
of this class with overriding the :code:`__call__`
method and adding filter attributes.
BaseFilter is subclass of :class:`pydantic.BaseModel` that's mean all subclasses of BaseFilter has
the validators based on class attributes and custom validator.
"""
if TYPE_CHECKING:
@ -37,6 +32,14 @@ class BaseFilter(BaseModel, ABC, _LogicFilter):
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
pass
def __await__(self): # type: ignore # pragma: no cover
# Is needed only for inspection and this method is never be called
return self.__call__
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
items = [repr(arg) for arg in args]
items.extend([f"{k}={v!r}" for k, v in kwargs.items()])
return f"{type(self).__name__}({', '.join(items)})"
def __str__(self) -> str:
return self._signature_to_string()
BaseFilter = Filter

View file

@ -9,7 +9,7 @@ from uuid import UUID
from magic_filter import MagicFilter
from pydantic import BaseModel
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery
T = TypeVar("T", bound="CallbackData")
@ -122,11 +122,8 @@ class CallbackData(BaseModel):
"""
return CallbackQueryFilter(callback_data=cls, rule=rule)
# class Config:
# use_enum_values = True
class CallbackQueryFilter(BaseFilter):
class CallbackQueryFilter(Filter):
"""
This filter helps to handle callback query.
@ -134,10 +131,18 @@ class CallbackQueryFilter(BaseFilter):
via callback data instance
"""
callback_data: Type[CallbackData]
"""Expected type of callback data"""
rule: Optional[MagicFilter] = None
"""Magic rule"""
def __init__(
self,
*,
callback_data: Type[CallbackData],
rule: Optional[MagicFilter] = None,
):
"""
:param callback_data: Expected type of callback data
:param rule: Magic rule
"""
self.callback_data = callback_data
self.rule = rule
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
if not isinstance(query, CallbackQuery) or not query.data:
@ -150,7 +155,3 @@ class CallbackQueryFilter(BaseFilter):
if self.rule is None or self.rule.resolve(callback_data):
return {"callback_data": callback_data}
return False
class Config:
arbitrary_types_allowed = True
use_enum_values = True

View file

@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, TypeVar, Union
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.types import ChatMember, ChatMemberUpdated
MarkerT = TypeVar("MarkerT", bound="_MemberStatusMarker")
@ -154,16 +154,16 @@ LEAVE_TRANSITION = ~JOIN_TRANSITION
PROMOTED_TRANSITION = (MEMBER | RESTRICTED | LEFT | KICKED) >> ADMINISTRATOR
class ChatMemberUpdatedFilter(BaseFilter):
member_status_changed: Union[
_MemberStatusMarker,
_MemberStatusGroupMarker,
_MemberStatusTransition,
]
"""Accepts the status transition or new status of the member (see usage in docs)"""
class Config:
arbitrary_types_allowed = True
class ChatMemberUpdatedFilter(Filter):
def __init__(
self,
member_status_changed: Union[
_MemberStatusMarker,
_MemberStatusGroupMarker,
_MemberStatusTransition,
],
):
self.member_status_changed = member_status_changed
async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]:
old = member_updated.old_chat_member

View file

@ -2,57 +2,111 @@ from __future__ import annotations
import re
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Match,
Optional,
Pattern,
Sequence,
Union,
cast,
)
from magic_filter import MagicFilter
from pydantic import Field, validator
from aiogram.filters import BaseFilter
from aiogram.types import Message
from aiogram.filters.base import Filter
from aiogram.types import BotCommand, Message
from aiogram.utils.deep_linking import decode_payload
if TYPE_CHECKING:
from aiogram import Bot
CommandPatternType = Union[str, re.Pattern]
CommandPatternType = Union[str, re.Pattern, BotCommand]
class CommandException(Exception):
pass
class Command(BaseFilter):
class Command(Filter):
"""
This filter can be helpful for handling commands from the text messages.
Works only with :class:`aiogram.types.message.Message` events which have the :code:`text`.
"""
commands: Union[Sequence[CommandPatternType], CommandPatternType]
"""List of commands (string or compiled regexp patterns)"""
commands_prefix: str = "/"
"""Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes,
for example: :code:`"/!"` will work with commands prefixed by :code:`"/"` or :code:`"!"`."""
commands_ignore_case: bool = False
"""Ignore case (Does not work with regexp, use flags instead)"""
commands_ignore_mention: bool = False
"""Ignore bot mention. By default bot can not handle commands intended for other bots"""
command_magic: Optional[MagicFilter] = None
"""Validate command object via Magic filter after all checks done"""
def __init__(
self,
*values: CommandPatternType,
commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None,
prefix: str = "/",
ignore_case: bool = False,
ignore_mention: bool = False,
magic: Optional[MagicFilter] = None,
):
"""
List of commands (string or compiled regexp patterns)
:param prefix: Prefix for command.
Prefix is always a single char but here you can pass all of allowed prefixes,
for example: :code:`"/!"` will work with commands prefixed
by :code:`"/"` or :code:`"!"`.
:param ignore_case: Ignore case (Does not work with regexp, use flags instead)
:param ignore_mention: Ignore bot mention. By default,
bot can not handle commands intended for other bots
:param magic: Validate command object via Magic filter after all checks done
"""
if commands is None:
commands = []
if isinstance(commands, (str, re.Pattern, BotCommand)):
commands = [commands]
if not isinstance(commands, Iterable):
ValueError(
"Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable"
)
items = []
for command in (*values, *commands):
if isinstance(command, BotCommand):
command = command.command
if not isinstance(command, (str, re.Pattern)):
raise ValueError(
"Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable"
)
items.append(command)
if not items:
raise ValueError("At least one command should be specified")
self.commands = tuple(items)
self.prefix = prefix
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
self.magic = magic
def __str__(self) -> str:
return self._signature_to_string(
*self.commands,
prefix=self.prefix,
ignore_case=self.ignore_case,
ignore_mention=self.ignore_mention,
magic=self.magic,
)
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
commands = flags.setdefault("commands", [])
commands.append(self)
@validator("commands", always=True)
def _validate_commands(
cls, value: Union[Sequence[CommandPatternType], CommandPatternType]
) -> Sequence[CommandPatternType]:
if isinstance(value, (str, re.Pattern)):
value = [value]
return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
if not isinstance(message, Message):
return False
text = message.text or message.caption
if not text:
return False
@ -82,11 +136,11 @@ class Command(BaseFilter):
)
def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.commands_prefix:
if command.prefix not in self.prefix:
raise CommandException("Invalid command prefix")
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.commands_ignore_mention:
if command.mention and not self.ignore_mention:
me = await bot.me()
if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match")
@ -119,16 +173,13 @@ class Command(BaseFilter):
return command
def do_magic(self, command: CommandObject) -> Any:
if not self.command_magic:
if not self.magic:
return command
result = self.command_magic.resolve(command)
result = self.magic.resolve(command)
if not result:
raise CommandException("Rejected via magic filter")
return replace(command, magic_result=result)
class Config:
arbitrary_types_allowed = True
@dataclass(frozen=True)
class CommandObject:
@ -170,10 +221,34 @@ class CommandObject:
class CommandStart(Command):
commands: Tuple[str] = Field(("start",), const=True)
commands_prefix: str = Field("/", const=True)
deep_link: bool = False
deep_link_encoded: bool = False
def __init__(
self,
deep_link: bool = False,
deep_link_encoded: bool = False,
ignore_case: bool = False,
ignore_mention: bool = False,
magic: Optional[MagicFilter] = None,
):
super().__init__(
"start",
prefix="/",
ignore_case=ignore_case,
ignore_mention=ignore_mention,
magic=magic,
)
self.deep_link = deep_link
self.deep_link_encoded = deep_link_encoded
def __str__(self) -> str:
return self._signature_to_string(
*self.commands,
prefix=self.prefix,
ignore_case=self.ignore_case,
ignore_mention=self.ignore_mention,
magic=self.magic,
deep_link=self.deep_link,
deep_link_encoded=self.deep_link_encoded,
)
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""

View file

@ -1,34 +0,0 @@
from typing import Any, Dict, Optional, Sequence, Union
from pydantic import validator
from aiogram.types import Message
from aiogram.types.message import ContentType
from .base import BaseFilter
class ContentTypesFilter(BaseFilter):
"""
Is useful for handling specific types of messages (For example separate text and stickers handlers).
"""
content_types: Union[Sequence[str], str]
"""Sequence of allowed content types"""
@validator("content_types")
def _validate_content_types(
cls, value: Optional[Union[Sequence[str], str]]
) -> Optional[Sequence[str]]:
if not value:
return value
if isinstance(value, str):
value = [value]
allowed_content_types = set(ContentType.all())
bad_content_types = set(value) - allowed_content_types
if bad_content_types:
raise ValueError(f"Invalid content types {bad_content_types} is not allowed here")
return value
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
return ContentType.ANY in self.content_types or message.content_type in self.content_types

View file

@ -1,51 +1,47 @@
import re
from typing import Any, Dict, Pattern, Tuple, Type, Union, cast
from typing import Any, Dict, Pattern, Type, Union, cast
from pydantic import validator
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.types import TelegramObject
from aiogram.types.error_event import ErrorEvent
class ExceptionTypeFilter(BaseFilter):
class ExceptionTypeFilter(Filter):
"""
Allows to match exception by type
"""
exception: Union[Type[Exception], Tuple[Type[Exception]]]
"""Exception type(s)"""
def __init__(self, *exceptions: Type[Exception]):
"""
:param exceptions: Exception type(s)
"""
if not exceptions:
raise ValueError("At least one exception type is required")
self.exceptions = exceptions
class Config:
arbitrary_types_allowed = True
async def __call__(
self, obj: TelegramObject, exception: Exception
) -> Union[bool, Dict[str, Any]]:
return isinstance(exception, self.exception)
async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]:
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
class ExceptionMessageFilter(BaseFilter):
class ExceptionMessageFilter(Filter):
"""
Allow to match exception by message
"""
pattern: Union[str, Pattern[str]]
"""Regexp pattern"""
class Config:
arbitrary_types_allowed = True
@validator("pattern")
def _validate_match(cls, value: Union[str, Pattern[str]]) -> Union[str, Pattern[str]]:
if isinstance(value, str):
return re.compile(value)
return value
def __init__(self, pattern: Union[str, Pattern[str]]):
"""
:param pattern: Regexp pattern
"""
if not isinstance(pattern, str):
pattern = re.compile(pattern)
self.pattern = pattern
async def __call__(
self, obj: TelegramObject, exception: Exception
self,
obj: TelegramObject,
) -> Union[bool, Dict[str, Any]]:
pattern = cast(Pattern[str], self.pattern)
result = pattern.match(str(exception))
result = pattern.match(str(cast(ErrorEvent, obj).exception))
if not result:
return False
return {"match_exception": result}

View file

@ -2,15 +2,13 @@ from typing import Any
from magic_filter import AttrDict, MagicFilter
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.types import TelegramObject
class MagicData(BaseFilter):
magic_data: MagicFilter
class Config:
arbitrary_types_allowed = True
class MagicData(Filter):
def __init__(self, magic_data: MagicFilter) -> None:
self.magic_data = magic_data
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
return self.magic_data.resolve(

View file

@ -1,40 +1,28 @@
from inspect import isclass
from typing import Any, Dict, Optional, Sequence, Type, Union, cast, no_type_check
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
from pydantic import Field, validator
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.fsm.state import State, StatesGroup
from aiogram.types import TelegramObject
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]]
class StateFilter(BaseFilter):
class StateFilter(Filter):
"""
State filter
"""
state: Union[StateType, Sequence[StateType]] = Field(...)
def __init__(self, *states: StateType) -> None:
if not states:
raise ValueError("At least one state is required")
class Config:
arbitrary_types_allowed = True
@validator("state")
@no_type_check # issubclass breaks things
def _validate_state(cls, v: Union[StateType, Sequence[StateType]]) -> Sequence[StateType]:
if (
isinstance(v, (str, State, StatesGroup))
or (isclass(v) and issubclass(v, StatesGroup))
or v is None
):
return [v]
return v
self.states = states
async def __call__(
self, obj: Union[TelegramObject], raw_state: Optional[str] = None
) -> Union[bool, Dict[str, Any]]:
allowed_states = cast(Sequence[StateType], self.state)
allowed_states = cast(Sequence[StateType], self.states)
for allowed_state in allowed_states:
if isinstance(allowed_state, str) or allowed_state is None:
if allowed_state == "*" or raw_state == allowed_state:

View file

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
from pydantic import root_validator
from aiogram.filters import BaseFilter
from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery, InlineQuery, Message, Poll
if TYPE_CHECKING:
@ -11,7 +9,7 @@ if TYPE_CHECKING:
TextType = Union[str, "LazyProxy"]
class Text(BaseFilter):
class Text(Filter):
"""
Is useful for filtering text :class:`aiogram.types.message.Message`,
any :class:`aiogram.types.callback_query.CallbackQuery` with `data`,
@ -19,7 +17,7 @@ class Text(BaseFilter):
.. warning::
Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once.
Only one of `text`, `contains`, `startswith` or `endswith` argument can be used at once.
Any of that arguments can be string, list, set or tuple of strings.
.. deprecated:: 3.0
@ -27,40 +25,54 @@ class Text(BaseFilter):
use :ref:`magic-filter <magic-filters>`. For example do :pycode:`F.text == "text"` instead
"""
text: Optional[Union[Sequence[TextType], TextType]] = None
"""Text equals value or one of values"""
text_contains: Optional[Union[Sequence[TextType], TextType]] = None
"""Text contains value or one of values"""
text_startswith: Optional[Union[Sequence[TextType], TextType]] = None
"""Text starts with value or one of values"""
text_endswith: Optional[Union[Sequence[TextType], TextType]] = None
"""Text ends with value or one of values"""
text_ignore_case: bool = False
"""Ignore case when checks"""
def __init__(
self,
text: Optional[Union[Sequence[TextType], TextType]] = None,
*,
contains: Optional[Union[Sequence[TextType], TextType]] = None,
startswith: Optional[Union[Sequence[TextType], TextType]] = None,
endswith: Optional[Union[Sequence[TextType], TextType]] = None,
ignore_case: bool = False,
):
"""
class Config:
arbitrary_types_allowed = True
@root_validator
def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Validate that only one text filter type is presented
used_args = set(
key for key, value in values.items() if key != "text_ignore_case" and value is not None
:param text: Text equals value or one of values
:param contains: Text contains value or one of values
:param startswith: Text starts with value or one of values
:param endswith: Text ends with value or one of values
:param ignore_case: Ignore case when checks
"""
self._validate_constraints(
text=text,
contains=contains,
startswith=startswith,
endswith=endswith,
)
self.text = text
self.contains = contains
self.startswith = startswith
self.endswith = endswith
self.ignore_case = ignore_case
@classmethod
def _prepare_argument(
cls, value: Optional[Union[Sequence[TextType], TextType]]
) -> Optional[Sequence[TextType]]:
from aiogram.utils.i18n.lazy_proxy import LazyProxy
if isinstance(value, (str, LazyProxy)):
return [value]
return value
@classmethod
def _validate_constraints(cls, **values: Any) -> None:
# Validate that only one text filter type is presented
used_args = set(key for key, value in values.items() if value is not None)
if len(used_args) < 1:
raise ValueError(
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
)
raise ValueError(f"Filter should contain one of arguments: {set(values.keys())}")
if len(used_args) > 1:
raise ValueError(f"Arguments {used_args} cannot be used together")
# Convert single value to list
for arg in used_args:
if isinstance(values[arg], str):
values[arg] = [values[arg]]
return values
async def __call__(
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
) -> Union[bool, Dict[str, Any]]:
@ -79,30 +91,30 @@ class Text(BaseFilter):
if not text:
return False
if self.text_ignore_case:
if self.ignore_case:
text = text.lower()
if self.text is not None:
equals = list(map(self.prepare_text, self.text))
return text in equals
if self.text_contains is not None:
contains = list(map(self.prepare_text, self.text_contains))
if self.contains is not None:
contains = list(map(self.prepare_text, self.contains))
return all(map(text.__contains__, contains))
if self.text_startswith is not None:
startswith = list(map(self.prepare_text, self.text_startswith))
if self.startswith is not None:
startswith = list(map(self.prepare_text, self.startswith))
return any(map(text.startswith, startswith))
if self.text_endswith is not None:
endswith = list(map(self.prepare_text, self.text_endswith))
if self.endswith is not None:
endswith = list(map(self.prepare_text, self.endswith))
return any(map(text.endswith, endswith))
# Impossible because the validator prevents this situation
return False # pragma: no cover
def prepare_text(self, text: str) -> str:
if self.text_ignore_case:
if self.ignore_case:
return str(text).lower()
else:
return str(text)

View file

@ -0,0 +1,10 @@
from aiogram.types import Update
from aiogram.types.base import MutableTelegramObject
class ErrorEvent(MutableTelegramObject):
update: Update
exception: Exception
class Config:
arbitrary_types_allowed = True

View file

@ -3,7 +3,7 @@ Command
=======
.. autoclass:: aiogram.filters.command.Command
:members:
:members: __init__
:member-order: bysource
:undoc-members: False
@ -18,10 +18,10 @@ When filter is passed the :class:`aiogram.filters.command.CommandObject` will be
Usage
=====
1. Filter single variant of commands: :code:`Command(commands=["start"])` or :code:`Command(commands="start")`
2. Handle command by regexp pattern: :code:`Command(commands=[re.compile(r"item_(\d+)")])`
3. Match command by multiple variants: :code:`Command(commands=["item", re.compile(r"item_(\d+)")])`
4. Handle commands in public chats intended for other bots: :code:`Command(commands=["command"], commands_ignore_mention=True)`
1. Filter single variant of commands: :code:`Command("start")`
2. Handle command by regexp pattern: :code:`Command(re.compile(r"item_(\d+)"))`
3. Match command by multiple variants: :code:`Command("item", re.compile(r"item_(\d+)"))`
4. Handle commands in public chats intended for other bots: :code:`Command("command", ignore_mention=True)`
.. warning::

View file

@ -39,7 +39,7 @@ def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]:
return True
@main_router.message(Command(commands=["add"], command_magic=F.args.func(is_bot_token)))
@main_router.message(Command(commands=["add"], magic=F.args.func(is_bot_token)))
async def command_add_bot(message: Message, command: CommandObject, bot: Bot) -> Any:
new_bot = Bot(token=command.args, session=bot.session)
try:

View file

@ -5,7 +5,7 @@ import pytest
from aiogram import F
from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject
from aiogram.filters import BaseFilter
from aiogram.filters import Filter
from aiogram.handlers import BaseHandler
from aiogram.types import Update
@ -28,7 +28,7 @@ async def callback4(foo: int, *, bar: int, baz: int):
return locals()
class Filter(BaseFilter):
class Filter(Filter):
async def __call__(self, foo: int, bar: int, baz: int) -> Union[bool, Dict[str, Any]]:
return locals()

View file

@ -9,7 +9,7 @@ from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router
from aiogram.exceptions import FiltersResolveError
from aiogram.filters import BaseFilter, Command
from aiogram.filters import Command, Filter
from aiogram.types import Chat, Message, User
from tests.deprecated import check_deprecated
@ -31,7 +31,7 @@ async def pipe_handler(*args, **kwargs):
return args, kwargs
class MyFilter1(BaseFilter):
class MyFilter1(Filter):
test: str
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -46,14 +46,14 @@ class MyFilter3(MyFilter1):
pass
class OptionalFilter(BaseFilter):
class OptionalFilter(Filter):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(BaseFilter):
class DefaultFilter(Filter):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -66,7 +66,7 @@ class TestTelegramEventObserver:
with pytest.raises(TypeError):
event_observer.bind_filter(object) # type: ignore
class MyFilter(BaseFilter):
class MyFilter(Filter):
async def __call__(
self, *args: Any, **kwargs: Any
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
@ -103,13 +103,13 @@ class TestTelegramEventObserver:
assert MyFilter3 not in filters_chain1
async def test_resolve_filters_data_from_parent_router(self):
class FilterSet(BaseFilter):
class FilterSet(Filter):
set_filter: bool
async def __call__(self, message: Message) -> dict:
return {"test": "hello world"}
class FilterGet(BaseFilter):
class FilterGet(Filter):
get_filter: bool
async def __call__(self, message: Message, **data) -> bool:

View file

@ -2,7 +2,7 @@ from typing import Awaitable
import pytest
from aiogram.filters import BaseFilter
from aiogram.filters import Filter
try:
from asynctest import CoroutineMock, patch
@ -13,7 +13,7 @@ except ImportError:
pytestmark = pytest.mark.asyncio
class MyFilter(BaseFilter):
class MyFilter(Filter):
foo: str
async def __call__(self, event: str):

View file

@ -16,34 +16,34 @@ class TestCommandFilter:
def test_convert_to_list(self):
cmd = Command(commands="start")
assert cmd.commands
assert isinstance(cmd.commands, list)
assert isinstance(cmd.commands, tuple)
assert cmd.commands[0] == "start"
assert cmd == Command(commands=["start"])
# assert cmd == Command(commands=["start"])
@pytest.mark.parametrize(
"text,command,result",
[
["/test@tbot", Command(commands=["test"], commands_prefix="/"), True],
["!test", Command(commands=["test"], commands_prefix="/"), False],
["/test@mention", Command(commands=["test"], commands_prefix="/"), False],
["/tests", Command(commands=["test"], commands_prefix="/"), False],
["/", Command(commands=["test"], commands_prefix="/"), False],
["/ test", Command(commands=["test"], commands_prefix="/"), False],
["", Command(commands=["test"], commands_prefix="/"), False],
[" ", Command(commands=["test"], commands_prefix="/"), False],
["test", Command(commands=["test"], commands_prefix="/"), False],
[" test", Command(commands=["test"], commands_prefix="/"), False],
["a", Command(commands=["test"], commands_prefix="/"), False],
["/test@tbot", Command(commands=["test"], prefix="/"), True],
["!test", Command(commands=["test"], prefix="/"), False],
["/test@mention", Command(commands=["test"], prefix="/"), False],
["/tests", Command(commands=["test"], prefix="/"), False],
["/", Command(commands=["test"], prefix="/"), False],
["/ test", Command(commands=["test"], prefix="/"), False],
["", Command(commands=["test"], prefix="/"), False],
[" ", Command(commands=["test"], prefix="/"), False],
["test", Command(commands=["test"], prefix="/"), False],
[" test", Command(commands=["test"], prefix="/"), False],
["a", Command(commands=["test"], prefix="/"), False],
["/test@tbot some args", Command(commands=["test"]), True],
["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"),
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "some args"),
True,
],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"),
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "test"),
False,
],
["/start test", CommandStart(), True],
@ -99,7 +99,7 @@ class TestCommandFilter:
chat=Chat(id=42, type="private"),
date=datetime.datetime.now(),
)
command = Command(commands=["test"], command_magic=(F.args.as_("args")))
command = Command(commands=["test"], magic=(F.args.as_("args")))
result = await command(message=message, bot=bot)
assert "args" in result
assert result["args"] == "42"

View file

@ -1,52 +0,0 @@
from dataclasses import dataclass
from typing import cast
import pytest
from pydantic import ValidationError
from aiogram.filters import ContentTypesFilter
from aiogram.types import ContentType, Message
pytestmark = pytest.mark.asyncio
@dataclass
class MinimalMessage:
content_type: str
class TestContentTypesFilter:
def test_validator_empty_list(self):
filter_ = ContentTypesFilter(content_types=[])
assert filter_.content_types == []
def test_convert_to_list(self):
filter_ = ContentTypesFilter(content_types="text")
assert filter_.content_types
assert isinstance(filter_.content_types, list)
assert filter_.content_types[0] == "text"
assert filter_ == ContentTypesFilter(content_types=["text"])
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
def test_validator_with_values(self, values):
filter_ = ContentTypesFilter(content_types=values)
assert filter_.content_types == values
@pytest.mark.parametrize("values", [["test"], ["text", "test"], ["TEXT"]])
def test_validator_with_bad_values(self, values):
with pytest.raises(ValidationError):
ContentTypesFilter(content_types=values)
@pytest.mark.parametrize(
"values,content_type,result",
[
[[ContentType.TEXT], ContentType.TEXT, True],
[[ContentType.PHOTO], ContentType.TEXT, False],
[[ContentType.ANY], ContentType.TEXT, True],
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
],
)
async def test_call(self, values, content_type, result):
filter_ = ContentTypesFilter(content_types=values)
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result

View file

@ -46,7 +46,7 @@ class TestExceptionTypeFilter:
],
)
async def test_check(self, exception: Exception, value: bool):
obj = ExceptionTypeFilter(exception=MyException)
obj = ExceptionTypeFilter(exceptions=MyException)
result = await obj(Update(update_id=0), exception=exception)