Bound filters resolving rework, filters with default argument (#727)

* bound filters resolving rework, filters with default argument

* bound filters resolving rework, filters with default argument

* Update 727.misc

* clarification of the comment about skipping filter

* fix data transfer from parent to included routers filters

* fix checking containing value in generator

* Update docs/dispatcher/filters/index.rst

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>

* Update 727.misc

* reformat

* better iterable types

Co-authored-by: Alex Root Junior <jroot.junior@gmail.com>
This commit is contained in:
darksidecat 2021-10-12 22:29:57 +03:00 committed by GitHub
parent f97367b3ee
commit 42cba8976f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 18 deletions

View file

@ -2,7 +2,18 @@ from __future__ import annotations
import functools
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from pydantic import ValidationError
@ -51,7 +62,7 @@ class TelegramEventObserver:
:param filters: positional filters
:param bound_filters: keyword filters
"""
resolved_filters = self.resolve_filters(bound_filters)
resolved_filters = self.resolve_filters(filters, bound_filters)
if self._handler.filters is None:
self._handler.filters = []
self._handler.filters.extend(
@ -77,7 +88,7 @@ class TelegramEventObserver:
"""
registry: List[Type[BaseFilter]] = []
for router in self.router.chain:
for router in reversed(tuple(self.router.chain)):
observer = router.observers[self.event_name]
for filter_ in observer.filters:
@ -95,22 +106,46 @@ class TelegramEventObserver:
if outer:
middlewares.extend(self.outer_middlewares)
else:
for router in reversed(list(self.router.chain_head)):
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares)
return middlewares
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
def resolve_filters(
self,
filters: Tuple[FilterType, ...],
full_config: Dict[str, Any],
ignore_default: bool = True,
) -> List[BaseFilter]:
"""
Resolve keyword filters via filters factory
:param filters: positional filters
:param full_config: keyword arguments to initialize bounded filters for router/handler
:param ignore_default: ignore to resolving filters with only default arguments that are not in full_config
"""
filters: List[BaseFilter] = []
if not full_config:
return filters
bound_filters: List[BaseFilter] = []
if ignore_default and not full_config:
return bound_filters
filter_types = set(type(f) for f in filters)
validation_errors = []
for bound_filter in self._resolve_filters_chain():
# skip filter if filter was used as positional filter:
if bound_filter in filter_types:
continue
# skip filter with no fields in full_config
if ignore_default:
full_config_keys = set(full_config.keys())
filter_fields = set(bound_filter.__fields__.keys())
if not full_config_keys.intersection(filter_fields):
continue
# Try to initialize filter.
try:
f = bound_filter(**full_config)
@ -123,7 +158,7 @@ class TelegramEventObserver:
for key in f.__fields__:
full_config.pop(key, None)
filters.append(f)
bound_filters.append(f)
if full_config:
possible_cases = []
@ -137,7 +172,7 @@ class TelegramEventObserver:
unresolved_fields=set(full_config.keys()), possible_cases=possible_cases
)
return filters
return bound_filters
def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
@ -145,7 +180,7 @@ class TelegramEventObserver:
"""
Register event handler
"""
resolved_filters = self.resolve_filters(bound_filters)
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
self.handlers.append(
HandlerObject(
callback=callback,