diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index e9c1b496..4d254f79 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -8,7 +8,12 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler -from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable +from aiogram.dispatcher.requirement import ( + CacheKeyType, + Requirement, + get_reqs_from_callable, + get_reqs_from_class, +) CallbackType = Callable[..., Awaitable[Any]] SyncFilter = Callable[..., Any] @@ -20,6 +25,10 @@ REQUIREMENT_CACHE_KEY = "_req_cache" ASYNC_STACK_KEY = "_stack" +def _is_class_handler(handler: HandlerType) -> bool: + return isinstance(handler, type) and issubclass(handler, BaseHandler) + + @dataclass class CallableMixin: callback: HandlerType @@ -36,7 +45,12 @@ class CallableMixin: self.spec = inspect.getfullargspec(callback) - self.__reqs__ = get_reqs_from_callable(callable_=callback) + if _is_class_handler(callback): + self.awaitable = True + self.__reqs__ = get_reqs_from_class(callback) + + else: + self.__reqs__ = get_reqs_from_callable(callable_=callback) def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: if self.spec.varkw: @@ -45,17 +59,34 @@ class CallableMixin: return {k: v for k, v in kwargs.items() if k in self.spec.args} async def call(self, *args: Any, **kwargs: Any) -> Any: + # we don't requirements_data and kwargs keys to intersect + requirements_data: Dict[str, Any] = {} + if self.__reqs__: stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY)) cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {}) + requirements_data = kwargs.copy() - for kwarg, default in self.__reqs__.items(): - kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs) + for req_id, req in self.__reqs__.items(): + requirements_data[req_id] = await req( + cache_dict=cache_dict, stack=stack, data=requirements_data + ) + + for to_pop in kwargs: + requirements_data.pop(to_pop, None) kwargs.pop(ASYNC_STACK_KEY, None) kwargs.pop(REQUIREMENT_CACHE_KEY, None) - wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) + if _is_class_handler(self.callback): + wrapped = partial(self.callback, *args, requirements_data, kwargs) + else: + wrapped = partial( + self.callback, + *args, + **self._prepare_kwargs(kwargs), + **self._prepare_kwargs(requirements_data), + ) if self.awaitable: return await wrapped() @@ -76,11 +107,6 @@ class HandlerObject(CallableMixin): callback: HandlerType filters: Optional[List[FilterObject]] = None - def __post_init__(self) -> None: - super(HandlerObject, self).__post_init__() - if inspect.isclass(self.callback) and issubclass(self.callback, BaseHandler): # type: ignore - self.awaitable = True - async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: if not self.filters: return True, kwargs diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index be1eac5a..513f6d57 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -132,7 +132,6 @@ class TelegramEventObserver: async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: for handler in self.handlers: async with AsyncExitStack() as stack: - # intermediate values. handler.call is responsible for cleaning them up kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}}) result, data = await handler.check(event, **kwargs) @@ -144,6 +143,9 @@ class TelegramEventObserver: except SkipHandler: continue + kwargs.pop(ASYNC_STACK_KEY, None) + kwargs.pop(REQUIREMENT_CACHE_KEY, None) + return NOT_HANDLED def __call__( diff --git a/aiogram/dispatcher/handler/base.py b/aiogram/dispatcher/handler/base.py index 90b8da45..3b47b456 100644 --- a/aiogram/dispatcher/handler/base.py +++ b/aiogram/dispatcher/handler/base.py @@ -18,9 +18,12 @@ class BaseHandler(BaseHandlerMixin[T], ABC): Base class for all class-based handlers """ - def __init__(self, event: T, **kwargs: Any) -> None: + def __init__(self, event: T, requirements_data: Dict[str, Any], data: Dict[str, Any]) -> None: self.event: T = event - self.data: Dict[str, Any] = kwargs + self.data: Dict[str, Any] = data + + for req_attr, req in requirements_data.items(): + setattr(self, req_attr, req) @property def bot(self) -> Bot: diff --git a/aiogram/dispatcher/requirement.py b/aiogram/dispatcher/requirement.py index 9082c9d6..f18b8d98 100644 --- a/aiogram/dispatcher/requirement.py +++ b/aiogram/dispatcher/requirement.py @@ -10,6 +10,7 @@ from typing import ( Dict, Generic, Optional, + Type, TypeVar, Union, cast, @@ -157,6 +158,12 @@ def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Require } +def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]: + return { + req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement) + } + + def require( what: _RequiredCallback[T], *,