feat(handler): requirements for CBH

implement POC of requirements for class based handlers
This commit is contained in:
mpa 2020-05-28 04:53:07 +04:00
parent 32ffda2eb7
commit 0b3702f658
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8
4 changed files with 51 additions and 13 deletions

View file

@ -8,7 +8,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type,
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable
from aiogram.dispatcher.requirement import (
CacheKeyType,
Requirement,
get_reqs_from_callable,
get_reqs_from_class,
)
CallbackType = Callable[..., Awaitable[Any]]
SyncFilter = Callable[..., Any]
@ -20,6 +25,10 @@ REQUIREMENT_CACHE_KEY = "_req_cache"
ASYNC_STACK_KEY = "_stack"
def _is_class_handler(handler: HandlerType) -> bool:
return isinstance(handler, type) and issubclass(handler, BaseHandler)
@dataclass
class CallableMixin:
callback: HandlerType
@ -36,7 +45,12 @@ class CallableMixin:
self.spec = inspect.getfullargspec(callback)
self.__reqs__ = get_reqs_from_callable(callable_=callback)
if _is_class_handler(callback):
self.awaitable = True
self.__reqs__ = get_reqs_from_class(callback)
else:
self.__reqs__ = get_reqs_from_callable(callable_=callback)
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if self.spec.varkw:
@ -45,17 +59,34 @@ class CallableMixin:
return {k: v for k, v in kwargs.items() if k in self.spec.args}
async def call(self, *args: Any, **kwargs: Any) -> Any:
# we don't requirements_data and kwargs keys to intersect
requirements_data: Dict[str, Any] = {}
if self.__reqs__:
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
requirements_data = kwargs.copy()
for kwarg, default in self.__reqs__.items():
kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs)
for req_id, req in self.__reqs__.items():
requirements_data[req_id] = await req(
cache_dict=cache_dict, stack=stack, data=requirements_data
)
for to_pop in kwargs:
requirements_data.pop(to_pop, None)
kwargs.pop(ASYNC_STACK_KEY, None)
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
if _is_class_handler(self.callback):
wrapped = partial(self.callback, *args, requirements_data, kwargs)
else:
wrapped = partial(
self.callback,
*args,
**self._prepare_kwargs(kwargs),
**self._prepare_kwargs(requirements_data),
)
if self.awaitable:
return await wrapped()
@ -76,11 +107,6 @@ class HandlerObject(CallableMixin):
callback: HandlerType
filters: Optional[List[FilterObject]] = None
def __post_init__(self) -> None:
super(HandlerObject, self).__post_init__()
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore
self.awaitable = True
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
if not self.filters:
return True, kwargs

View file

@ -132,7 +132,6 @@ class TelegramEventObserver:
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
for handler in self.handlers:
async with AsyncExitStack() as stack:
# intermediate values. handler.call is responsible for cleaning them up
kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}})
result, data = await handler.check(event, **kwargs)
@ -144,6 +143,9 @@ class TelegramEventObserver:
except SkipHandler:
continue
kwargs.pop(ASYNC_STACK_KEY, None)
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
return NOT_HANDLED
def __call__(

View file

@ -18,9 +18,12 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
Base class for all class-based handlers
"""
def __init__(self, event: T, **kwargs: Any) -> None:
def __init__(self, event: T, requirements_data: Dict[str, Any], data: Dict[str, Any]) -> None:
self.event: T = event
self.data: Dict[str, Any] = kwargs
self.data: Dict[str, Any] = data
for req_attr, req in requirements_data.items():
setattr(self, req_attr, req)
@property
def bot(self) -> Bot:

View file

@ -10,6 +10,7 @@ from typing import (
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
cast,
@ -157,6 +158,12 @@ def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Require
}
def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
return {
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
}
def require(
what: _RequiredCallback[T],
*,