mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Rewrite filters
This commit is contained in:
parent
3f57c69d4f
commit
8d2aae77c1
24 changed files with 311 additions and 539 deletions
4
Makefile
4
Makefile
|
|
@ -114,8 +114,8 @@ docs-gettext:
|
|||
.PHONY: docs-gettext
|
||||
|
||||
docs-serve:
|
||||
rm -rf docs/_build
|
||||
$(py) sphinx-autobuild --watch aiogram/ docs/ docs/_build/ $(OPTS)
|
||||
#rm -rf docs/_build
|
||||
$(py) sphinx-autobuild --watch aiogram/ --watch CHANGELOG.rst --watch README.rst docs/ docs/_build/ $(OPTS)
|
||||
.PHONY: docs-serve
|
||||
|
||||
$(locale_targets): docs-serve-%:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
from magic_filter import MagicFilter
|
||||
|
||||
from aiogram.dispatcher.flags import extract_flags_from_object
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.handlers import BaseHandler
|
||||
|
||||
CallbackType = Callable[..., Any]
|
||||
|
|
@ -53,8 +54,12 @@ class FilterObject(CallableMixin):
|
|||
if isinstance(self.callback, MagicFilter):
|
||||
# MagicFilter instance is callable but generates only "CallOperation" instead of applying the filter
|
||||
self.callback = self.callback.resolve
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if isinstance(self.callback, Filter):
|
||||
self.awaitable = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandlerObject(CallableMixin):
|
||||
|
|
|
|||
|
|
@ -1,17 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import ValidationError
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
from aiogram.filters.base import BaseFilter
|
||||
|
||||
from ...exceptions import FiltersResolveError
|
||||
from ...filters import BUILTIN_FILTERS_SET
|
||||
from ...filters.base import Filter
|
||||
from ...types import TelegramObject
|
||||
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, HandlerObject
|
||||
|
|
@ -33,7 +26,6 @@ class TelegramEventObserver:
|
|||
self.event_name: str = event_name
|
||||
|
||||
self.handlers: List[HandlerObject] = []
|
||||
self.filters: List[Type[BaseFilter]] = []
|
||||
|
||||
self.middleware = MiddlewareManager()
|
||||
self.outer_middleware = MiddlewareManager()
|
||||
|
|
@ -42,63 +34,16 @@ class TelegramEventObserver:
|
|||
# with dummy callback which never will be used
|
||||
self._handler = HandlerObject(callback=lambda: True, filters=[])
|
||||
|
||||
def filter(self, *filters: CallbackType, _stacklevel: int = 2, **bound_filters: Any) -> None:
|
||||
def filter(self, *filters: CallbackType) -> None:
|
||||
"""
|
||||
Register filter for all handlers of this event observer
|
||||
|
||||
:param filters: positional filters
|
||||
:param bound_filters: keyword filters
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(
|
||||
filters, bound_filters, _stacklevel=_stacklevel + 1
|
||||
)
|
||||
if self._handler.filters is None:
|
||||
self._handler.filters = []
|
||||
self._handler.filters.extend(
|
||||
[
|
||||
FilterObject(filter_) # type: ignore
|
||||
for filter_ in chain(
|
||||
resolved_filters,
|
||||
filters,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
"""
|
||||
Register filter class in factory
|
||||
|
||||
:param bound_filter:
|
||||
"""
|
||||
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
|
||||
raise TypeError(
|
||||
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
||||
)
|
||||
if bound_filter not in BUILTIN_FILTERS_SET:
|
||||
warnings.warn(
|
||||
category=DeprecationWarning,
|
||||
message="filters factory deprecated and will be removed in 3.0b5,"
|
||||
" use filters directly instead (Example: "
|
||||
f"`{bound_filter.__name__}(<argument>=<value>)` instead of `<argument>=<value>`)",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.filters.append(bound_filter)
|
||||
|
||||
def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]:
|
||||
"""
|
||||
Get all bounded filters from current observer and from the parents
|
||||
with the same event type without duplicates
|
||||
"""
|
||||
registry: List[Type[BaseFilter]] = []
|
||||
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
|
||||
for filter_ in observer.filters:
|
||||
if filter_ in registry:
|
||||
continue
|
||||
yield filter_
|
||||
registry.append(filter_)
|
||||
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
|
||||
|
||||
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
|
||||
middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
|
|
@ -108,112 +53,30 @@ class TelegramEventObserver:
|
|||
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(
|
||||
self,
|
||||
filters: Tuple[CallbackType, ...],
|
||||
full_config: Dict[str, Any],
|
||||
ignore_default: bool = True,
|
||||
_stacklevel: int = 2,
|
||||
) -> List[BaseFilter]:
|
||||
"""
|
||||
Resolve keyword filters via filters factory
|
||||
|
||||
:param filters: positional filters
|
||||
:param full_config: keyword arguments to initialize bounded filters for router/handler
|
||||
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
|
||||
"""
|
||||
bound_filters: List[BaseFilter] = []
|
||||
|
||||
if ignore_default and not full_config:
|
||||
return bound_filters
|
||||
|
||||
filter_types = set(type(f) for f in filters)
|
||||
|
||||
validation_errors = []
|
||||
for bound_filter in self._resolve_filters_chain():
|
||||
# skip filter if filter was used as positional filter:
|
||||
if bound_filter in filter_types:
|
||||
continue
|
||||
|
||||
# skip filter with no fields in full_config
|
||||
if ignore_default:
|
||||
full_config_keys = set(full_config.keys())
|
||||
filter_fields = set(bound_filter.__fields__.keys())
|
||||
|
||||
if not full_config_keys.intersection(filter_fields):
|
||||
continue
|
||||
|
||||
# Try to initialize filter.
|
||||
try:
|
||||
f = bound_filter(**full_config)
|
||||
except ValidationError as e:
|
||||
validation_errors.append(e)
|
||||
continue
|
||||
|
||||
# Clean full config to prevent to re-initialize another filter
|
||||
# with the same configuration
|
||||
for key in f.__fields__:
|
||||
full_config.pop(key, None)
|
||||
|
||||
bound_filters.append(f)
|
||||
|
||||
if full_config:
|
||||
possible_cases = []
|
||||
for error in validation_errors:
|
||||
for sum_error in error.errors():
|
||||
if sum_error["loc"][0] in full_config:
|
||||
possible_cases.append(error)
|
||||
break
|
||||
|
||||
raise FiltersResolveError(
|
||||
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
|
||||
)
|
||||
|
||||
if bound_filters:
|
||||
warnings.warn(
|
||||
category=DeprecationWarning,
|
||||
message="Filters factory deprecated and will be removed in 3.0b5.\n"
|
||||
"Use filters directly, for example instead of "
|
||||
"`@router.message(commands=['help']')` "
|
||||
"use `@router.message(Command(commands=['help'])`",
|
||||
stacklevel=_stacklevel,
|
||||
)
|
||||
return bound_filters
|
||||
|
||||
def register(
|
||||
self,
|
||||
callback: CallbackType,
|
||||
*filters: CallbackType,
|
||||
flags: Optional[Dict[str, Any]] = None,
|
||||
_stacklevel: int = 2,
|
||||
**bound_filters: Any,
|
||||
) -> CallbackType:
|
||||
"""
|
||||
Register event handler
|
||||
"""
|
||||
if flags is None:
|
||||
flags = {}
|
||||
resolved_filters = self.resolve_filters(
|
||||
filters,
|
||||
bound_filters,
|
||||
ignore_default=False,
|
||||
_stacklevel=_stacklevel + 1,
|
||||
)
|
||||
for resolved_filter in resolved_filters:
|
||||
resolved_filter.update_handler_flags(flags=flags)
|
||||
|
||||
for item in filters:
|
||||
if isinstance(item, Filter):
|
||||
item.update_handler_flags(flags=flags)
|
||||
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
filters=[
|
||||
FilterObject(filter_) # type: ignore
|
||||
for filter_ in chain(
|
||||
resolved_filters,
|
||||
filters,
|
||||
)
|
||||
],
|
||||
filters=[FilterObject(filter_) for filter_ in filters],
|
||||
flags=flags,
|
||||
)
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
def wrap_outer_middleware(
|
||||
|
|
@ -253,19 +116,15 @@ class TelegramEventObserver:
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
*args: CallbackType,
|
||||
*filters: CallbackType,
|
||||
flags: Optional[Dict[str, Any]] = None,
|
||||
_stacklevel: int = 2,
|
||||
**bound_filters: Any,
|
||||
) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(
|
||||
callback, *args, flags=flags, **bound_filters, _stacklevel=_stacklevel + 1
|
||||
)
|
||||
self.register(callback, *filters, flags=flags)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast
|
||||
|
||||
from ...types import TelegramObject
|
||||
from ...types import TelegramObject, Update
|
||||
from ...types.error_event import ErrorEvent
|
||||
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
|
@ -26,7 +27,9 @@ class ErrorsMiddleware(BaseMiddleware):
|
|||
raise
|
||||
except Exception as e:
|
||||
response = await self.router.propagate_event(
|
||||
update_type="error", event=event, **data, exception=e
|
||||
update_type="error",
|
||||
event=ErrorEvent(update=cast(Update, event), exception=e),
|
||||
**data,
|
||||
)
|
||||
if response is not UNHANDLED:
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ from __future__ import annotations
|
|||
import warnings
|
||||
from typing import Any, Dict, Final, Generator, List, Optional, Set, Union
|
||||
|
||||
from aiogram.filters import BUILTIN_FILTERS
|
||||
|
||||
from ..types import TelegramObject
|
||||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
|
|
@ -28,7 +26,7 @@ class Router:
|
|||
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
|
||||
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
|
||||
:param use_builtin_filters: `aiogram` has many builtin filters, and you can controll automatic registration of this filters in factory
|
||||
:param name: Optional router name, can be useful for debugging
|
||||
"""
|
||||
|
||||
|
|
@ -83,12 +81,6 @@ class Router:
|
|||
"error": self.errors,
|
||||
}
|
||||
|
||||
# Builtin filters
|
||||
if use_builtin_filters:
|
||||
for name, observer in self.observers.items():
|
||||
for builtin_filter in BUILTIN_FILTERS.get(name, ()):
|
||||
observer.bind_filter(builtin_filter)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__} {self.name!r}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from itertools import chain
|
||||
from typing import Dict, Tuple, Type
|
||||
|
||||
from .base import BaseFilter
|
||||
from .base import Filter
|
||||
from .chat_member_updated import (
|
||||
ADMINISTRATOR,
|
||||
CREATOR,
|
||||
|
|
@ -18,7 +17,6 @@ from .chat_member_updated import (
|
|||
ChatMemberUpdatedFilter,
|
||||
)
|
||||
from .command import Command, CommandObject, CommandStart
|
||||
from .content_types import ContentTypesFilter
|
||||
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from .logic import and_f, invert_f, or_f
|
||||
from .magic_data import MagicData
|
||||
|
|
@ -27,12 +25,11 @@ from .text import Text
|
|||
|
||||
__all__ = (
|
||||
"BUILTIN_FILTERS",
|
||||
"BaseFilter",
|
||||
"Filter",
|
||||
"Text",
|
||||
"Command",
|
||||
"CommandObject",
|
||||
"CommandStart",
|
||||
"ContentTypesFilter",
|
||||
"ExceptionMessageFilter",
|
||||
"ExceptionTypeFilter",
|
||||
"StateFilter",
|
||||
|
|
@ -55,85 +52,4 @@ __all__ = (
|
|||
"invert_f",
|
||||
)
|
||||
|
||||
_ALL_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (MagicData,)
|
||||
_TELEGRAM_EVENTS_FILTERS: Tuple[Type[BaseFilter], ...] = (StateFilter,)
|
||||
|
||||
BUILTIN_FILTERS: Dict[str, Tuple[Type[BaseFilter], ...]] = {
|
||||
"message": (
|
||||
Text,
|
||||
Command,
|
||||
ContentTypesFilter,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"edited_message": (
|
||||
Text,
|
||||
Command,
|
||||
ContentTypesFilter,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"channel_post": (
|
||||
Text,
|
||||
ContentTypesFilter,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"edited_channel_post": (
|
||||
Text,
|
||||
ContentTypesFilter,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"inline_query": (
|
||||
Text,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"chosen_inline_result": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"callback_query": (
|
||||
Text,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"shipping_query": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"pre_checkout_query": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"poll": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"poll_answer": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"my_chat_member": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
ChatMemberUpdatedFilter,
|
||||
),
|
||||
"chat_member": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
ChatMemberUpdatedFilter,
|
||||
),
|
||||
"chat_join_request": (
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
*_TELEGRAM_EVENTS_FILTERS,
|
||||
),
|
||||
"error": (
|
||||
ExceptionMessageFilter,
|
||||
ExceptionTypeFilter,
|
||||
*_ALL_EVENTS_FILTERS,
|
||||
),
|
||||
}
|
||||
|
||||
BUILTIN_FILTERS_SET = set(chain.from_iterable(BUILTIN_FILTERS.values()))
|
||||
BUILTIN_FILTERS: Dict[str, Tuple[Type[Filter], ...]] = {}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiogram.filters.logic import _LogicFilter
|
||||
|
||||
|
||||
class BaseFilter(BaseModel, ABC, _LogicFilter):
|
||||
class Filter(_LogicFilter, ABC):
|
||||
"""
|
||||
If you want to register own filters like builtin filters you will need to write subclass
|
||||
of this class with overriding the :code:`__call__`
|
||||
method and adding filter attributes.
|
||||
|
||||
BaseFilter is subclass of :class:`pydantic.BaseModel` that's mean all subclasses of BaseFilter has
|
||||
the validators based on class attributes and custom validator.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -37,6 +32,14 @@ class BaseFilter(BaseModel, ABC, _LogicFilter):
|
|||
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def __await__(self): # type: ignore # pragma: no cover
|
||||
# Is needed only for inspection and this method is never be called
|
||||
return self.__call__
|
||||
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
|
||||
items = [repr(arg) for arg in args]
|
||||
items.extend([f"{k}={v!r}" for k, v in kwargs.items()])
|
||||
|
||||
return f"{type(self).__name__}({', '.join(items)})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._signature_to_string()
|
||||
|
||||
|
||||
BaseFilter = Filter
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from uuid import UUID
|
|||
from magic_filter import MagicFilter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import CallbackQuery
|
||||
|
||||
T = TypeVar("T", bound="CallbackData")
|
||||
|
|
@ -122,11 +122,8 @@ class CallbackData(BaseModel):
|
|||
"""
|
||||
return CallbackQueryFilter(callback_data=cls, rule=rule)
|
||||
|
||||
# class Config:
|
||||
# use_enum_values = True
|
||||
|
||||
|
||||
class CallbackQueryFilter(BaseFilter):
|
||||
class CallbackQueryFilter(Filter):
|
||||
"""
|
||||
This filter helps to handle callback query.
|
||||
|
||||
|
|
@ -134,10 +131,18 @@ class CallbackQueryFilter(BaseFilter):
|
|||
via callback data instance
|
||||
"""
|
||||
|
||||
callback_data: Type[CallbackData]
|
||||
"""Expected type of callback data"""
|
||||
rule: Optional[MagicFilter] = None
|
||||
"""Magic rule"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
callback_data: Type[CallbackData],
|
||||
rule: Optional[MagicFilter] = None,
|
||||
):
|
||||
"""
|
||||
:param callback_data: Expected type of callback data
|
||||
:param rule: Magic rule
|
||||
"""
|
||||
self.callback_data = callback_data
|
||||
self.rule = rule
|
||||
|
||||
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
|
||||
if not isinstance(query, CallbackQuery) or not query.data:
|
||||
|
|
@ -150,7 +155,3 @@ class CallbackQueryFilter(BaseFilter):
|
|||
if self.rule is None or self.rule.resolve(callback_data):
|
||||
return {"callback_data": callback_data}
|
||||
return False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
use_enum_values = True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict, Optional, TypeVar, Union
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import ChatMember, ChatMemberUpdated
|
||||
|
||||
MarkerT = TypeVar("MarkerT", bound="_MemberStatusMarker")
|
||||
|
|
@ -154,16 +154,16 @@ LEAVE_TRANSITION = ~JOIN_TRANSITION
|
|||
PROMOTED_TRANSITION = (MEMBER | RESTRICTED | LEFT | KICKED) >> ADMINISTRATOR
|
||||
|
||||
|
||||
class ChatMemberUpdatedFilter(BaseFilter):
|
||||
member_status_changed: Union[
|
||||
_MemberStatusMarker,
|
||||
_MemberStatusGroupMarker,
|
||||
_MemberStatusTransition,
|
||||
]
|
||||
"""Accepts the status transition or new status of the member (see usage in docs)"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
class ChatMemberUpdatedFilter(Filter):
|
||||
def __init__(
|
||||
self,
|
||||
member_status_changed: Union[
|
||||
_MemberStatusMarker,
|
||||
_MemberStatusGroupMarker,
|
||||
_MemberStatusTransition,
|
||||
],
|
||||
):
|
||||
self.member_status_changed = member_status_changed
|
||||
|
||||
async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]:
|
||||
old = member_updated.old_chat_member
|
||||
|
|
|
|||
|
|
@ -2,57 +2,111 @@ from __future__ import annotations
|
|||
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from magic_filter import MagicFilter
|
||||
from pydantic import Field, validator
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.types import Message
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import BotCommand, Message
|
||||
from aiogram.utils.deep_linking import decode_payload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram import Bot
|
||||
|
||||
CommandPatternType = Union[str, re.Pattern]
|
||||
CommandPatternType = Union[str, re.Pattern, BotCommand]
|
||||
|
||||
|
||||
class CommandException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Command(BaseFilter):
|
||||
class Command(Filter):
|
||||
"""
|
||||
This filter can be helpful for handling commands from the text messages.
|
||||
|
||||
Works only with :class:`aiogram.types.message.Message` events which have the :code:`text`.
|
||||
"""
|
||||
|
||||
commands: Union[Sequence[CommandPatternType], CommandPatternType]
|
||||
"""List of commands (string or compiled regexp patterns)"""
|
||||
commands_prefix: str = "/"
|
||||
"""Prefix for command. Prefix is always is single char but here you can pass all of allowed prefixes,
|
||||
for example: :code:`"/!"` will work with commands prefixed by :code:`"/"` or :code:`"!"`."""
|
||||
commands_ignore_case: bool = False
|
||||
"""Ignore case (Does not work with regexp, use flags instead)"""
|
||||
commands_ignore_mention: bool = False
|
||||
"""Ignore bot mention. By default bot can not handle commands intended for other bots"""
|
||||
command_magic: Optional[MagicFilter] = None
|
||||
"""Validate command object via Magic filter after all checks done"""
|
||||
def __init__(
|
||||
self,
|
||||
*values: CommandPatternType,
|
||||
commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None,
|
||||
prefix: str = "/",
|
||||
ignore_case: bool = False,
|
||||
ignore_mention: bool = False,
|
||||
magic: Optional[MagicFilter] = None,
|
||||
):
|
||||
"""
|
||||
List of commands (string or compiled regexp patterns)
|
||||
|
||||
:param prefix: Prefix for command.
|
||||
Prefix is always a single char but here you can pass all of allowed prefixes,
|
||||
for example: :code:`"/!"` will work with commands prefixed
|
||||
by :code:`"/"` or :code:`"!"`.
|
||||
:param ignore_case: Ignore case (Does not work with regexp, use flags instead)
|
||||
:param ignore_mention: Ignore bot mention. By default,
|
||||
bot can not handle commands intended for other bots
|
||||
:param magic: Validate command object via Magic filter after all checks done
|
||||
"""
|
||||
if commands is None:
|
||||
commands = []
|
||||
if isinstance(commands, (str, re.Pattern, BotCommand)):
|
||||
commands = [commands]
|
||||
|
||||
if not isinstance(commands, Iterable):
|
||||
ValueError(
|
||||
"Command filter only supports str, re.Pattern, BotCommand object"
|
||||
" or their Iterable"
|
||||
)
|
||||
|
||||
items = []
|
||||
for command in (*values, *commands):
|
||||
if isinstance(command, BotCommand):
|
||||
command = command.command
|
||||
if not isinstance(command, (str, re.Pattern)):
|
||||
raise ValueError(
|
||||
"Command filter only supports str, re.Pattern, BotCommand object"
|
||||
" or their Iterable"
|
||||
)
|
||||
items.append(command)
|
||||
|
||||
if not items:
|
||||
raise ValueError("At least one command should be specified")
|
||||
|
||||
self.commands = tuple(items)
|
||||
self.prefix = prefix
|
||||
self.ignore_case = ignore_case
|
||||
self.ignore_mention = ignore_mention
|
||||
self.magic = magic
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._signature_to_string(
|
||||
*self.commands,
|
||||
prefix=self.prefix,
|
||||
ignore_case=self.ignore_case,
|
||||
ignore_mention=self.ignore_mention,
|
||||
magic=self.magic,
|
||||
)
|
||||
|
||||
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
||||
commands = flags.setdefault("commands", [])
|
||||
commands.append(self)
|
||||
|
||||
@validator("commands", always=True)
|
||||
def _validate_commands(
|
||||
cls, value: Union[Sequence[CommandPatternType], CommandPatternType]
|
||||
) -> Sequence[CommandPatternType]:
|
||||
if isinstance(value, (str, re.Pattern)):
|
||||
value = [value]
|
||||
return value
|
||||
|
||||
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
||||
if not isinstance(message, Message):
|
||||
return False
|
||||
|
||||
text = message.text or message.caption
|
||||
if not text:
|
||||
return False
|
||||
|
|
@ -82,11 +136,11 @@ class Command(BaseFilter):
|
|||
)
|
||||
|
||||
def validate_prefix(self, command: CommandObject) -> None:
|
||||
if command.prefix not in self.commands_prefix:
|
||||
if command.prefix not in self.prefix:
|
||||
raise CommandException("Invalid command prefix")
|
||||
|
||||
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
|
||||
if command.mention and not self.commands_ignore_mention:
|
||||
if command.mention and not self.ignore_mention:
|
||||
me = await bot.me()
|
||||
if me.username and command.mention.lower() != me.username.lower():
|
||||
raise CommandException("Mention did not match")
|
||||
|
|
@ -119,16 +173,13 @@ class Command(BaseFilter):
|
|||
return command
|
||||
|
||||
def do_magic(self, command: CommandObject) -> Any:
|
||||
if not self.command_magic:
|
||||
if not self.magic:
|
||||
return command
|
||||
result = self.command_magic.resolve(command)
|
||||
result = self.magic.resolve(command)
|
||||
if not result:
|
||||
raise CommandException("Rejected via magic filter")
|
||||
return replace(command, magic_result=result)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandObject:
|
||||
|
|
@ -170,10 +221,34 @@ class CommandObject:
|
|||
|
||||
|
||||
class CommandStart(Command):
|
||||
commands: Tuple[str] = Field(("start",), const=True)
|
||||
commands_prefix: str = Field("/", const=True)
|
||||
deep_link: bool = False
|
||||
deep_link_encoded: bool = False
|
||||
def __init__(
|
||||
self,
|
||||
deep_link: bool = False,
|
||||
deep_link_encoded: bool = False,
|
||||
ignore_case: bool = False,
|
||||
ignore_mention: bool = False,
|
||||
magic: Optional[MagicFilter] = None,
|
||||
):
|
||||
super().__init__(
|
||||
"start",
|
||||
prefix="/",
|
||||
ignore_case=ignore_case,
|
||||
ignore_mention=ignore_mention,
|
||||
magic=magic,
|
||||
)
|
||||
self.deep_link = deep_link
|
||||
self.deep_link_encoded = deep_link_encoded
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._signature_to_string(
|
||||
*self.commands,
|
||||
prefix=self.prefix,
|
||||
ignore_case=self.ignore_case,
|
||||
ignore_mention=self.ignore_mention,
|
||||
magic=self.magic,
|
||||
deep_link=self.deep_link,
|
||||
deep_link_encoded=self.deep_link_encoded,
|
||||
)
|
||||
|
||||
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from aiogram.types import Message
|
||||
from aiogram.types.message import ContentType
|
||||
|
||||
from .base import BaseFilter
|
||||
|
||||
|
||||
class ContentTypesFilter(BaseFilter):
|
||||
"""
|
||||
Is useful for handling specific types of messages (For example separate text and stickers handlers).
|
||||
"""
|
||||
|
||||
content_types: Union[Sequence[str], str]
|
||||
"""Sequence of allowed content types"""
|
||||
|
||||
@validator("content_types")
|
||||
def _validate_content_types(
|
||||
cls, value: Optional[Union[Sequence[str], str]]
|
||||
) -> Optional[Sequence[str]]:
|
||||
if not value:
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
allowed_content_types = set(ContentType.all())
|
||||
bad_content_types = set(value) - allowed_content_types
|
||||
if bad_content_types:
|
||||
raise ValueError(f"Invalid content types {bad_content_types} is not allowed here")
|
||||
return value
|
||||
|
||||
async def __call__(self, message: Message) -> Union[bool, Dict[str, Any]]:
|
||||
return ContentType.ANY in self.content_types or message.content_type in self.content_types
|
||||
|
|
@ -1,51 +1,47 @@
|
|||
import re
|
||||
from typing import Any, Dict, Pattern, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, Pattern, Type, Union, cast
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import TelegramObject
|
||||
from aiogram.types.error_event import ErrorEvent
|
||||
|
||||
|
||||
class ExceptionTypeFilter(BaseFilter):
|
||||
class ExceptionTypeFilter(Filter):
|
||||
"""
|
||||
Allows to match exception by type
|
||||
"""
|
||||
|
||||
exception: Union[Type[Exception], Tuple[Type[Exception]]]
|
||||
"""Exception type(s)"""
|
||||
def __init__(self, *exceptions: Type[Exception]):
|
||||
"""
|
||||
:param exceptions: Exception type(s)
|
||||
"""
|
||||
if not exceptions:
|
||||
raise ValueError("At least one exception type is required")
|
||||
self.exceptions = exceptions
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def __call__(
|
||||
self, obj: TelegramObject, exception: Exception
|
||||
) -> Union[bool, Dict[str, Any]]:
|
||||
return isinstance(exception, self.exception)
|
||||
async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]:
|
||||
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
|
||||
|
||||
|
||||
class ExceptionMessageFilter(BaseFilter):
|
||||
class ExceptionMessageFilter(Filter):
|
||||
"""
|
||||
Allow to match exception by message
|
||||
"""
|
||||
|
||||
pattern: Union[str, Pattern[str]]
|
||||
"""Regexp pattern"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("pattern")
|
||||
def _validate_match(cls, value: Union[str, Pattern[str]]) -> Union[str, Pattern[str]]:
|
||||
if isinstance(value, str):
|
||||
return re.compile(value)
|
||||
return value
|
||||
def __init__(self, pattern: Union[str, Pattern[str]]):
|
||||
"""
|
||||
:param pattern: Regexp pattern
|
||||
"""
|
||||
if not isinstance(pattern, str):
|
||||
pattern = re.compile(pattern)
|
||||
self.pattern = pattern
|
||||
|
||||
async def __call__(
|
||||
self, obj: TelegramObject, exception: Exception
|
||||
self,
|
||||
obj: TelegramObject,
|
||||
) -> Union[bool, Dict[str, Any]]:
|
||||
pattern = cast(Pattern[str], self.pattern)
|
||||
result = pattern.match(str(exception))
|
||||
result = pattern.match(str(cast(ErrorEvent, obj).exception))
|
||||
if not result:
|
||||
return False
|
||||
return {"match_exception": result}
|
||||
|
|
|
|||
|
|
@ -2,15 +2,13 @@ from typing import Any
|
|||
|
||||
from magic_filter import AttrDict, MagicFilter
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class MagicData(BaseFilter):
|
||||
magic_data: MagicFilter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
class MagicData(Filter):
|
||||
def __init__(self, magic_data: MagicFilter) -> None:
|
||||
self.magic_data = magic_data
|
||||
|
||||
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.magic_data.resolve(
|
||||
|
|
|
|||
|
|
@ -1,40 +1,28 @@
|
|||
from inspect import isclass
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union, cast, no_type_check
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
|
||||
|
||||
from pydantic import Field, validator
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.fsm.state import State, StatesGroup
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]]
|
||||
|
||||
|
||||
class StateFilter(BaseFilter):
|
||||
class StateFilter(Filter):
|
||||
"""
|
||||
State filter
|
||||
"""
|
||||
|
||||
state: Union[StateType, Sequence[StateType]] = Field(...)
|
||||
def __init__(self, *states: StateType) -> None:
|
||||
if not states:
|
||||
raise ValueError("At least one state is required")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("state")
|
||||
@no_type_check # issubclass breaks things
|
||||
def _validate_state(cls, v: Union[StateType, Sequence[StateType]]) -> Sequence[StateType]:
|
||||
if (
|
||||
isinstance(v, (str, State, StatesGroup))
|
||||
or (isclass(v) and issubclass(v, StatesGroup))
|
||||
or v is None
|
||||
):
|
||||
return [v]
|
||||
return v
|
||||
self.states = states
|
||||
|
||||
async def __call__(
|
||||
self, obj: Union[TelegramObject], raw_state: Optional[str] = None
|
||||
) -> Union[bool, Dict[str, Any]]:
|
||||
allowed_states = cast(Sequence[StateType], self.state)
|
||||
allowed_states = cast(Sequence[StateType], self.states)
|
||||
for allowed_state in allowed_states:
|
||||
if isinstance(allowed_state, str) or allowed_state is None:
|
||||
if allowed_state == "*" or raw_state == allowed_state:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import CallbackQuery, InlineQuery, Message, Poll
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -11,7 +9,7 @@ if TYPE_CHECKING:
|
|||
TextType = Union[str, "LazyProxy"]
|
||||
|
||||
|
||||
class Text(BaseFilter):
|
||||
class Text(Filter):
|
||||
"""
|
||||
Is useful for filtering text :class:`aiogram.types.message.Message`,
|
||||
any :class:`aiogram.types.callback_query.CallbackQuery` with `data`,
|
||||
|
|
@ -19,7 +17,7 @@ class Text(BaseFilter):
|
|||
|
||||
.. warning::
|
||||
|
||||
Only one of `text`, `text_contains`, `text_startswith` or `text_endswith` argument can be used at once.
|
||||
Only one of `text`, `contains`, `startswith` or `endswith` argument can be used at once.
|
||||
Any of that arguments can be string, list, set or tuple of strings.
|
||||
|
||||
.. deprecated:: 3.0
|
||||
|
|
@ -27,40 +25,54 @@ class Text(BaseFilter):
|
|||
use :ref:`magic-filter <magic-filters>`. For example do :pycode:`F.text == "text"` instead
|
||||
"""
|
||||
|
||||
text: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
"""Text equals value or one of values"""
|
||||
text_contains: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
"""Text contains value or one of values"""
|
||||
text_startswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
"""Text starts with value or one of values"""
|
||||
text_endswith: Optional[Union[Sequence[TextType], TextType]] = None
|
||||
"""Text ends with value or one of values"""
|
||||
text_ignore_case: bool = False
|
||||
"""Ignore case when checks"""
|
||||
def __init__(
|
||||
self,
|
||||
text: Optional[Union[Sequence[TextType], TextType]] = None,
|
||||
*,
|
||||
contains: Optional[Union[Sequence[TextType], TextType]] = None,
|
||||
startswith: Optional[Union[Sequence[TextType], TextType]] = None,
|
||||
endswith: Optional[Union[Sequence[TextType], TextType]] = None,
|
||||
ignore_case: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def _validate_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Validate that only one text filter type is presented
|
||||
used_args = set(
|
||||
key for key, value in values.items() if key != "text_ignore_case" and value is not None
|
||||
:param text: Text equals value or one of values
|
||||
:param contains: Text contains value or one of values
|
||||
:param startswith: Text starts with value or one of values
|
||||
:param endswith: Text ends with value or one of values
|
||||
:param ignore_case: Ignore case when checks
|
||||
"""
|
||||
self._validate_constraints(
|
||||
text=text,
|
||||
contains=contains,
|
||||
startswith=startswith,
|
||||
endswith=endswith,
|
||||
)
|
||||
self.text = text
|
||||
self.contains = contains
|
||||
self.startswith = startswith
|
||||
self.endswith = endswith
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
@classmethod
|
||||
def _prepare_argument(
|
||||
cls, value: Optional[Union[Sequence[TextType], TextType]]
|
||||
) -> Optional[Sequence[TextType]]:
|
||||
from aiogram.utils.i18n.lazy_proxy import LazyProxy
|
||||
|
||||
if isinstance(value, (str, LazyProxy)):
|
||||
return [value]
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _validate_constraints(cls, **values: Any) -> None:
|
||||
# Validate that only one text filter type is presented
|
||||
used_args = set(key for key, value in values.items() if value is not None)
|
||||
if len(used_args) < 1:
|
||||
raise ValueError(
|
||||
"Filter should contain one of arguments: {'text', 'text_contains', 'text_startswith', 'text_endswith'}"
|
||||
)
|
||||
raise ValueError(f"Filter should contain one of arguments: {set(values.keys())}")
|
||||
if len(used_args) > 1:
|
||||
raise ValueError(f"Arguments {used_args} cannot be used together")
|
||||
|
||||
# Convert single value to list
|
||||
for arg in used_args:
|
||||
if isinstance(values[arg], str):
|
||||
values[arg] = [values[arg]]
|
||||
|
||||
return values
|
||||
|
||||
async def __call__(
|
||||
self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]
|
||||
) -> Union[bool, Dict[str, Any]]:
|
||||
|
|
@ -79,30 +91,30 @@ class Text(BaseFilter):
|
|||
|
||||
if not text:
|
||||
return False
|
||||
if self.text_ignore_case:
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.text is not None:
|
||||
equals = list(map(self.prepare_text, self.text))
|
||||
return text in equals
|
||||
|
||||
if self.text_contains is not None:
|
||||
contains = list(map(self.prepare_text, self.text_contains))
|
||||
if self.contains is not None:
|
||||
contains = list(map(self.prepare_text, self.contains))
|
||||
return all(map(text.__contains__, contains))
|
||||
|
||||
if self.text_startswith is not None:
|
||||
startswith = list(map(self.prepare_text, self.text_startswith))
|
||||
if self.startswith is not None:
|
||||
startswith = list(map(self.prepare_text, self.startswith))
|
||||
return any(map(text.startswith, startswith))
|
||||
|
||||
if self.text_endswith is not None:
|
||||
endswith = list(map(self.prepare_text, self.text_endswith))
|
||||
if self.endswith is not None:
|
||||
endswith = list(map(self.prepare_text, self.endswith))
|
||||
return any(map(text.endswith, endswith))
|
||||
|
||||
# Impossible because the validator prevents this situation
|
||||
return False # pragma: no cover
|
||||
|
||||
def prepare_text(self, text: str) -> str:
|
||||
if self.text_ignore_case:
|
||||
if self.ignore_case:
|
||||
return str(text).lower()
|
||||
else:
|
||||
return str(text)
|
||||
|
|
|
|||
10
aiogram/types/error_event.py
Normal file
10
aiogram/types/error_event.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from aiogram.types import Update
|
||||
from aiogram.types.base import MutableTelegramObject
|
||||
|
||||
|
||||
class ErrorEvent(MutableTelegramObject):
|
||||
update: Update
|
||||
exception: Exception
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
@ -3,7 +3,7 @@ Command
|
|||
=======
|
||||
|
||||
.. autoclass:: aiogram.filters.command.Command
|
||||
:members:
|
||||
:members: __init__
|
||||
:member-order: bysource
|
||||
:undoc-members: False
|
||||
|
||||
|
|
@ -18,10 +18,10 @@ When filter is passed the :class:`aiogram.filters.command.CommandObject` will be
|
|||
Usage
|
||||
=====
|
||||
|
||||
1. Filter single variant of commands: :code:`Command(commands=["start"])` or :code:`Command(commands="start")`
|
||||
2. Handle command by regexp pattern: :code:`Command(commands=[re.compile(r"item_(\d+)")])`
|
||||
3. Match command by multiple variants: :code:`Command(commands=["item", re.compile(r"item_(\d+)")])`
|
||||
4. Handle commands in public chats intended for other bots: :code:`Command(commands=["command"], commands_ignore_mention=True)`
|
||||
1. Filter single variant of commands: :code:`Command("start")`
|
||||
2. Handle command by regexp pattern: :code:`Command(re.compile(r"item_(\d+)"))`
|
||||
3. Match command by multiple variants: :code:`Command("item", re.compile(r"item_(\d+)"))`
|
||||
4. Handle commands in public chats intended for other bots: :code:`Command("command", ignore_mention=True)`
|
||||
|
||||
.. warning::
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]:
|
|||
return True
|
||||
|
||||
|
||||
@main_router.message(Command(commands=["add"], command_magic=F.args.func(is_bot_token)))
|
||||
@main_router.message(Command(commands=["add"], magic=F.args.func(is_bot_token)))
|
||||
async def command_add_bot(message: Message, command: CommandObject, bot: Bot) -> Any:
|
||||
new_bot = Bot(token=command.args, session=bot.session)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
|
||||
from aiogram import F
|
||||
from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters import Filter
|
||||
from aiogram.handlers import BaseHandler
|
||||
from aiogram.types import Update
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ async def callback4(foo: int, *, bar: int, baz: int):
|
|||
return locals()
|
||||
|
||||
|
||||
class Filter(BaseFilter):
|
||||
class Filter(Filter):
|
||||
async def __call__(self, foo: int, bar: int, baz: int) -> Union[bool, Dict[str, Any]]:
|
||||
return locals()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from aiogram.dispatcher.event.handler import HandlerObject
|
|||
from aiogram.dispatcher.event.telegram import TelegramEventObserver
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.exceptions import FiltersResolveError
|
||||
from aiogram.filters import BaseFilter, Command
|
||||
from aiogram.filters import Command, Filter
|
||||
from aiogram.types import Chat, Message, User
|
||||
from tests.deprecated import check_deprecated
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ async def pipe_handler(*args, **kwargs):
|
|||
return args, kwargs
|
||||
|
||||
|
||||
class MyFilter1(BaseFilter):
|
||||
class MyFilter1(Filter):
|
||||
test: str
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
|
|
@ -46,14 +46,14 @@ class MyFilter3(MyFilter1):
|
|||
pass
|
||||
|
||||
|
||||
class OptionalFilter(BaseFilter):
|
||||
class OptionalFilter(Filter):
|
||||
optional: Optional[str]
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class DefaultFilter(BaseFilter):
|
||||
class DefaultFilter(Filter):
|
||||
default: str = "Default"
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
|
|
@ -66,7 +66,7 @@ class TestTelegramEventObserver:
|
|||
with pytest.raises(TypeError):
|
||||
event_observer.bind_filter(object) # type: ignore
|
||||
|
||||
class MyFilter(BaseFilter):
|
||||
class MyFilter(Filter):
|
||||
async def __call__(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
|
||||
|
|
@ -103,13 +103,13 @@ class TestTelegramEventObserver:
|
|||
assert MyFilter3 not in filters_chain1
|
||||
|
||||
async def test_resolve_filters_data_from_parent_router(self):
|
||||
class FilterSet(BaseFilter):
|
||||
class FilterSet(Filter):
|
||||
set_filter: bool
|
||||
|
||||
async def __call__(self, message: Message) -> dict:
|
||||
return {"test": "hello world"}
|
||||
|
||||
class FilterGet(BaseFilter):
|
||||
class FilterGet(Filter):
|
||||
get_filter: bool
|
||||
|
||||
async def __call__(self, message: Message, **data) -> bool:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Awaitable
|
|||
|
||||
import pytest
|
||||
|
||||
from aiogram.filters import BaseFilter
|
||||
from aiogram.filters import Filter
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
|
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class MyFilter(BaseFilter):
|
||||
class MyFilter(Filter):
|
||||
foo: str
|
||||
|
||||
async def __call__(self, event: str):
|
||||
|
|
|
|||
|
|
@ -16,34 +16,34 @@ class TestCommandFilter:
|
|||
def test_convert_to_list(self):
|
||||
cmd = Command(commands="start")
|
||||
assert cmd.commands
|
||||
assert isinstance(cmd.commands, list)
|
||||
assert isinstance(cmd.commands, tuple)
|
||||
assert cmd.commands[0] == "start"
|
||||
assert cmd == Command(commands=["start"])
|
||||
# assert cmd == Command(commands=["start"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,command,result",
|
||||
[
|
||||
["/test@tbot", Command(commands=["test"], commands_prefix="/"), True],
|
||||
["!test", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["/test@mention", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["/tests", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["/", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["/ test", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["", Command(commands=["test"], commands_prefix="/"), False],
|
||||
[" ", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["test", Command(commands=["test"], commands_prefix="/"), False],
|
||||
[" test", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["a", Command(commands=["test"], commands_prefix="/"), False],
|
||||
["/test@tbot", Command(commands=["test"], prefix="/"), True],
|
||||
["!test", Command(commands=["test"], prefix="/"), False],
|
||||
["/test@mention", Command(commands=["test"], prefix="/"), False],
|
||||
["/tests", Command(commands=["test"], prefix="/"), False],
|
||||
["/", Command(commands=["test"], prefix="/"), False],
|
||||
["/ test", Command(commands=["test"], prefix="/"), False],
|
||||
["", Command(commands=["test"], prefix="/"), False],
|
||||
[" ", Command(commands=["test"], prefix="/"), False],
|
||||
["test", Command(commands=["test"], prefix="/"), False],
|
||||
[" test", Command(commands=["test"], prefix="/"), False],
|
||||
["a", Command(commands=["test"], prefix="/"), False],
|
||||
["/test@tbot some args", Command(commands=["test"]), True],
|
||||
["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True],
|
||||
[
|
||||
"/test42@tbot some args",
|
||||
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"),
|
||||
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "some args"),
|
||||
True,
|
||||
],
|
||||
[
|
||||
"/test42@tbot some args",
|
||||
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"),
|
||||
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "test"),
|
||||
False,
|
||||
],
|
||||
["/start test", CommandStart(), True],
|
||||
|
|
@ -99,7 +99,7 @@ class TestCommandFilter:
|
|||
chat=Chat(id=42, type="private"),
|
||||
date=datetime.datetime.now(),
|
||||
)
|
||||
command = Command(commands=["test"], command_magic=(F.args.as_("args")))
|
||||
command = Command(commands=["test"], magic=(F.args.as_("args")))
|
||||
result = await command(message=message, bot=bot)
|
||||
assert "args" in result
|
||||
assert result["args"] == "42"
|
||||
|
|
|
|||
|
|
@ -1,52 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiogram.filters import ContentTypesFilter
|
||||
from aiogram.types import ContentType, Message
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimalMessage:
|
||||
content_type: str
|
||||
|
||||
|
||||
class TestContentTypesFilter:
|
||||
def test_validator_empty_list(self):
|
||||
filter_ = ContentTypesFilter(content_types=[])
|
||||
assert filter_.content_types == []
|
||||
|
||||
def test_convert_to_list(self):
|
||||
filter_ = ContentTypesFilter(content_types="text")
|
||||
assert filter_.content_types
|
||||
assert isinstance(filter_.content_types, list)
|
||||
assert filter_.content_types[0] == "text"
|
||||
assert filter_ == ContentTypesFilter(content_types=["text"])
|
||||
|
||||
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
|
||||
def test_validator_with_values(self, values):
|
||||
filter_ = ContentTypesFilter(content_types=values)
|
||||
assert filter_.content_types == values
|
||||
|
||||
@pytest.mark.parametrize("values", [["test"], ["text", "test"], ["TEXT"]])
|
||||
def test_validator_with_bad_values(self, values):
|
||||
with pytest.raises(ValidationError):
|
||||
ContentTypesFilter(content_types=values)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values,content_type,result",
|
||||
[
|
||||
[[ContentType.TEXT], ContentType.TEXT, True],
|
||||
[[ContentType.PHOTO], ContentType.TEXT, False],
|
||||
[[ContentType.ANY], ContentType.TEXT, True],
|
||||
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
|
||||
],
|
||||
)
|
||||
async def test_call(self, values, content_type, result):
|
||||
filter_ = ContentTypesFilter(content_types=values)
|
||||
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result
|
||||
|
|
@ -46,7 +46,7 @@ class TestExceptionTypeFilter:
|
|||
],
|
||||
)
|
||||
async def test_check(self, exception: Exception, value: bool):
|
||||
obj = ExceptionTypeFilter(exception=MyException)
|
||||
obj = ExceptionTypeFilter(exceptions=MyException)
|
||||
|
||||
result = await obj(Update(update_id=0), exception=exception)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue