From 7484086d122a0502459759e24846c7dd55b7f77f Mon Sep 17 00:00:00 2001 From: darksidecat <58224121+darksidecat@users.noreply.github.com> Date: Mon, 11 Oct 2021 12:33:10 +0300 Subject: [PATCH] bound filters resolving rework, filters with default argument * bound filters resolving rework, filters with default argument --- CHANGES/727.misc | 1 + aiogram/dispatcher/event/telegram.py | 53 +++++++++++--- docs/dispatcher/filters/index.rst | 27 +++++++ .../test_event/test_telegram.py | 71 +++++++++++++++++-- 4 files changed, 136 insertions(+), 16 deletions(-) create mode 100644 CHANGES/727.misc diff --git a/CHANGES/727.misc b/CHANGES/727.misc new file mode 100644 index 00000000..cbc6f983 --- /dev/null +++ b/CHANGES/727.misc @@ -0,0 +1 @@ +Rework filters resolving, support filters with default values diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index 386d2fa4..6dae5c6c 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -2,7 +2,18 @@ from __future__ import annotations import functools from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + Union, +) from pydantic import ValidationError @@ -51,7 +62,7 @@ class TelegramEventObserver: :param filters: positional filters :param bound_filters: keyword filters """ - resolved_filters = self.resolve_filters(bound_filters) + resolved_filters = self.resolve_filters(filters, bound_filters) if self._handler.filters is None: self._handler.filters = [] self._handler.filters.extend( @@ -101,16 +112,40 @@ class TelegramEventObserver: return middlewares - def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: + def resolve_filters( + self, + filters: Tuple[FilterType, ...], + full_config: Dict[str, Any], + ignore_default: bool = True, + ) -> List[BaseFilter]: """ Resolve keyword filters via filters factory + + :param filters: positional filters + :param full_config: keyword arguments to initialize bounded filters for router/handler + :param ignore_default: ignore to resolving filters with only default arguments that are not in full_config """ - filters: List[BaseFilter] = [] - if not full_config: - return filters + bound_filters: List[BaseFilter] = [] + + if ignore_default and not full_config: + return bound_filters + + filter_types = (type(f) for f in filters) validation_errors = [] for bound_filter in self._resolve_filters_chain(): + # skip filter if filter was used directly: + if bound_filter in filter_types: + continue + + # skip filter with no fields in full_config + if ignore_default: + full_config_keys = set(full_config.keys()) + filter_fields = set(bound_filter.__fields__.keys()) + + if not full_config_keys.intersection(filter_fields): + continue + # Try to initialize filter. try: f = bound_filter(**full_config) @@ -123,7 +158,7 @@ class TelegramEventObserver: for key in f.__fields__: full_config.pop(key, None) - filters.append(f) + bound_filters.append(f) if full_config: possible_cases = [] @@ -137,7 +172,7 @@ class TelegramEventObserver: unresolved_fields=set(full_config.keys()), possible_cases=possible_cases ) - return filters + return bound_filters def register( self, callback: HandlerType, *filters: FilterType, **bound_filters: Any @@ -145,7 +180,7 @@ class TelegramEventObserver: """ Register event handler """ - resolved_filters = self.resolve_filters(bound_filters) + resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False) self.handlers.append( HandlerObject( callback=callback, diff --git a/docs/dispatcher/filters/index.rst b/docs/dispatcher/filters/index.rst index 26de26a4..7200c611 100644 --- a/docs/dispatcher/filters/index.rst +++ b/docs/dispatcher/filters/index.rst @@ -75,3 +75,30 @@ For example if you need to make simple text filter: Bound filters is always recursive propagates to the nested routers but will be available in nested routers only after attaching routers so that's mean you will need to include routers before registering handlers. + +Resolving filters with default value +==================================== + +Bound Filters with only default arguments will be automatically applied with default values +to each handler in the router and nested routers to which this filter is bound. + +For example, although we do not specify chat_type in the handler filters, +but since the filter has a default value, the filter will be applied to the handler +with a default value :code:`private`: + +.. code-block:: python + + class ChatType(BaseFilter): + chat_type: str = "private" + + async def __call__(self, message: Message , event_chat: Chat) -> bool: + if event_chat: + return event_chat.type == chat_type + else: + return False + + + router.message.bind_filter(ChatType) + + @router.message() + async def my_handler(message: Message): ... diff --git a/tests/test_dispatcher/test_event/test_telegram.py b/tests/test_dispatcher/test_event/test_telegram.py index 563ffa9e..04e959f1 100644 --- a/tests/test_dispatcher/test_event/test_telegram.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -1,6 +1,6 @@ import datetime import functools -from typing import Any, Awaitable, Callable, Dict, NoReturn, Union +from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union import pytest @@ -45,6 +45,20 @@ class MyFilter3(MyFilter1): pass +class OptionalFilter(BaseFilter): + optional: Optional[str] + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + +class DefaultFilter(BaseFilter): + default: str = "Default" + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + class TestTelegramEventObserver: def test_bind_filter(self): event_observer = TelegramEventObserver(Router(), "test") @@ -90,21 +104,63 @@ class TestTelegramEventObserver: observer = router.message observer.bind_filter(MyFilter1) - resolved = observer.resolve_filters({"test": "PASS"}) + resolved = observer.resolve_filters((), {"test": "PASS"}) assert isinstance(resolved, list) assert any(isinstance(item, MyFilter1) for item in resolved) # Unknown filter with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): - assert observer.resolve_filters({"@bad": "very"}) + assert observer.resolve_filters((), {"@bad": "very"}) # Unknown filter with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"): - assert observer.resolve_filters({"test": "ok", "@bad": "very"}) + assert observer.resolve_filters((), {"test": "ok", "@bad": "very"}) # Bad argument type with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"): - assert observer.resolve_filters({"test": ...}) + assert observer.resolve_filters((), {"test": ...}) + + # Disallow same filter using + with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"): + observer.resolve_filters((MyFilter1(test="test"),), {"test": ...}) + + def test_dont_autoresolve_optional_filters_for_router(self): + router = Router(use_builtin_filters=False) + observer = router.message + observer.bind_filter(MyFilter1) + observer.bind_filter(OptionalFilter) + observer.bind_filter(DefaultFilter) + + observer.filter(test="test") + assert len(observer._handler.filters) == 1 + + def test_register_autoresolve_optional_filters(self): + router = Router(use_builtin_filters=False) + observer = router.message + observer.bind_filter(MyFilter1) + observer.bind_filter(OptionalFilter) + observer.bind_filter(DefaultFilter) + + assert observer.register(my_handler) == my_handler + assert isinstance(observer.handlers[0], HandlerObject) + assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter) + assert len(observer.handlers[0].filters) == 2 + assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter) + assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter) + + observer.register(my_handler, test="ok") + assert isinstance(observer.handlers[1], HandlerObject) + assert len(observer.handlers[1].filters) == 3 + assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1) + assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter) + assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter) + + observer.register(my_handler, test="ok", optional="ok") + assert isinstance(observer.handlers[2], HandlerObject) + assert len(observer.handlers[2].filters) == 3 + assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1) + assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter) + assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter) def test_register(self): router = Router(use_builtin_filters=False) @@ -125,10 +181,11 @@ class TestTelegramEventObserver: assert isinstance(observer.handlers[2], HandlerObject) assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters) - observer.register(my_handler, f, test="PASS") + f2 = MyFilter2(test="ok") + observer.register(my_handler, f2, test="PASS") assert isinstance(observer.handlers[3], HandlerObject) callbacks = [filter_.callback for filter_ in observer.handlers[3].filters] - assert f in callbacks + assert f2 in callbacks assert MyFilter1(test="PASS") in callbacks def test_register_decorator(self):