mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat(handler): requirements for CBH
implement POC of requirements for class based handlers
This commit is contained in:
parent
32ffda2eb7
commit
0b3702f658
4 changed files with 51 additions and 13 deletions
|
|
@ -8,7 +8,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type,
|
|||
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable
|
||||
from aiogram.dispatcher.requirement import (
|
||||
CacheKeyType,
|
||||
Requirement,
|
||||
get_reqs_from_callable,
|
||||
get_reqs_from_class,
|
||||
)
|
||||
|
||||
CallbackType = Callable[..., Awaitable[Any]]
|
||||
SyncFilter = Callable[..., Any]
|
||||
|
|
@ -20,6 +25,10 @@ REQUIREMENT_CACHE_KEY = "_req_cache"
|
|||
ASYNC_STACK_KEY = "_stack"
|
||||
|
||||
|
||||
def _is_class_handler(handler: HandlerType) -> bool:
|
||||
return isinstance(handler, type) and issubclass(handler, BaseHandler)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallableMixin:
|
||||
callback: HandlerType
|
||||
|
|
@ -36,7 +45,12 @@ class CallableMixin:
|
|||
|
||||
self.spec = inspect.getfullargspec(callback)
|
||||
|
||||
self.__reqs__ = get_reqs_from_callable(callable_=callback)
|
||||
if _is_class_handler(callback):
|
||||
self.awaitable = True
|
||||
self.__reqs__ = get_reqs_from_class(callback)
|
||||
|
||||
else:
|
||||
self.__reqs__ = get_reqs_from_callable(callable_=callback)
|
||||
|
||||
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.spec.varkw:
|
||||
|
|
@ -45,17 +59,34 @@ class CallableMixin:
|
|||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||
|
||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# we don't requirements_data and kwargs keys to intersect
|
||||
requirements_data: Dict[str, Any] = {}
|
||||
|
||||
if self.__reqs__:
|
||||
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
|
||||
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
|
||||
requirements_data = kwargs.copy()
|
||||
|
||||
for kwarg, default in self.__reqs__.items():
|
||||
kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs)
|
||||
for req_id, req in self.__reqs__.items():
|
||||
requirements_data[req_id] = await req(
|
||||
cache_dict=cache_dict, stack=stack, data=requirements_data
|
||||
)
|
||||
|
||||
for to_pop in kwargs:
|
||||
requirements_data.pop(to_pop, None)
|
||||
|
||||
kwargs.pop(ASYNC_STACK_KEY, None)
|
||||
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
|
||||
|
||||
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
|
||||
if _is_class_handler(self.callback):
|
||||
wrapped = partial(self.callback, *args, requirements_data, kwargs)
|
||||
else:
|
||||
wrapped = partial(
|
||||
self.callback,
|
||||
*args,
|
||||
**self._prepare_kwargs(kwargs),
|
||||
**self._prepare_kwargs(requirements_data),
|
||||
)
|
||||
|
||||
if self.awaitable:
|
||||
return await wrapped()
|
||||
|
|
@ -76,11 +107,6 @@ class HandlerObject(CallableMixin):
|
|||
callback: HandlerType
|
||||
filters: Optional[List[FilterObject]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super(HandlerObject, self).__post_init__()
|
||||
if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore
|
||||
self.awaitable = True
|
||||
|
||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||
if not self.filters:
|
||||
return True, kwargs
|
||||
|
|
|
|||
|
|
@ -132,7 +132,6 @@ class TelegramEventObserver:
|
|||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
for handler in self.handlers:
|
||||
async with AsyncExitStack() as stack:
|
||||
# intermediate values. handler.call is responsible for cleaning them up
|
||||
kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}})
|
||||
|
||||
result, data = await handler.check(event, **kwargs)
|
||||
|
|
@ -144,6 +143,9 @@ class TelegramEventObserver:
|
|||
except SkipHandler:
|
||||
continue
|
||||
|
||||
kwargs.pop(ASYNC_STACK_KEY, None)
|
||||
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
|
||||
|
||||
return NOT_HANDLED
|
||||
|
||||
def __call__(
|
||||
|
|
|
|||
|
|
@ -18,9 +18,12 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
|
|||
Base class for all class-based handlers
|
||||
"""
|
||||
|
||||
def __init__(self, event: T, **kwargs: Any) -> None:
|
||||
def __init__(self, event: T, requirements_data: Dict[str, Any], data: Dict[str, Any]) -> None:
|
||||
self.event: T = event
|
||||
self.data: Dict[str, Any] = kwargs
|
||||
self.data: Dict[str, Any] = data
|
||||
|
||||
for req_attr, req in requirements_data.items():
|
||||
setattr(self, req_attr, req)
|
||||
|
||||
@property
|
||||
def bot(self) -> Bot:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import (
|
|||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
|
|
@ -157,6 +158,12 @@ def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Require
|
|||
}
|
||||
|
||||
|
||||
def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
|
||||
return {
|
||||
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
|
||||
}
|
||||
|
||||
|
||||
def require(
|
||||
what: _RequiredCallback[T],
|
||||
*,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue