bound filters resolving rework, filters with default argument

* bound filters resolving rework, filters with default argument
This commit is contained in:
darksidecat 2021-10-11 12:33:10 +03:00 committed by GitHub
parent 3931253a88
commit 7484086d12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 16 deletions

1
CHANGES/727.misc Normal file
View file

@ -0,0 +1 @@
Rework filters resolving, support filters with default values

View file

@ -2,7 +2,18 @@ from __future__ import annotations
import functools
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from pydantic import ValidationError
@ -51,7 +62,7 @@ class TelegramEventObserver:
:param filters: positional filters
:param bound_filters: keyword filters
"""
resolved_filters = self.resolve_filters(bound_filters)
resolved_filters = self.resolve_filters(filters, bound_filters)
if self._handler.filters is None:
self._handler.filters = []
self._handler.filters.extend(
@ -101,16 +112,40 @@ class TelegramEventObserver:
return middlewares
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
def resolve_filters(
self,
filters: Tuple[FilterType, ...],
full_config: Dict[str, Any],
ignore_default: bool = True,
) -> List[BaseFilter]:
"""
Resolve keyword filters via filters factory
:param filters: positional filters
:param full_config: keyword arguments to initialize bounded filters for router/handler
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
"""
filters: List[BaseFilter] = []
if not full_config:
return filters
bound_filters: List[BaseFilter] = []
if ignore_default and not full_config:
return bound_filters
filter_types = (type(f) for f in filters)
validation_errors = []
for bound_filter in self._resolve_filters_chain():
# skip filter if filter was used directly:
if bound_filter in filter_types:
continue
# skip filter with no fields in full_config
if ignore_default:
full_config_keys = set(full_config.keys())
filter_fields = set(bound_filter.__fields__.keys())
if not full_config_keys.intersection(filter_fields):
continue
# Try to initialize filter.
try:
f = bound_filter(**full_config)
@ -123,7 +158,7 @@ class TelegramEventObserver:
for key in f.__fields__:
full_config.pop(key, None)
filters.append(f)
bound_filters.append(f)
if full_config:
possible_cases = []
@ -137,7 +172,7 @@ class TelegramEventObserver:
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
)
return filters
return bound_filters
def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
@ -145,7 +180,7 @@ class TelegramEventObserver:
"""
Register event handler
"""
resolved_filters = self.resolve_filters(bound_filters)
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
self.handlers.append(
HandlerObject(
callback=callback,

View file

@ -75,3 +75,30 @@ For example if you need to make simple text filter:
Bound filters is always recursive propagates to the nested routers but will be available
in nested routers only after attaching routers so that's mean you will need to
include routers before registering handlers.
Resolving filters with default value
====================================
Bound Filters with only default arguments will be automatically applied with default values
to each handler in the router and nested routers to which this filter is bound.
For example, although we do not specify chat_type in the handler filters,
but since the filter has a default value, the filter will be applied to the handler
with a default value :code:`private`:
.. code-block:: python
class ChatType(BaseFilter):
chat_type: str = "private"
async def __call__(self, message: Message , event_chat: Chat) -> bool:
if event_chat:
return event_chat.type == chat_type
else:
return False
router.message.bind_filter(ChatType)
@router.message()
async def my_handler(message: Message): ...

View file

@ -1,6 +1,6 @@
import datetime
import functools
from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
import pytest
@ -45,6 +45,20 @@ class MyFilter3(MyFilter1):
pass
class OptionalFilter(BaseFilter):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(BaseFilter):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
@ -90,21 +104,63 @@ class TestTelegramEventObserver:
observer = router.message
observer.bind_filter(MyFilter1)
resolved = observer.resolve_filters({"test": "PASS"})
resolved = observer.resolve_filters((), {"test": "PASS"})
assert isinstance(resolved, list)
assert any(isinstance(item, MyFilter1) for item in resolved)
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"@bad": "very"})
assert observer.resolve_filters((), {"@bad": "very"})
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
# Bad argument type
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
assert observer.resolve_filters({"test": ...})
assert observer.resolve_filters((), {"test": ...})
# Disallow same filter using
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
def test_dont_autoresolve_optional_filters_for_router(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
observer.filter(test="test")
assert len(observer._handler.filters) == 1
def test_register_autoresolve_optional_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
assert observer.register(my_handler) == my_handler
assert isinstance(observer.handlers[0], HandlerObject)
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert len(observer.handlers[0].filters) == 2
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
observer.register(my_handler, test="ok")
assert isinstance(observer.handlers[1], HandlerObject)
assert len(observer.handlers[1].filters) == 3
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
observer.register(my_handler, test="ok", optional="ok")
assert isinstance(observer.handlers[2], HandlerObject)
assert len(observer.handlers[2].filters) == 3
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
def test_register(self):
router = Router(use_builtin_filters=False)
@ -125,10 +181,11 @@ class TestTelegramEventObserver:
assert isinstance(observer.handlers[2], HandlerObject)
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
observer.register(my_handler, f, test="PASS")
f2 = MyFilter2(test="ok")
observer.register(my_handler, f2, test="PASS")
assert isinstance(observer.handlers[3], HandlerObject)
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
assert f in callbacks
assert f2 in callbacks
assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self):