mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
bound filters resolving rework, filters with default argument
* bound filters resolving rework, filters with default argument
This commit is contained in:
parent
3931253a88
commit
7484086d12
4 changed files with 136 additions and 16 deletions
1
CHANGES/727.misc
Normal file
1
CHANGES/727.misc
Normal file
|
|
@ -0,0 +1 @@
|
|||
Rework filters resolving, support filters with default values
|
||||
|
|
@ -2,7 +2,18 @@ from __future__ import annotations
|
|||
|
||||
import functools
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -51,7 +62,7 @@ class TelegramEventObserver:
|
|||
:param filters: positional filters
|
||||
:param bound_filters: keyword filters
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
resolved_filters = self.resolve_filters(filters, bound_filters)
|
||||
if self._handler.filters is None:
|
||||
self._handler.filters = []
|
||||
self._handler.filters.extend(
|
||||
|
|
@ -101,16 +112,40 @@ class TelegramEventObserver:
|
|||
|
||||
return middlewares
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
def resolve_filters(
|
||||
self,
|
||||
filters: Tuple[FilterType, ...],
|
||||
full_config: Dict[str, Any],
|
||||
ignore_default: bool = True,
|
||||
) -> List[BaseFilter]:
|
||||
"""
|
||||
Resolve keyword filters via filters factory
|
||||
|
||||
:param filters: positional filters
|
||||
:param full_config: keyword arguments to initialize bounded filters for router/handler
|
||||
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
|
||||
"""
|
||||
filters: List[BaseFilter] = []
|
||||
if not full_config:
|
||||
return filters
|
||||
bound_filters: List[BaseFilter] = []
|
||||
|
||||
if ignore_default and not full_config:
|
||||
return bound_filters
|
||||
|
||||
filter_types = (type(f) for f in filters)
|
||||
|
||||
validation_errors = []
|
||||
for bound_filter in self._resolve_filters_chain():
|
||||
# skip filter if filter was used directly:
|
||||
if bound_filter in filter_types:
|
||||
continue
|
||||
|
||||
# skip filter with no fields in full_config
|
||||
if ignore_default:
|
||||
full_config_keys = set(full_config.keys())
|
||||
filter_fields = set(bound_filter.__fields__.keys())
|
||||
|
||||
if not full_config_keys.intersection(filter_fields):
|
||||
continue
|
||||
|
||||
# Try to initialize filter.
|
||||
try:
|
||||
f = bound_filter(**full_config)
|
||||
|
|
@ -123,7 +158,7 @@ class TelegramEventObserver:
|
|||
for key in f.__fields__:
|
||||
full_config.pop(key, None)
|
||||
|
||||
filters.append(f)
|
||||
bound_filters.append(f)
|
||||
|
||||
if full_config:
|
||||
possible_cases = []
|
||||
|
|
@ -137,7 +172,7 @@ class TelegramEventObserver:
|
|||
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
|
||||
)
|
||||
|
||||
return filters
|
||||
return bound_filters
|
||||
|
||||
def register(
|
||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||
|
|
@ -145,7 +180,7 @@ class TelegramEventObserver:
|
|||
"""
|
||||
Register event handler
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -75,3 +75,30 @@ For example if you need to make simple text filter:
|
|||
Bound filters is always recursive propagates to the nested routers but will be available
|
||||
in nested routers only after attaching routers so that's mean you will need to
|
||||
include routers before registering handlers.
|
||||
|
||||
Resolving filters with default value
|
||||
====================================
|
||||
|
||||
Bound Filters with only default arguments will be automatically applied with default values
|
||||
to each handler in the router and nested routers to which this filter is bound.
|
||||
|
||||
For example, although we do not specify chat_type in the handler filters,
|
||||
but since the filter has a default value, the filter will be applied to the handler
|
||||
with a default value :code:`private`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class ChatType(BaseFilter):
|
||||
chat_type: str = "private"
|
||||
|
||||
async def __call__(self, message: Message , event_chat: Chat) -> bool:
|
||||
if event_chat:
|
||||
return event_chat.type == chat_type
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
router.message.bind_filter(ChatType)
|
||||
|
||||
@router.message()
|
||||
async def my_handler(message: Message): ...
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
|
|||
pass
|
||||
|
||||
|
||||
class OptionalFilter(BaseFilter):
|
||||
optional: Optional[str]
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class DefaultFilter(BaseFilter):
|
||||
default: str = "Default"
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
||||
return True
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
def test_bind_filter(self):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
|
@ -90,21 +104,63 @@ class TestTelegramEventObserver:
|
|||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
|
||||
resolved = observer.resolve_filters({"test": "PASS"})
|
||||
resolved = observer.resolve_filters((), {"test": "PASS"})
|
||||
assert isinstance(resolved, list)
|
||||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"@bad": "very"})
|
||||
assert observer.resolve_filters((), {"@bad": "very"})
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
||||
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
|
||||
|
||||
# Bad argument type
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
assert observer.resolve_filters({"test": ...})
|
||||
assert observer.resolve_filters((), {"test": ...})
|
||||
|
||||
# Disallow same filter using
|
||||
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
|
||||
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
|
||||
|
||||
def test_dont_autoresolve_optional_filters_for_router(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
observer.filter(test="test")
|
||||
assert len(observer._handler.filters) == 1
|
||||
|
||||
def test_register_autoresolve_optional_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.bind_filter(MyFilter1)
|
||||
observer.bind_filter(OptionalFilter)
|
||||
observer.bind_filter(DefaultFilter)
|
||||
|
||||
assert observer.register(my_handler) == my_handler
|
||||
assert isinstance(observer.handlers[0], HandlerObject)
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert len(observer.handlers[0].filters) == 2
|
||||
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok")
|
||||
assert isinstance(observer.handlers[1], HandlerObject)
|
||||
assert len(observer.handlers[1].filters) == 3
|
||||
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
|
||||
|
||||
observer.register(my_handler, test="ok", optional="ok")
|
||||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert len(observer.handlers[2].filters) == 3
|
||||
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
|
||||
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
|
||||
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
|
||||
|
||||
def test_register(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
|
|
@ -125,10 +181,11 @@ class TestTelegramEventObserver:
|
|||
assert isinstance(observer.handlers[2], HandlerObject)
|
||||
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
|
||||
|
||||
observer.register(my_handler, f, test="PASS")
|
||||
f2 = MyFilter2(test="ok")
|
||||
observer.register(my_handler, f2, test="PASS")
|
||||
assert isinstance(observer.handlers[3], HandlerObject)
|
||||
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
|
||||
assert f in callbacks
|
||||
assert f2 in callbacks
|
||||
assert MyFilter1(test="PASS") in callbacks
|
||||
|
||||
def test_register_decorator(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue