mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Added explicit logic filters, added slots to all other filters
This commit is contained in:
parent
9041cc72f1
commit
06f24a8cb3
13 changed files with 214 additions and 43 deletions
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Dict, Tuple, Type
|
||||
|
||||
from .base import Filter
|
||||
from .chat_member_updated import (
|
||||
ADMINISTRATOR,
|
||||
|
|
@ -18,6 +16,7 @@ from .chat_member_updated import (
|
|||
)
|
||||
from .command import Command, CommandObject, CommandStart
|
||||
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from .logic import and_f, invert_f, or_f
|
||||
from .magic_data import MagicData
|
||||
from .state import StateFilter
|
||||
from .text import Text
|
||||
|
|
@ -25,7 +24,6 @@ from .text import Text
|
|||
BaseFilter = Filter
|
||||
|
||||
__all__ = (
|
||||
"BUILTIN_FILTERS",
|
||||
"Filter",
|
||||
"BaseFilter",
|
||||
"Text",
|
||||
|
|
@ -49,6 +47,7 @@ __all__ = (
|
|||
"IS_NOT_MEMBER",
|
||||
"JOIN_TRANSITION",
|
||||
"LEAVE_TRANSITION",
|
||||
"and_f",
|
||||
"or_f",
|
||||
"invert_f",
|
||||
)
|
||||
|
||||
BUILTIN_FILTERS: Dict[str, Tuple[Type[Filter], ...]] = {}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.event.handler import CallbackType, FilterObject
|
||||
from aiogram.filters.logic import _InvertFilter
|
||||
|
||||
|
||||
class Filter(ABC):
|
||||
|
|
@ -31,6 +31,8 @@ class Filter(ABC):
|
|||
pass
|
||||
|
||||
def __invert__(self) -> "_InvertFilter":
|
||||
from aiogram.filters.logic import invert_f
|
||||
|
||||
return invert_f(self)
|
||||
|
||||
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
||||
|
|
@ -50,22 +52,3 @@ class Filter(ABC):
|
|||
def __await__(self): # type: ignore # pragma: no cover
|
||||
# Is needed only for inspection and this method is never be called
|
||||
return self.__call__
|
||||
|
||||
|
||||
class _InvertFilter(Filter):
|
||||
__slots__ = ("target",)
|
||||
|
||||
def __init__(self, target: "FilterObject") -> None:
|
||||
self.target = target
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return not bool(await self.target.call(*args, **kwargs))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"~{self.target.callback}"
|
||||
|
||||
|
||||
def invert_f(target: "CallbackType") -> _InvertFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _InvertFilter(target=FilterObject(target))
|
||||
|
|
|
|||
|
|
@ -131,6 +131,11 @@ class CallbackQueryFilter(Filter):
|
|||
via callback data instance
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"callback_data",
|
||||
"rule",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,11 @@ TransitionT = TypeVar("TransitionT", bound="_MemberStatusTransition")
|
|||
|
||||
|
||||
class _MemberStatusMarker:
|
||||
__slots__ = (
|
||||
"name",
|
||||
"is_member",
|
||||
)
|
||||
|
||||
def __init__(self, name: str, *, is_member: Optional[bool] = None) -> None:
|
||||
self.name = name
|
||||
self.is_member = is_member
|
||||
|
|
@ -72,6 +77,8 @@ class _MemberStatusMarker:
|
|||
|
||||
|
||||
class _MemberStatusGroupMarker:
|
||||
__slots__ = ("statuses",)
|
||||
|
||||
def __init__(self, *statuses: _MemberStatusMarker) -> None:
|
||||
if not statuses:
|
||||
raise ValueError("Member status group should have at least one status included")
|
||||
|
|
@ -124,6 +131,11 @@ class _MemberStatusGroupMarker:
|
|||
|
||||
|
||||
class _MemberStatusTransition:
|
||||
__slots__ = (
|
||||
"old",
|
||||
"new",
|
||||
)
|
||||
|
||||
def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None:
|
||||
self.old = old
|
||||
self.new = new
|
||||
|
|
@ -155,6 +167,8 @@ PROMOTED_TRANSITION = (MEMBER | RESTRICTED | LEFT | KICKED) >> ADMINISTRATOR
|
|||
|
||||
|
||||
class ChatMemberUpdatedFilter(Filter):
|
||||
__slots__ = ("member_status_changed",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
member_status_changed: Union[
|
||||
|
|
|
|||
|
|
@ -38,6 +38,14 @@ class Command(Filter):
|
|||
Works only with :class:`aiogram.types.message.Message` events which have the :code:`text`.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"commands",
|
||||
"prefix",
|
||||
"ignore_case",
|
||||
"ignore_mention",
|
||||
"magic",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*values: CommandPatternType,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ class ExceptionTypeFilter(Filter):
|
|||
Allows to match exception by type
|
||||
"""
|
||||
|
||||
__slots__ = ("exceptions",)
|
||||
|
||||
def __init__(self, *exceptions: Type[Exception]):
|
||||
"""
|
||||
:param exceptions: Exception type(s)
|
||||
|
|
@ -28,6 +30,8 @@ class ExceptionMessageFilter(Filter):
|
|||
Allow to match exception by message
|
||||
"""
|
||||
|
||||
__slots__ = ("pattern",)
|
||||
|
||||
def __init__(self, pattern: Union[str, Pattern[str]]):
|
||||
"""
|
||||
:param pattern: Regexp pattern
|
||||
|
|
|
|||
77
aiogram/filters/logic.py
Normal file
77
aiogram/filters/logic.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, Union
|
||||
|
||||
from aiogram.filters import Filter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.event.handler import CallbackType, FilterObject
|
||||
|
||||
|
||||
class _LogicFilter(Filter, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class _InvertFilter(_LogicFilter):
|
||||
__slots__ = ("target",)
|
||||
|
||||
def __init__(self, target: "FilterObject") -> None:
|
||||
self.target = target
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return not bool(await self.target.call(*args, **kwargs))
|
||||
|
||||
|
||||
class _AndFilter(_LogicFilter):
|
||||
__slots__ = ("targets",)
|
||||
|
||||
def __init__(self, *targets: "FilterObject") -> None:
|
||||
self.targets = targets
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
final_result = {}
|
||||
|
||||
for target in self.targets:
|
||||
result = await target.call(*args, **kwargs)
|
||||
if not result:
|
||||
return False
|
||||
if isinstance(result, dict):
|
||||
final_result.update(result)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
return True
|
||||
|
||||
|
||||
class _OrFilter(_LogicFilter):
|
||||
__slots__ = ("targets",)
|
||||
|
||||
def __init__(self, *targets: "FilterObject") -> None:
|
||||
self.targets = targets
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
for target in self.targets:
|
||||
result = await target.call(*args, **kwargs)
|
||||
if not result:
|
||||
continue
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return bool(result)
|
||||
return False
|
||||
|
||||
|
||||
def and_f(target1: "CallbackType", target2: "CallbackType") -> _AndFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _AndFilter(FilterObject(target1), FilterObject(target2))
|
||||
|
||||
|
||||
def or_f(target1: "CallbackType", target2: "CallbackType") -> _OrFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _OrFilter(FilterObject(target1), FilterObject(target2))
|
||||
|
||||
|
||||
def invert_f(target: "CallbackType") -> _InvertFilter:
|
||||
from aiogram.dispatcher.event.handler import FilterObject
|
||||
|
||||
return _InvertFilter(FilterObject(target))
|
||||
|
|
@ -7,6 +7,12 @@ from aiogram.types import TelegramObject
|
|||
|
||||
|
||||
class MagicData(Filter):
|
||||
"""
|
||||
This filter helps to filter event with contextual data
|
||||
"""
|
||||
|
||||
__slots__ = "magic_data"
|
||||
|
||||
def __init__(self, magic_data: MagicFilter) -> None:
|
||||
self.magic_data = magic_data
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ class StateFilter(Filter):
|
|||
State filter
|
||||
"""
|
||||
|
||||
__slots__ = ("states",)
|
||||
|
||||
def __init__(self, *states: StateType) -> None:
|
||||
if not states:
|
||||
raise ValueError("At least one state is required")
|
||||
|
|
|
|||
|
|
@ -25,6 +25,14 @@ class Text(Filter):
|
|||
use :ref:`magic-filter <magic-filters>`. For example do :pycode:`F.text == "text"` instead
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"text",
|
||||
"contains",
|
||||
"startswith",
|
||||
"endswith",
|
||||
"ignore_case",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: Optional[Union[Sequence[TextType], TextType]] = None,
|
||||
|
|
|
|||
|
|
@ -53,3 +53,48 @@ Own filter example
|
|||
For example if you need to make simple text filter:
|
||||
|
||||
.. literalinclude:: ../../../examples/own_filter.py
|
||||
|
||||
Combining Filters
|
||||
=================
|
||||
|
||||
In general, all filters can be combined in two ways
|
||||
|
||||
|
||||
Recommended way
|
||||
---------------
|
||||
|
||||
If you specify multiple filters in a row, it will be checked with an "and" condition:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<router>.message(Text(startswith="show"), Text(endswith="example"))
|
||||
|
||||
|
||||
Also, if you want to use two alternative ways to run the sage handler ("or" condition)
|
||||
you can register the handler twice or more times as you like
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<router>.message(Text(text="hi"))
|
||||
@<router>.message(CommandStart())
|
||||
|
||||
|
||||
Also sometimes you will need to invert the filter result, for example you have an *IsAdmin* filter
|
||||
and you want to check if the user is not an admin
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<router>.message(~IsAdmin())
|
||||
|
||||
|
||||
Another possible way
|
||||
--------------------
|
||||
|
||||
An alternative way is to combine using special functions (:func:`and_f`, :func:`or_f`, :func:`invert_f` from :code:`aiogram.filters` module):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
and_f(Text(startswith="show"), Text(endswith="example"))
|
||||
or_f(Text(text="hi"), CommandStart())
|
||||
invert_f(IsAdmin())
|
||||
and_f(<A>, or_f(<B>, <C>))
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Awaitable
|
|||
import pytest
|
||||
|
||||
from aiogram.filters import Filter
|
||||
from aiogram.filters.base import _InvertFilter
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
|
|
@ -32,20 +31,3 @@ class TestBaseFilter:
|
|||
call = my_filter(event="test")
|
||||
await call
|
||||
mocked_call.assert_awaited_with(event="test")
|
||||
|
||||
async def test_invert(self):
|
||||
my_filter = MyFilter()
|
||||
my_inverted_filter = ~my_filter
|
||||
|
||||
assert str(my_inverted_filter) == f"~{str(my_filter)}"
|
||||
|
||||
assert isinstance(my_inverted_filter, _InvertFilter)
|
||||
|
||||
with patch(
|
||||
"tests.test_filters.test_base.MyFilter.__call__",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_call:
|
||||
call = my_inverted_filter(event="test")
|
||||
result = await call
|
||||
mocked_call.assert_awaited_with(event="test")
|
||||
assert not result
|
||||
|
|
|
|||
38
tests/test_filters/test_logic.py
Normal file
38
tests/test_filters/test_logic.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.filters import Text, and_f, invert_f, or_f
|
||||
from aiogram.filters.logic import _AndFilter, _InvertFilter, _OrFilter
|
||||
|
||||
|
||||
class TestLogic:
|
||||
@pytest.mark.parametrize(
|
||||
"obj,case,result",
|
||||
[
|
||||
[True, and_f(lambda t: t is True, lambda t: t is True), True],
|
||||
[True, and_f(lambda t: t is True, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: t is False, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: {"t": t}, lambda t: t is False), False],
|
||||
[True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}],
|
||||
[True, or_f(lambda t: t is True, lambda t: t is True), True],
|
||||
[True, or_f(lambda t: t is True, lambda t: t is False), True],
|
||||
[True, or_f(lambda t: t is False, lambda t: t is False), False],
|
||||
[True, or_f(lambda t: t is False, lambda t: t is True), True],
|
||||
[True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}],
|
||||
[True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}],
|
||||
[True, invert_f(lambda t: t is False), True],
|
||||
],
|
||||
)
|
||||
async def test_logic(self, obj, case, result):
|
||||
assert await case(obj) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case,type_",
|
||||
[
|
||||
[or_f(Text(text="test"), Text(text="test")), _OrFilter],
|
||||
[and_f(Text(text="test"), Text(text="test")), _AndFilter],
|
||||
[invert_f(Text(text="test")), _InvertFilter],
|
||||
[~Text(text="test"), _InvertFilter],
|
||||
],
|
||||
)
|
||||
def test_dunder_methods(self, case, type_):
|
||||
assert isinstance(case, type_)
|
||||
Loading…
Add table
Add a link
Reference in a new issue