From 06f24a8cb3f31188470cb338b4347571b7adb831 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 3 Oct 2022 01:04:53 +0300 Subject: [PATCH] Added explicit logic filters, added slots to all other filters --- aiogram/filters/__init__.py | 9 ++- aiogram/filters/base.py | 23 +------- aiogram/filters/callback_data.py | 5 ++ aiogram/filters/chat_member_updated.py | 14 +++++ aiogram/filters/command.py | 8 +++ aiogram/filters/exception.py | 4 ++ aiogram/filters/logic.py | 77 ++++++++++++++++++++++++++ aiogram/filters/magic_data.py | 6 ++ aiogram/filters/state.py | 2 + aiogram/filters/text.py | 8 +++ docs/dispatcher/filters/index.rst | 45 +++++++++++++++ tests/test_filters/test_base.py | 18 ------ tests/test_filters/test_logic.py | 38 +++++++++++++ 13 files changed, 214 insertions(+), 43 deletions(-) create mode 100644 aiogram/filters/logic.py create mode 100644 tests/test_filters/test_logic.py diff --git a/aiogram/filters/__init__.py b/aiogram/filters/__init__.py index 9041f6b6..b3bee0d8 100644 --- a/aiogram/filters/__init__.py +++ b/aiogram/filters/__init__.py @@ -1,5 +1,3 @@ -from typing import Dict, Tuple, Type - from .base import Filter from .chat_member_updated import ( ADMINISTRATOR, @@ -18,6 +16,7 @@ from .chat_member_updated import ( ) from .command import Command, CommandObject, CommandStart from .exception import ExceptionMessageFilter, ExceptionTypeFilter +from .logic import and_f, invert_f, or_f from .magic_data import MagicData from .state import StateFilter from .text import Text @@ -25,7 +24,6 @@ from .text import Text BaseFilter = Filter __all__ = ( - "BUILTIN_FILTERS", "Filter", "BaseFilter", "Text", @@ -49,6 +47,7 @@ __all__ = ( "IS_NOT_MEMBER", "JOIN_TRANSITION", "LEAVE_TRANSITION", + "and_f", + "or_f", + "invert_f", ) - -BUILTIN_FILTERS: Dict[str, Tuple[Type[Filter], ...]] = {} diff --git a/aiogram/filters/base.py b/aiogram/filters/base.py index 9e2e21b8..fdec295e 100644 --- a/aiogram/filters/base.py +++ b/aiogram/filters/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union if TYPE_CHECKING: - from aiogram.dispatcher.event.handler import CallbackType, FilterObject + from aiogram.filters.logic import _InvertFilter class Filter(ABC): @@ -31,6 +31,8 @@ class Filter(ABC): pass def __invert__(self) -> "_InvertFilter": + from aiogram.filters.logic import invert_f + return invert_f(self) def update_handler_flags(self, flags: Dict[str, Any]) -> None: @@ -50,22 +52,3 @@ class Filter(ABC): def __await__(self): # type: ignore # pragma: no cover # Is needed only for inspection and this method is never be called return self.__call__ - - -class _InvertFilter(Filter): - __slots__ = ("target",) - - def __init__(self, target: "FilterObject") -> None: - self.target = target - - async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: - return not bool(await self.target.call(*args, **kwargs)) - - def __str__(self) -> str: - return f"~{self.target.callback}" - - -def invert_f(target: "CallbackType") -> _InvertFilter: - from aiogram.dispatcher.event.handler import FilterObject - - return _InvertFilter(target=FilterObject(target)) diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 750294be..cf06d7f1 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -131,6 +131,11 @@ class CallbackQueryFilter(Filter): via callback data instance """ + __slots__ = ( + "callback_data", + "rule", + ) + def __init__( self, *, diff --git a/aiogram/filters/chat_member_updated.py b/aiogram/filters/chat_member_updated.py index 99eee928..650ce631 100644 --- a/aiogram/filters/chat_member_updated.py +++ b/aiogram/filters/chat_member_updated.py @@ -9,6 +9,11 @@ TransitionT = TypeVar("TransitionT", bound="_MemberStatusTransition") class _MemberStatusMarker: + __slots__ = ( + "name", + "is_member", + ) + def __init__(self, name: str, *, is_member: Optional[bool] = None) -> None: self.name = name self.is_member = is_member @@ -72,6 +77,8 @@ class _MemberStatusMarker: class _MemberStatusGroupMarker: + __slots__ = ("statuses",) + def __init__(self, *statuses: _MemberStatusMarker) -> None: if not statuses: raise ValueError("Member status group should have at least one status included") @@ -124,6 +131,11 @@ class _MemberStatusGroupMarker: class _MemberStatusTransition: + __slots__ = ( + "old", + "new", + ) + def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None: self.old = old self.new = new @@ -155,6 +167,8 @@ PROMOTED_TRANSITION = (MEMBER | RESTRICTED | LEFT | KICKED) >> ADMINISTRATOR class ChatMemberUpdatedFilter(Filter): + __slots__ = ("member_status_changed",) + def __init__( self, member_status_changed: Union[ diff --git a/aiogram/filters/command.py b/aiogram/filters/command.py index 01993efb..edf68ef9 100644 --- a/aiogram/filters/command.py +++ b/aiogram/filters/command.py @@ -38,6 +38,14 @@ class Command(Filter): Works only with :class:`aiogram.types.message.Message` events which have the :code:`text`. """ + __slots__ = ( + "commands", + "prefix", + "ignore_case", + "ignore_mention", + "magic", + ) + def __init__( self, *values: CommandPatternType, diff --git a/aiogram/filters/exception.py b/aiogram/filters/exception.py index c04ec0dc..2530d751 100644 --- a/aiogram/filters/exception.py +++ b/aiogram/filters/exception.py @@ -11,6 +11,8 @@ class ExceptionTypeFilter(Filter): Allows to match exception by type """ + __slots__ = ("exceptions",) + def __init__(self, *exceptions: Type[Exception]): """ :param exceptions: Exception type(s) @@ -28,6 +30,8 @@ class ExceptionMessageFilter(Filter): Allow to match exception by message """ + __slots__ = ("pattern",) + def __init__(self, pattern: Union[str, Pattern[str]]): """ :param pattern: Regexp pattern diff --git a/aiogram/filters/logic.py b/aiogram/filters/logic.py new file mode 100644 index 00000000..b9c95de1 --- /dev/null +++ b/aiogram/filters/logic.py @@ -0,0 +1,77 @@ +from abc import ABC +from typing import TYPE_CHECKING, Any, Dict, Union + +from aiogram.filters import Filter + +if TYPE_CHECKING: + from aiogram.dispatcher.event.handler import CallbackType, FilterObject + + +class _LogicFilter(Filter, ABC): + pass + + +class _InvertFilter(_LogicFilter): + __slots__ = ("target",) + + def __init__(self, target: "FilterObject") -> None: + self.target = target + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return not bool(await self.target.call(*args, **kwargs)) + + +class _AndFilter(_LogicFilter): + __slots__ = ("targets",) + + def __init__(self, *targets: "FilterObject") -> None: + self.targets = targets + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + final_result = {} + + for target in self.targets: + result = await target.call(*args, **kwargs) + if not result: + return False + if isinstance(result, dict): + final_result.update(result) + + if final_result: + return final_result + return True + + +class _OrFilter(_LogicFilter): + __slots__ = ("targets",) + + def __init__(self, *targets: "FilterObject") -> None: + self.targets = targets + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + for target in self.targets: + result = await target.call(*args, **kwargs) + if not result: + continue + if isinstance(result, dict): + return result + return bool(result) + return False + + +def and_f(target1: "CallbackType", target2: "CallbackType") -> _AndFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _AndFilter(FilterObject(target1), FilterObject(target2)) + + +def or_f(target1: "CallbackType", target2: "CallbackType") -> _OrFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _OrFilter(FilterObject(target1), FilterObject(target2)) + + +def invert_f(target: "CallbackType") -> _InvertFilter: + from aiogram.dispatcher.event.handler import FilterObject + + return _InvertFilter(FilterObject(target)) diff --git a/aiogram/filters/magic_data.py b/aiogram/filters/magic_data.py index a9b9baf0..351c8e4f 100644 --- a/aiogram/filters/magic_data.py +++ b/aiogram/filters/magic_data.py @@ -7,6 +7,12 @@ from aiogram.types import TelegramObject class MagicData(Filter): + """ + This filter helps to filter event with contextual data + """ + + __slots__ = "magic_data" + def __init__(self, magic_data: MagicFilter) -> None: self.magic_data = magic_data diff --git a/aiogram/filters/state.py b/aiogram/filters/state.py index 0168d130..5ad65ae5 100644 --- a/aiogram/filters/state.py +++ b/aiogram/filters/state.py @@ -13,6 +13,8 @@ class StateFilter(Filter): State filter """ + __slots__ = ("states",) + def __init__(self, *states: StateType) -> None: if not states: raise ValueError("At least one state is required") diff --git a/aiogram/filters/text.py b/aiogram/filters/text.py index 96748185..52585f95 100644 --- a/aiogram/filters/text.py +++ b/aiogram/filters/text.py @@ -25,6 +25,14 @@ class Text(Filter): use :ref:`magic-filter `. For example do :pycode:`F.text == "text"` instead """ + __slots__ = ( + "text", + "contains", + "startswith", + "endswith", + "ignore_case", + ) + def __init__( self, text: Optional[Union[Sequence[TextType], TextType]] = None, diff --git a/docs/dispatcher/filters/index.rst b/docs/dispatcher/filters/index.rst index 9eed36c3..01006774 100644 --- a/docs/dispatcher/filters/index.rst +++ b/docs/dispatcher/filters/index.rst @@ -53,3 +53,48 @@ Own filter example For example if you need to make simple text filter: .. literalinclude:: ../../../examples/own_filter.py + +Combining Filters +================= + +In general, all filters can be combined in two ways + + +Recommended way +--------------- + +If you specify multiple filters in a row, it will be checked with an "and" condition: + +.. code-block:: python + + @.message(Text(startswith="show"), Text(endswith="example")) + + +Also, if you want to use two alternative ways to run the sage handler ("or" condition) +you can register the handler twice or more times as you like + +.. code-block:: python + + @.message(Text(text="hi")) + @.message(CommandStart()) + + +Also sometimes you will need to invert the filter result, for example you have an *IsAdmin* filter +and you want to check if the user is not an admin + +.. code-block:: python + + @.message(~IsAdmin()) + + +Another possible way +-------------------- + +An alternative way is to combine using special functions (:func:`and_f`, :func:`or_f`, :func:`invert_f` from :code:`aiogram.filters` module): + +.. code-block:: python + + and_f(Text(startswith="show"), Text(endswith="example")) + or_f(Text(text="hi"), CommandStart()) + invert_f(IsAdmin()) + and_f(, or_f(, )) diff --git a/tests/test_filters/test_base.py b/tests/test_filters/test_base.py index 954a8d93..904c14b1 100644 --- a/tests/test_filters/test_base.py +++ b/tests/test_filters/test_base.py @@ -3,7 +3,6 @@ from typing import Awaitable import pytest from aiogram.filters import Filter -from aiogram.filters.base import _InvertFilter try: from asynctest import CoroutineMock, patch @@ -32,20 +31,3 @@ class TestBaseFilter: call = my_filter(event="test") await call mocked_call.assert_awaited_with(event="test") - - async def test_invert(self): - my_filter = MyFilter() - my_inverted_filter = ~my_filter - - assert str(my_inverted_filter) == f"~{str(my_filter)}" - - assert isinstance(my_inverted_filter, _InvertFilter) - - with patch( - "tests.test_filters.test_base.MyFilter.__call__", - new_callable=CoroutineMock, - ) as mocked_call: - call = my_inverted_filter(event="test") - result = await call - mocked_call.assert_awaited_with(event="test") - assert not result diff --git a/tests/test_filters/test_logic.py b/tests/test_filters/test_logic.py new file mode 100644 index 00000000..9c4d4f48 --- /dev/null +++ b/tests/test_filters/test_logic.py @@ -0,0 +1,38 @@ +import pytest + +from aiogram.filters import Text, and_f, invert_f, or_f +from aiogram.filters.logic import _AndFilter, _InvertFilter, _OrFilter + + +class TestLogic: + @pytest.mark.parametrize( + "obj,case,result", + [ + [True, and_f(lambda t: t is True, lambda t: t is True), True], + [True, and_f(lambda t: t is True, lambda t: t is False), False], + [True, and_f(lambda t: t is False, lambda t: t is False), False], + [True, and_f(lambda t: {"t": t}, lambda t: t is False), False], + [True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}], + [True, or_f(lambda t: t is True, lambda t: t is True), True], + [True, or_f(lambda t: t is True, lambda t: t is False), True], + [True, or_f(lambda t: t is False, lambda t: t is False), False], + [True, or_f(lambda t: t is False, lambda t: t is True), True], + [True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}], + [True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}], + [True, invert_f(lambda t: t is False), True], + ], + ) + async def test_logic(self, obj, case, result): + assert await case(obj) == result + + @pytest.mark.parametrize( + "case,type_", + [ + [or_f(Text(text="test"), Text(text="test")), _OrFilter], + [and_f(Text(text="test"), Text(text="test")), _AndFilter], + [invert_f(Text(text="test")), _InvertFilter], + [~Text(text="test"), _InvertFilter], + ], + ) + def test_dunder_methods(self, case, type_): + assert isinstance(case, type_)