diff --git a/Makefile b/Makefile index 41761503..12ca3c63 100644 --- a/Makefile +++ b/Makefile @@ -114,8 +114,8 @@ docs-gettext: .PHONY: docs-gettext docs-serve: - rm -rf docs/_build - $(py) sphinx-autobuild --watch aiogram/ docs/ docs/_build/ $(OPTS) + #rm -rf docs/_build + $(py) sphinx-autobuild --watch aiogram/ --watch CHANGELOG.rst --watch README.rst docs/ docs/_build/ $(OPTS) .PHONY: docs-serve $(locale_targets): docs-serve-%: diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 3a507a4e..996f5742 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from magic_filter import MagicFilter from aiogram.dispatcher.flags import extract_flags_from_object +from aiogram.filters.base import Filter from aiogram.handlers import BaseHandler CallbackType = Callable[..., Any] @@ -53,8 +54,12 @@ class FilterObject(CallableMixin): if isinstance(self.callback, MagicFilter): # MagicFilter instance is callable but generates only "CallOperation" instead of applying the filter self.callback = self.callback.resolve + super().__post_init__() + if isinstance(self.callback, Filter): + self.awaitable = True + @dataclass class HandlerObject(CallableMixin): diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index d5170594..8171273f 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -1,17 +1,10 @@ from __future__ import annotations -import warnings -from inspect import isclass -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from aiogram.dispatcher.middlewares.manager import MiddlewareManager -from aiogram.filters.base import BaseFilter -from ...exceptions import FiltersResolveError -from ...filters import BUILTIN_FILTERS_SET +from ...filters.base import Filter from ...types import TelegramObject from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler from .handler import CallbackType, FilterObject, HandlerObject @@ -33,7 +26,6 @@ class TelegramEventObserver: self.event_name: str = event_name self.handlers: List[HandlerObject] = [] - self.filters: List[Type[BaseFilter]] = [] self.middleware = MiddlewareManager() self.outer_middleware = MiddlewareManager() @@ -42,63 +34,16 @@ class TelegramEventObserver: # with dummy callback which never will be used self._handler = HandlerObject(callback=lambda: True, filters=[]) - def filter(self, *filters: CallbackType, _stacklevel: int = 2, **bound_filters: Any) -> None: + def filter(self, *filters: CallbackType) -> None: """ Register filter for all handlers of this event observer :param filters: positional filters :param bound_filters: keyword filters """ - resolved_filters = self.resolve_filters( - filters, bound_filters, _stacklevel=_stacklevel + 1 - ) if self._handler.filters is None: self._handler.filters = [] - self._handler.filters.extend( - [ - FilterObject(filter_) # type: ignore - for filter_ in chain( - resolved_filters, - filters, - ) - ] - ) - - def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: - """ - Register filter class in factory - - :param bound_filter: - """ - if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter): - raise TypeError( - "bound_filter() argument 'bound_filter' must be subclass of BaseFilter" - ) - if bound_filter not in BUILTIN_FILTERS_SET: - warnings.warn( - category=DeprecationWarning, - message="filters factory deprecated and will be removed in 3.0b5," - " use filters directly instead (Example: " - f"`{bound_filter.__name__}(=)` instead of `=`)", - stacklevel=2, - ) - self.filters.append(bound_filter) - - def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]: - """ - Get all bounded filters from current observer and from the parents - with the same event type without duplicates - """ - registry: List[Type[BaseFilter]] = [] - - for router in reversed(tuple(self.router.chain_head)): - observer = router.observers[self.event_name] - - for filter_ in observer.filters: - if filter_ in registry: - continue - yield filter_ - registry.append(filter_) + self._handler.filters.extend([FilterObject(filter_) for filter_ in filters]) def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]: middlewares: List[MiddlewareType[TelegramObject]] = [] @@ -108,112 +53,30 @@ class TelegramEventObserver: return middlewares - def resolve_filters( - self, - filters: Tuple[CallbackType, ...], - full_config: Dict[str, Any], - ignore_default: bool = True, - _stacklevel: int = 2, - ) -> List[BaseFilter]: - """ - Resolve keyword filters via filters factory - - :param filters: positional filters - :param full_config: keyword arguments to initialize bounded filters for router/handler - :param ignore_default: ignore to resolving filters with only default arguments that are not in full_config - """ - bound_filters: List[BaseFilter] = [] - - if ignore_default and not full_config: - return bound_filters - - filter_types = set(type(f) for f in filters) - - validation_errors = [] - for bound_filter in self._resolve_filters_chain(): - # skip filter if filter was used as positional filter: - if bound_filter in filter_types: - continue - - # skip filter with no fields in full_config - if ignore_default: - full_config_keys = set(full_config.keys()) - filter_fields = set(bound_filter.__fields__.keys()) - - if not full_config_keys.intersection(filter_fields): - continue - - # Try to initialize filter. - try: - f = bound_filter(**full_config) - except ValidationError as e: - validation_errors.append(e) - continue - - # Clean full config to prevent to re-initialize another filter - # with the same configuration - for key in f.__fields__: - full_config.pop(key, None) - - bound_filters.append(f) - - if full_config: - possible_cases = [] - for error in validation_errors: - for sum_error in error.errors(): - if sum_error["loc"][0] in full_config: - possible_cases.append(error) - break - - raise FiltersResolveError( - unresolved_fields=set(full_config.keys()), possible_cases=possible_cases - ) - - if bound_filters: - warnings.warn( - category=DeprecationWarning, - message="Filters factory deprecated and will be removed in 3.0b5.\n" - "Use filters directly, for example instead of " - "`@router.message(commands=['help']')` " - "use `@router.message(Command(commands=['help'])`", - stacklevel=_stacklevel, - ) - return bound_filters - def register( self, callback: CallbackType, *filters: CallbackType, flags: Optional[Dict[str, Any]] = None, - _stacklevel: int = 2, - **bound_filters: Any, ) -> CallbackType: """ Register event handler """ if flags is None: flags = {} - resolved_filters = self.resolve_filters( - filters, - bound_filters, - ignore_default=False, - _stacklevel=_stacklevel + 1, - ) - for resolved_filter in resolved_filters: - resolved_filter.update_handler_flags(flags=flags) + + for item in filters: + if isinstance(item, Filter): + item.update_handler_flags(flags=flags) + self.handlers.append( HandlerObject( callback=callback, - filters=[ - FilterObject(filter_) # type: ignore - for filter_ in chain( - resolved_filters, - filters, - ) - ], + filters=[FilterObject(filter_) for filter_ in filters], flags=flags, ) ) + return callback def wrap_outer_middleware( @@ -253,19 +116,15 @@ class TelegramEventObserver: def __call__( self, - *args: CallbackType, + *filters: CallbackType, flags: Optional[Dict[str, Any]] = None, - _stacklevel: int = 2, - **bound_filters: Any, ) -> Callable[[CallbackType], CallbackType]: """ Decorator for registering event handlers """ def wrapper(callback: CallbackType) -> CallbackType: - self.register( - callback, *args, flags=flags, **bound_filters, _stacklevel=_stacklevel + 1 - ) + self.register(callback, *filters, flags=flags) return callback return wrapper diff --git a/aiogram/dispatcher/middlewares/error.py b/aiogram/dispatcher/middlewares/error.py index b33de5f9..4b68c0bc 100644 --- a/aiogram/dispatcher/middlewares/error.py +++ b/aiogram/dispatcher/middlewares/error.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast -from ...types import TelegramObject +from ...types import TelegramObject, Update +from ...types.error_event import ErrorEvent from ..event.bases import UNHANDLED, CancelHandler, SkipHandler from .base import BaseMiddleware @@ -26,7 +27,9 @@ class ErrorsMiddleware(BaseMiddleware): raise except Exception as e: response = await self.router.propagate_event( - update_type="error", event=event, **data, exception=e + update_type="error", + event=ErrorEvent(update=cast(Update, event), exception=e), + **data, ) if response is not UNHANDLED: return response diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index b2400396..6e0a33fd 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -3,8 +3,6 @@ from __future__ import annotations import warnings from typing import Any, Dict, Final, Generator, List, Optional, Set, Union -from aiogram.filters import BUILTIN_FILTERS - from ..types import TelegramObject from ..utils.warnings import CodeHasNoEffect from .event.bases import REJECTED, UNHANDLED @@ -28,7 +26,7 @@ class Router: def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None: """ - :param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory + :param use_builtin_filters: `aiogram` has many builtin filters, and you can controll automatic registration of this filters in factory :param name: Optional router name, can be useful for debugging """ @@ -83,12 +81,6 @@ class Router: "error": self.errors, } - # Builtin filters - if use_builtin_filters: - for name, observer in self.observers.items(): - for builtin_filter in BUILTIN_FILTERS.get(name, ()): - observer.bind_filter(builtin_filter) - def __str__(self) -> str: return f"{type(self).__name__} {self.name!r}" diff --git a/aiogram/filters/__init__.py b/aiogram/filters/__init__.py index 2626d51e..6732b717 100644 --- a/aiogram/filters/__init__.py +++ b/aiogram/filters/__init__.py @@ -1,7 +1,6 @@ -from itertools import chain from typing import Dict, Tuple, Type -from .base import BaseFilter +from .base import Filter from .chat_member_updated import ( ADMINISTRATOR, CREATOR, @@ -18,7 +17,6 @@ from .chat_member_updated import ( ChatMemberUpdatedFilter, ) from .command import Command, CommandObject, CommandStart -from .content_types import ContentTypesFilter from .exception import ExceptionMessageFilter, ExceptionTypeFilter from .logic import and_f, invert_f, or_f from .magic_data import MagicData @@ -27,12 +25,11 @@ from .text import Text __all__ = ( "BUILTIN_FILTERS", - "BaseFilter", + "Filter", "Text", "Command", "CommandObject", "CommandStart", - "ContentTypesFilter", "ExceptionMessageFilter", "ExceptionTypeFilter", "StateFilter", @@ -55,85 +52,4 @@ __all__ = ( "invert_f", ) -_ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,) -_TELEGRAM_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (StateFilter,) - -BUILTIN_FILTERS: Dict[str, Tuple[Type[BaseFilter], ...]] = { - "message": ( - Text, - Command, - ContentTypesFilter, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "edited_message": ( - Text, - Command, - ContentTypesFilter, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "channel_post": ( - Text, - ContentTypesFilter, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "edited_channel_post": ( - Text, - ContentTypesFilter, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "inline_query": ( - Text, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "chosen_inline_result": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "callback_query": ( - Text, - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "shipping_query": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "pre_checkout_query": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "poll": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "poll_answer": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "my_chat_member": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ChatMemberUpdatedFilter, - ), - "chat_member": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ChatMemberUpdatedFilter, - ), - "chat_join_request": ( - *_ALL_EVENTS_FILTERS, - *_TELEGRAM_EVENTS_FILTERS, - ), - "error": ( - ExceptionMessageFilter, - ExceptionTypeFilter, - *_ALL_EVENTS_FILTERS, - ), -} - -BUILTIN_FILTERS_SET = set(chain.from_iterable(BUILTIN_FILTERS.values())) +BUILTIN_FILTERS: Dict[str, Tuple[Type[Filter], ...]] = {} diff --git a/aiogram/filters/base.py b/aiogram/filters/base.py index 82592e13..8facc586 100644 --- a/aiogram/filters/base.py +++ b/aiogram/filters/base.py @@ -1,19 +1,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union -from pydantic import BaseModel - from aiogram.filters.logic import _LogicFilter -class BaseFilter(BaseModel, ABC, _LogicFilter): +class Filter(_LogicFilter, ABC): """ If you want to register own filters like builtin filters you will need to write subclass of this class with overriding the :code:`__call__` method and adding filter attributes. - - BaseFilter is subclass of :class:`pydantic.BaseModel` that's mean all subclasses of BaseFilter has - the validators based on class attributes and custom validator. """ if TYPE_CHECKING: @@ -37,6 +32,14 @@ class BaseFilter(BaseModel, ABC, _LogicFilter): def update_handler_flags(self, flags: Dict[str, Any]) -> None: pass - def __await__(self): # type: ignore # pragma: no cover - # Is needed only for inspection and this method is never be called - return self.__call__ + def _signature_to_string(self, *args: Any, **kwargs: Any) -> str: + items = [repr(arg) for arg in args] + items.extend([f"{k}={v!r}" for k, v in kwargs.items()]) + + return f"{type(self).__name__}({', '.join(items)})" + + def __str__(self) -> str: + return self._signature_to_string() + + +BaseFilter = Filter diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 1f53760b..cb3d7658 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -9,7 +9,7 @@ from uuid import UUID from magic_filter import MagicFilter from pydantic import BaseModel -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.types import CallbackQuery T = TypeVar("T", bound="CallbackData") @@ -122,11 +122,8 @@ class CallbackData(BaseModel): """ return CallbackQueryFilter(callback_data=cls, rule=rule) - # class Config: - # use_enum_values = True - -class CallbackQueryFilter(BaseFilter): +class CallbackQueryFilter(Filter): """ This filter helps to handle callback query. @@ -134,10 +131,18 @@ class CallbackQueryFilter(BaseFilter): via callback data instance """ - callback_data: Type[CallbackData] - """Expected type of callback data""" - rule: Optional[MagicFilter] = None - """Magic rule""" + def __init__( + self, + *, + callback_data: Type[CallbackData], + rule: Optional[MagicFilter] = None, + ): + """ + :param callback_data: Expected type of callback data + :param rule: Magic rule + """ + self.callback_data = callback_data + self.rule = rule async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]: if not isinstance(query, CallbackQuery) or not query.data: @@ -150,7 +155,3 @@ class CallbackQueryFilter(BaseFilter): if self.rule is None or self.rule.resolve(callback_data): return {"callback_data": callback_data} return False - - class Config: - arbitrary_types_allowed = True - use_enum_values = True diff --git a/aiogram/filters/chat_member_updated.py b/aiogram/filters/chat_member_updated.py index 52718399..7d8c036d 100644 --- a/aiogram/filters/chat_member_updated.py +++ b/aiogram/filters/chat_member_updated.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Optional, TypeVar, Union -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.types import ChatMember, ChatMemberUpdated MarkerT = TypeVar("MarkerT", bound="_MemberStatusMarker") @@ -154,16 +154,16 @@ LEAVE_TRANSITION = ~JOIN_TRANSITION PROMOTED_TRANSITION = (MEMBER | RESTRICTED | LEFT | KICKED) >> ADMINISTRATOR -class ChatMemberUpdatedFilter(BaseFilter): - member_status_changed: Union[ - _MemberStatusMarker, - _MemberStatusGroupMarker, - _MemberStatusTransition, - ] - """Accepts the status transition or new status of the member (see usage in docs)""" - - class Config: - arbitrary_types_allowed = True +class ChatMemberUpdatedFilter(Filter): + def __init__( + self, + member_status_changed: Union[ + _MemberStatusMarker, + _MemberStatusGroupMarker, + _MemberStatusTransition, + ], + ): + self.member_status_changed = member_status_changed async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]: old = member_updated.old_chat_member diff --git a/aiogram/filters/command.py b/aiogram/filters/command.py index 1c35a1c0..b46ca80d 100644 --- a/aiogram/filters/command.py +++ b/aiogram/filters/command.py @@ -2,57 +2,111 @@ from __future__ import annotations import re from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Match, + Optional, + Pattern, + Sequence, + Union, + cast, +) from magic_filter import MagicFilter -from pydantic import Field, validator -from aiogram.filters import BaseFilter -from aiogram.types import Message +from aiogram.filters.base import Filter +from aiogram.types import BotCommand, Message from aiogram.utils.deep_linking import decode_payload if TYPE_CHECKING: from aiogram import Bot -CommandPatternType = Union[str, re.Pattern] +CommandPatternType = Union[str, re.Pattern, BotCommand] class CommandException(Exception): pass -class Command(BaseFilter): +class Command(Filter): """ This filter can be helpful for handling commands from the text messages. Works only with :class:`aiogram.types.message.Message` events which have the :code:`text`. """ - commands: Union[Sequence[CommandPatternType], CommandPatternType] - """List of commands (string or compiled regexp patterns)""" - commands_prefix: str = "/" - """Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes, - for example: :code:`"/!"` will work with commands prefixed by :code:`"/"` or :code:`"!"`.""" - commands_ignore_case: bool = False - """Ignore case (Does not work with regexp, use flags instead)""" - commands_ignore_mention: bool = False - """Ignore bot mention. By default bot can not handle commands intended for other bots""" - command_magic: Optional[MagicFilter] = None - """Validate command object via Magic filter after all checks done""" + def __init__( + self, + *values: CommandPatternType, + commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None, + prefix: str = "/", + ignore_case: bool = False, + ignore_mention: bool = False, + magic: Optional[MagicFilter] = None, + ): + """ + List of commands (string or compiled regexp patterns) + + :param prefix: Prefix for command. + Prefix is always a single char but here you can pass all of allowed prefixes, + for example: :code:`"/!"` will work with commands prefixed + by :code:`"/"` or :code:`"!"`. + :param ignore_case: Ignore case (Does not work with regexp, use flags instead) + :param ignore_mention: Ignore bot mention. By default, + bot can not handle commands intended for other bots + :param magic: Validate command object via Magic filter after all checks done + """ + if commands is None: + commands = [] + if isinstance(commands, (str, re.Pattern, BotCommand)): + commands = [commands] + + if not isinstance(commands, Iterable): + ValueError( + "Command filter only supports str, re.Pattern, BotCommand object" + " or their Iterable" + ) + + items = [] + for command in (*values, *commands): + if isinstance(command, BotCommand): + command = command.command + if not isinstance(command, (str, re.Pattern)): + raise ValueError( + "Command filter only supports str, re.Pattern, BotCommand object" + " or their Iterable" + ) + items.append(command) + + if not items: + raise ValueError("At least one command should be specified") + + self.commands = tuple(items) + self.prefix = prefix + self.ignore_case = ignore_case + self.ignore_mention = ignore_mention + self.magic = magic + + def __str__(self) -> str: + return self._signature_to_string( + *self.commands, + prefix=self.prefix, + ignore_case=self.ignore_case, + ignore_mention=self.ignore_mention, + magic=self.magic, + ) def update_handler_flags(self, flags: Dict[str, Any]) -> None: commands = flags.setdefault("commands", []) commands.append(self) - @validator("commands", always=True) - def _validate_commands( - cls, value: Union[Sequence[CommandPatternType], CommandPatternType] - ) -> Sequence[CommandPatternType]: - if isinstance(value, (str, re.Pattern)): - value = [value] - return value - async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: + if not isinstance(message, Message): + return False + text = message.text or message.caption if not text: return False @@ -82,11 +136,11 @@ class Command(BaseFilter): ) def validate_prefix(self, command: CommandObject) -> None: - if command.prefix not in self.commands_prefix: + if command.prefix not in self.prefix: raise CommandException("Invalid command prefix") async def validate_mention(self, bot: Bot, command: CommandObject) -> None: - if command.mention and not self.commands_ignore_mention: + if command.mention and not self.ignore_mention: me = await bot.me() if me.username and command.mention.lower() != me.username.lower(): raise CommandException("Mention did not match") @@ -119,16 +173,13 @@ class Command(BaseFilter): return command def do_magic(self, command: CommandObject) -> Any: - if not self.command_magic: + if not self.magic: return command - result = self.command_magic.resolve(command) + result = self.magic.resolve(command) if not result: raise CommandException("Rejected via magic filter") return replace(command, magic_result=result) - class Config: - arbitrary_types_allowed = True - @dataclass(frozen=True) class CommandObject: @@ -170,10 +221,34 @@ class CommandObject: class CommandStart(Command): - commands: Tuple[str] = Field(("start",), const=True) - commands_prefix: str = Field("/", const=True) - deep_link: bool = False - deep_link_encoded: bool = False + def __init__( + self, + deep_link: bool = False, + deep_link_encoded: bool = False, + ignore_case: bool = False, + ignore_mention: bool = False, + magic: Optional[MagicFilter] = None, + ): + super().__init__( + "start", + prefix="/", + ignore_case=ignore_case, + ignore_mention=ignore_mention, + magic=magic, + ) + self.deep_link = deep_link + self.deep_link_encoded = deep_link_encoded + + def __str__(self) -> str: + return self._signature_to_string( + *self.commands, + prefix=self.prefix, + ignore_case=self.ignore_case, + ignore_mention=self.ignore_mention, + magic=self.magic, + deep_link=self.deep_link, + deep_link_encoded=self.deep_link_encoded, + ) async def parse_command(self, text: str, bot: Bot) -> CommandObject: """ diff --git a/aiogram/filters/content_types.py b/aiogram/filters/content_types.py deleted file mode 100644 index 492c262b..00000000 --- a/aiogram/filters/content_types.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Any, Dict, Optional, Sequence, Union - -from pydantic import validator - -from aiogram.types import Message -from aiogram.types.message import ContentType - -from .base import BaseFilter - - -class ContentTypesFilter(BaseFilter): - """ - Is useful for handling specific types of messages (For example separate text and stickers handlers). - """ - - content_types: Union[Sequence[str], str] - """Sequence of allowed content types""" - - @validator("content_types") - def _validate_content_types( - cls, value: Optional[Union[Sequence[str], str]] - ) -> Optional[Sequence[str]]: - if not value: - return value - if isinstance(value, str): - value = [value] - allowed_content_types = set(ContentType.all()) - bad_content_types = set(value) - allowed_content_types - if bad_content_types: - raise ValueError(f"Invalid content types {bad_content_types} is not allowed here") - return value - - async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]: - return ContentType.ANY in self.content_types or message.content_type in self.content_types diff --git a/aiogram/filters/exception.py b/aiogram/filters/exception.py index 735af17b..95b08f3c 100644 --- a/aiogram/filters/exception.py +++ b/aiogram/filters/exception.py @@ -1,51 +1,47 @@ import re -from typing import Any, Dict, Pattern, Tuple, Type, Union, cast +from typing import Any, Dict, Pattern, Type, Union, cast -from pydantic import validator - -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.types import TelegramObject +from aiogram.types.error_event import ErrorEvent -class ExceptionTypeFilter(BaseFilter): +class ExceptionTypeFilter(Filter): """ Allows to match exception by type """ - exception: Union[Type[Exception], Tuple[Type[Exception]]] - """Exception type(s)""" + def __init__(self, *exceptions: Type[Exception]): + """ + :param exceptions: Exception type(s) + """ + if not exceptions: + raise ValueError("At least one exception type is required") + self.exceptions = exceptions - class Config: - arbitrary_types_allowed = True - - async def __call__( - self, obj: TelegramObject, exception: Exception - ) -> Union[bool, Dict[str, Any]]: - return isinstance(exception, self.exception) + async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]: + return isinstance(cast(ErrorEvent, obj).exception, self.exceptions) -class ExceptionMessageFilter(BaseFilter): +class ExceptionMessageFilter(Filter): """ Allow to match exception by message """ - pattern: Union[str, Pattern[str]] - """Regexp pattern""" - - class Config: - arbitrary_types_allowed = True - - @validator("pattern") - def _validate_match(cls, value: Union[str, Pattern[str]]) -> Union[str, Pattern[str]]: - if isinstance(value, str): - return re.compile(value) - return value + def __init__(self, pattern: Union[str, Pattern[str]]): + """ + :param pattern: Regexp pattern + """ + if not isinstance(pattern, str): + pattern = re.compile(pattern) + self.pattern = pattern async def __call__( - self, obj: TelegramObject, exception: Exception + self, + obj: TelegramObject, ) -> Union[bool, Dict[str, Any]]: pattern = cast(Pattern[str], self.pattern) - result = pattern.match(str(exception)) + result = pattern.match(str(cast(ErrorEvent, obj).exception)) if not result: return False return {"match_exception": result} diff --git a/aiogram/filters/magic_data.py b/aiogram/filters/magic_data.py index c1a19083..e43ae889 100644 --- a/aiogram/filters/magic_data.py +++ b/aiogram/filters/magic_data.py @@ -2,15 +2,13 @@ from typing import Any from magic_filter import AttrDict, MagicFilter -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.types import TelegramObject -class MagicData(BaseFilter): - magic_data: MagicFilter - - class Config: - arbitrary_types_allowed = True +class MagicData(Filter): + def __init__(self, magic_data: MagicFilter) -> None: + self.magic_data = magic_data async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any: return self.magic_data.resolve( diff --git a/aiogram/filters/state.py b/aiogram/filters/state.py index 5363bf15..b5f624e1 100644 --- a/aiogram/filters/state.py +++ b/aiogram/filters/state.py @@ -1,40 +1,28 @@ from inspect import isclass -from typing import Any, Dict, Optional, Sequence, Type, Union, cast, no_type_check +from typing import Any, Dict, Optional, Sequence, Type, Union, cast -from pydantic import Field, validator - -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.fsm.state import State, StatesGroup from aiogram.types import TelegramObject StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]] -class StateFilter(BaseFilter): +class StateFilter(Filter): """ State filter """ - state: Union[StateType, Sequence[StateType]] = Field(...) + def __init__(self, *states: StateType) -> None: + if not states: + raise ValueError("At least one state is required") - class Config: - arbitrary_types_allowed = True - - @validator("state") - @no_type_check # issubclass breaks things - def _validate_state(cls, v: Union[StateType, Sequence[StateType]]) -> Sequence[StateType]: - if ( - isinstance(v, (str, State, StatesGroup)) - or (isclass(v) and issubclass(v, StatesGroup)) - or v is None - ): - return [v] - return v + self.states = states async def __call__( self, obj: Union[TelegramObject], raw_state: Optional[str] = None ) -> Union[bool, Dict[str, Any]]: - allowed_states = cast(Sequence[StateType], self.state) + allowed_states = cast(Sequence[StateType], self.states) for allowed_state in allowed_states: if isinstance(allowed_state, str) or allowed_state is None: if allowed_state == "*" or raw_state == allowed_state: diff --git a/aiogram/filters/text.py b/aiogram/filters/text.py index aa2d6bb5..a9a76720 100644 --- a/aiogram/filters/text.py +++ b/aiogram/filters/text.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union -from pydantic import root_validator - -from aiogram.filters import BaseFilter +from aiogram.filters.base import Filter from aiogram.types import CallbackQuery, InlineQuery, Message, Poll if TYPE_CHECKING: @@ -11,7 +9,7 @@ if TYPE_CHECKING: TextType = Union[str, "LazyProxy"] -class Text(BaseFilter): +class Text(Filter): """ Is useful for filtering text :class:`aiogram.types.message.Message`, any :class:`aiogram.types.callback_query.CallbackQuery` with `data`, @@ -19,7 +17,7 @@ class Text(BaseFilter): .. warning:: - Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once. + Only one of `text`, `contains`, `startswith` or `endswith` argument can be used at once. Any of that arguments can be string, list, set or tuple of strings. .. deprecated:: 3.0 @@ -27,40 +25,54 @@ class Text(BaseFilter): use :ref:`magic-filter `. For example do :pycode:`F.text == "text"` instead """ - text: Optional[Union[Sequence[TextType], TextType]] = None - """Text equals value or one of values""" - text_contains: Optional[Union[Sequence[TextType], TextType]] = None - """Text contains value or one of values""" - text_startswith: Optional[Union[Sequence[TextType], TextType]] = None - """Text starts with value or one of values""" - text_endswith: Optional[Union[Sequence[TextType], TextType]] = None - """Text ends with value or one of values""" - text_ignore_case: bool = False - """Ignore case when checks""" + def __init__( + self, + text: Optional[Union[Sequence[TextType], TextType]] = None, + *, + contains: Optional[Union[Sequence[TextType], TextType]] = None, + startswith: Optional[Union[Sequence[TextType], TextType]] = None, + endswith: Optional[Union[Sequence[TextType], TextType]] = None, + ignore_case: bool = False, + ): + """ - class Config: - arbitrary_types_allowed = True - - @root_validator - def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: - # Validate that only one text filter type is presented - used_args = set( - key for key, value in values.items() if key != "text_ignore_case" and value is not None + :param text: Text equals value or one of values + :param contains: Text contains value or one of values + :param startswith: Text starts with value or one of values + :param endswith: Text ends with value or one of values + :param ignore_case: Ignore case when checks + """ + self._validate_constraints( + text=text, + contains=contains, + startswith=startswith, + endswith=endswith, ) + self.text = text + self.contains = contains + self.startswith = startswith + self.endswith = endswith + self.ignore_case = ignore_case + + @classmethod + def _prepare_argument( + cls, value: Optional[Union[Sequence[TextType], TextType]] + ) -> Optional[Sequence[TextType]]: + from aiogram.utils.i18n.lazy_proxy import LazyProxy + + if isinstance(value, (str, LazyProxy)): + return [value] + return value + + @classmethod + def _validate_constraints(cls, **values: Any) -> None: + # Validate that only one text filter type is presented + used_args = set(key for key, value in values.items() if value is not None) if len(used_args) < 1: - raise ValueError( - "Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}" - ) + raise ValueError(f"Filter should contain one of arguments: {set(values.keys())}") if len(used_args) > 1: raise ValueError(f"Arguments {used_args} cannot be used together") - # Convert single value to list - for arg in used_args: - if isinstance(values[arg], str): - values[arg] = [values[arg]] - - return values - async def __call__( self, obj: Union[Message, CallbackQuery, InlineQuery, Poll] ) -> Union[bool, Dict[str, Any]]: @@ -79,30 +91,30 @@ class Text(BaseFilter): if not text: return False - if self.text_ignore_case: + if self.ignore_case: text = text.lower() if self.text is not None: equals = list(map(self.prepare_text, self.text)) return text in equals - if self.text_contains is not None: - contains = list(map(self.prepare_text, self.text_contains)) + if self.contains is not None: + contains = list(map(self.prepare_text, self.contains)) return all(map(text.__contains__, contains)) - if self.text_startswith is not None: - startswith = list(map(self.prepare_text, self.text_startswith)) + if self.startswith is not None: + startswith = list(map(self.prepare_text, self.startswith)) return any(map(text.startswith, startswith)) - if self.text_endswith is not None: - endswith = list(map(self.prepare_text, self.text_endswith)) + if self.endswith is not None: + endswith = list(map(self.prepare_text, self.endswith)) return any(map(text.endswith, endswith)) # Impossible because the validator prevents this situation return False # pragma: no cover def prepare_text(self, text: str) -> str: - if self.text_ignore_case: + if self.ignore_case: return str(text).lower() else: return str(text) diff --git a/aiogram/types/error_event.py b/aiogram/types/error_event.py new file mode 100644 index 00000000..6ab03303 --- /dev/null +++ b/aiogram/types/error_event.py @@ -0,0 +1,10 @@ +from aiogram.types import Update +from aiogram.types.base import MutableTelegramObject + + +class ErrorEvent(MutableTelegramObject): + update: Update + exception: Exception + + class Config: + arbitrary_types_allowed = True diff --git a/docs/dispatcher/filters/command.rst b/docs/dispatcher/filters/command.rst index 7281d751..d899cabb 100644 --- a/docs/dispatcher/filters/command.rst +++ b/docs/dispatcher/filters/command.rst @@ -3,7 +3,7 @@ Command ======= .. autoclass:: aiogram.filters.command.Command - :members: + :members: __init__ :member-order: bysource :undoc-members: False @@ -18,10 +18,10 @@ When filter is passed the :class:`aiogram.filters.command.CommandObject` will be Usage ===== -1. Filter single variant of commands: :code:`Command(commands=["start"])` or :code:`Command(commands="start")` -2. Handle command by regexp pattern: :code:`Command(commands=[re.compile(r"item_(\d+)")])` -3. Match command by multiple variants: :code:`Command(commands=["item", re.compile(r"item_(\d+)")])` -4. Handle commands in public chats intended for other bots: :code:`Command(commands=["command"], commands_ignore_mention=True)` +1. Filter single variant of commands: :code:`Command("start")` +2. Handle command by regexp pattern: :code:`Command(re.compile(r"item_(\d+)"))` +3. Match command by multiple variants: :code:`Command("item", re.compile(r"item_(\d+)"))` +4. Handle commands in public chats intended for other bots: :code:`Command("command", ignore_mention=True)` .. warning:: diff --git a/examples/multibot.py b/examples/multibot.py index 83001680..c9ce099f 100644 --- a/examples/multibot.py +++ b/examples/multibot.py @@ -39,7 +39,7 @@ def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]: return True -@main_router.message(Command(commands=["add"], command_magic=F.args.func(is_bot_token))) +@main_router.message(Command(commands=["add"], magic=F.args.func(is_bot_token))) async def command_add_bot(message: Message, command: CommandObject, bot: Bot) -> Any: new_bot = Bot(token=command.args, session=bot.session) try: diff --git a/tests/test_dispatcher/test_event/test_handler.py b/tests/test_dispatcher/test_event/test_handler.py index 2bdbd880..00cf295b 100644 --- a/tests/test_dispatcher/test_event/test_handler.py +++ b/tests/test_dispatcher/test_event/test_handler.py @@ -5,7 +5,7 @@ import pytest from aiogram import F from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject -from aiogram.filters import BaseFilter +from aiogram.filters import Filter from aiogram.handlers import BaseHandler from aiogram.types import Update @@ -28,7 +28,7 @@ async def callback4(foo: int, *, bar: int, baz: int): return locals() -class Filter(BaseFilter): +class Filter(Filter): async def __call__(self, foo: int, bar: int, baz: int) -> Union[bool, Dict[str, Any]]: return locals() diff --git a/tests/test_dispatcher/test_event/test_telegram.py b/tests/test_dispatcher/test_event/test_telegram.py index ebb4ae87..550d8841 100644 --- a/tests/test_dispatcher/test_event/test_telegram.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -9,7 +9,7 @@ from aiogram.dispatcher.event.handler import HandlerObject from aiogram.dispatcher.event.telegram import TelegramEventObserver from aiogram.dispatcher.router import Router from aiogram.exceptions import FiltersResolveError -from aiogram.filters import BaseFilter, Command +from aiogram.filters import Command, Filter from aiogram.types import Chat, Message, User from tests.deprecated import check_deprecated @@ -31,7 +31,7 @@ async def pipe_handler(*args, **kwargs): return args, kwargs -class MyFilter1(BaseFilter): +class MyFilter1(Filter): test: str async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: @@ -46,14 +46,14 @@ class MyFilter3(MyFilter1): pass -class OptionalFilter(BaseFilter): +class OptionalFilter(Filter): optional: Optional[str] async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: return True -class DefaultFilter(BaseFilter): +class DefaultFilter(Filter): default: str = "Default" async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: @@ -66,7 +66,7 @@ class TestTelegramEventObserver: with pytest.raises(TypeError): event_observer.bind_filter(object) # type: ignore - class MyFilter(BaseFilter): + class MyFilter(Filter): async def __call__( self, *args: Any, **kwargs: Any ) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]: @@ -103,13 +103,13 @@ class TestTelegramEventObserver: assert MyFilter3 not in filters_chain1 async def test_resolve_filters_data_from_parent_router(self): - class FilterSet(BaseFilter): + class FilterSet(Filter): set_filter: bool async def __call__(self, message: Message) -> dict: return {"test": "hello world"} - class FilterGet(BaseFilter): + class FilterGet(Filter): get_filter: bool async def __call__(self, message: Message, **data) -> bool: diff --git a/tests/test_filters/test_base.py b/tests/test_filters/test_base.py index 4d854cbd..1c30a4f1 100644 --- a/tests/test_filters/test_base.py +++ b/tests/test_filters/test_base.py @@ -2,7 +2,7 @@ from typing import Awaitable import pytest -from aiogram.filters import BaseFilter +from aiogram.filters import Filter try: from asynctest import CoroutineMock, patch @@ -13,7 +13,7 @@ except ImportError: pytestmark = pytest.mark.asyncio -class MyFilter(BaseFilter): +class MyFilter(Filter): foo: str async def __call__(self, event: str): diff --git a/tests/test_filters/test_command.py b/tests/test_filters/test_command.py index 4e75a9b7..815640b0 100644 --- a/tests/test_filters/test_command.py +++ b/tests/test_filters/test_command.py @@ -16,34 +16,34 @@ class TestCommandFilter: def test_convert_to_list(self): cmd = Command(commands="start") assert cmd.commands - assert isinstance(cmd.commands, list) + assert isinstance(cmd.commands, tuple) assert cmd.commands[0] == "start" - assert cmd == Command(commands=["start"]) + # assert cmd == Command(commands=["start"]) @pytest.mark.parametrize( "text,command,result", [ - ["/test@tbot", Command(commands=["test"], commands_prefix="/"), True], - ["!test", Command(commands=["test"], commands_prefix="/"), False], - ["/test@mention", Command(commands=["test"], commands_prefix="/"), False], - ["/tests", Command(commands=["test"], commands_prefix="/"), False], - ["/", Command(commands=["test"], commands_prefix="/"), False], - ["/ test", Command(commands=["test"], commands_prefix="/"), False], - ["", Command(commands=["test"], commands_prefix="/"), False], - [" ", Command(commands=["test"], commands_prefix="/"), False], - ["test", Command(commands=["test"], commands_prefix="/"), False], - [" test", Command(commands=["test"], commands_prefix="/"), False], - ["a", Command(commands=["test"], commands_prefix="/"), False], + ["/test@tbot", Command(commands=["test"], prefix="/"), True], + ["!test", Command(commands=["test"], prefix="/"), False], + ["/test@mention", Command(commands=["test"], prefix="/"), False], + ["/tests", Command(commands=["test"], prefix="/"), False], + ["/", Command(commands=["test"], prefix="/"), False], + ["/ test", Command(commands=["test"], prefix="/"), False], + ["", Command(commands=["test"], prefix="/"), False], + [" ", Command(commands=["test"], prefix="/"), False], + ["test", Command(commands=["test"], prefix="/"), False], + [" test", Command(commands=["test"], prefix="/"), False], + ["a", Command(commands=["test"], prefix="/"), False], ["/test@tbot some args", Command(commands=["test"]), True], ["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True], [ "/test42@tbot some args", - Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"), + Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "some args"), True, ], [ "/test42@tbot some args", - Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"), + Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "test"), False, ], ["/start test", CommandStart(), True], @@ -99,7 +99,7 @@ class TestCommandFilter: chat=Chat(id=42, type="private"), date=datetime.datetime.now(), ) - command = Command(commands=["test"], command_magic=(F.args.as_("args"))) + command = Command(commands=["test"], magic=(F.args.as_("args"))) result = await command(message=message, bot=bot) assert "args" in result assert result["args"] == "42" diff --git a/tests/test_filters/test_content_types.py b/tests/test_filters/test_content_types.py deleted file mode 100644 index f6822519..00000000 --- a/tests/test_filters/test_content_types.py +++ /dev/null @@ -1,52 +0,0 @@ -from dataclasses import dataclass -from typing import cast - -import pytest -from pydantic import ValidationError - -from aiogram.filters import ContentTypesFilter -from aiogram.types import ContentType, Message - -pytestmark = pytest.mark.asyncio - - -@dataclass -class MinimalMessage: - content_type: str - - -class TestContentTypesFilter: - def test_validator_empty_list(self): - filter_ = ContentTypesFilter(content_types=[]) - assert filter_.content_types == [] - - def test_convert_to_list(self): - filter_ = ContentTypesFilter(content_types="text") - assert filter_.content_types - assert isinstance(filter_.content_types, list) - assert filter_.content_types[0] == "text" - assert filter_ == ContentTypesFilter(content_types=["text"]) - - @pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]]) - def test_validator_with_values(self, values): - filter_ = ContentTypesFilter(content_types=values) - assert filter_.content_types == values - - @pytest.mark.parametrize("values", [["test"], ["text", "test"], ["TEXT"]]) - def test_validator_with_bad_values(self, values): - with pytest.raises(ValidationError): - ContentTypesFilter(content_types=values) - - @pytest.mark.parametrize( - "values,content_type,result", - [ - [[ContentType.TEXT], ContentType.TEXT, True], - [[ContentType.PHOTO], ContentType.TEXT, False], - [[ContentType.ANY], ContentType.TEXT, True], - [[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True], - [[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True], - ], - ) - async def test_call(self, values, content_type, result): - filter_ = ContentTypesFilter(content_types=values) - assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result diff --git a/tests/test_filters/test_exception.py b/tests/test_filters/test_exception.py index 498ef6d6..68b9fae7 100644 --- a/tests/test_filters/test_exception.py +++ b/tests/test_filters/test_exception.py @@ -46,7 +46,7 @@ class TestExceptionTypeFilter: ], ) async def test_check(self, exception: Exception, value: bool): - obj = ExceptionTypeFilter(exception=MyException) + obj = ExceptionTypeFilter(exceptions=MyException) result = await obj(Update(update_id=0), exception=exception)