diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 4d254f79..b996797c 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -35,7 +35,7 @@ class CallableMixin: awaitable: bool = field(init=False) spec: inspect.FullArgSpec = field(init=False) - __reqs__: Dict[str, Requirement[Any]] = field(init=False) + requirements: Dict[str, Requirement[Any]] = field(init=False) def __post_init__(self) -> None: callback = inspect.unwrap(self.callback) @@ -47,10 +47,9 @@ class CallableMixin: if _is_class_handler(callback): self.awaitable = True - self.__reqs__ = get_reqs_from_class(callback) - + self.requirements = get_reqs_from_class(callback) else: - self.__reqs__ = get_reqs_from_callable(callable_=callback) + self.requirements = get_reqs_from_callable(callable_=callback) def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: if self.spec.varkw: @@ -58,33 +57,33 @@ class CallableMixin: return {k: v for k, v in kwargs.items() if k in self.spec.args} - async def call(self, *args: Any, **kwargs: Any) -> Any: + async def call(self, *args: Any, **data: Any) -> Any: # we don't requirements_data and kwargs keys to intersect requirements_data: Dict[str, Any] = {} - if self.__reqs__: - stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY)) - cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {}) - requirements_data = kwargs.copy() + if self.requirements: + stack = cast(AsyncExitStack, data.get(ASYNC_STACK_KEY)) + cache_dict: Dict[CacheKeyType, Any] = data.get(REQUIREMENT_CACHE_KEY, {}) + requirements_data = data.copy() - for req_id, req in self.__reqs__.items(): + for req_id, req in self.requirements.items(): requirements_data[req_id] = await req( - cache_dict=cache_dict, stack=stack, data=requirements_data + cache_dict=cache_dict, stack=stack, data=requirements_data, ) - for to_pop in kwargs: + for to_pop in data: requirements_data.pop(to_pop, None) - kwargs.pop(ASYNC_STACK_KEY, None) - kwargs.pop(REQUIREMENT_CACHE_KEY, None) + data.pop(ASYNC_STACK_KEY, None) + data.pop(REQUIREMENT_CACHE_KEY, None) if _is_class_handler(self.callback): - wrapped = partial(self.callback, *args, requirements_data, kwargs) + wrapped = partial(self.callback, *args, requirements_data, data) else: wrapped = partial( self.callback, *args, - **self._prepare_kwargs(kwargs), + **self._prepare_kwargs(data), **self._prepare_kwargs(requirements_data), ) diff --git a/aiogram/dispatcher/requirement.py b/aiogram/dispatcher/requirement.py index f18b8d98..1d7b0ab0 100644 --- a/aiogram/dispatcher/requirement.py +++ b/aiogram/dispatcher/requirement.py @@ -1,4 +1,3 @@ -import abc import enum import inspect from contextlib import AsyncExitStack, asynccontextmanager, contextmanager @@ -8,6 +7,7 @@ from typing import ( Awaitable, Callable, Dict, + Generator, Generic, Optional, Type, @@ -18,7 +18,10 @@ from typing import ( T = TypeVar("T") CacheKeyType = Union[str, int] -_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]] +CacheType = Dict[CacheKeyType, Any] +_RequiredCallback = Callable[ + ..., Union[T, Generator[T, None, None], AsyncGenerator[None, T], Awaitable[T]] +] class GeneratorKind(enum.IntEnum): @@ -42,28 +45,19 @@ async def move_to_async_gen(context_manager: Any) -> Any: context_manager.__exit__(None, None, None) -class Requirement(abc.ABC, Generic[T]): - """ - Interface for all requirements - """ - - async def __call__( - self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], - ) -> T: - raise NotImplementedError() - - -class CallableRequirement(Requirement[T]): +class Requirement(Generic[T]): __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type" def __init__( self, callable_: _RequiredCallback[T], - *, cache_key: Optional[CacheKeyType] = None, use_cache: bool = True, ): self.callable = callable_ + self.use_cache = use_cache + self.children: Dict[str, Requirement[Any]] = {} + self.generator_type = GeneratorKind.not_a_gen if inspect.isasyncgenfunction(callable_): @@ -72,7 +66,6 @@ class CallableRequirement(Requirement[T]): self.generator_type = GeneratorKind.plain_gen self.cache_key = hash(callable_) if cache_key is None else cache_key - self.use_cache = use_cache self.children = get_reqs_from_callable(callable_) assert not ( @@ -83,27 +76,23 @@ class CallableRequirement(Requirement[T]): return {key: value for key, value in data.items() if key in self.children} async def initialize_children( - self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack + self, data: Dict[str, Any], cache_dict: CacheType, stack: AsyncExitStack, ) -> None: for req_id, req in self.children.items(): - if isinstance(req, CachedRequirement): - data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack) - continue - if isinstance(req, CallableRequirement): - await req.initialize_children(data, cache_dict, stack) + await req.initialize_children(data, cache_dict, stack) - if req.use_cache and req.cache_key in cache_dict: - data[req_id] = cache_dict[req.cache_key] + if req.use_cache and req.cache_key in cache_dict: + data[req_id] = cache_dict[req.cache_key] - else: - data[req_id] = await initialize_callable_requirement(req, data, stack) + else: + data[req_id] = await initialize_callable_requirement(req, data, stack) - if req.use_cache: - cache_dict[req.cache_key] = data[req_id] + if req.use_cache: + cache_dict[req.cache_key] = data[req_id] async def __call__( - self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], + self, *, cache_dict: CacheType, stack: AsyncExitStack, data: Dict[str, Any], ) -> T: await self.initialize_children(data, cache_dict, stack) @@ -113,24 +102,12 @@ class CallableRequirement(Requirement[T]): result = await initialize_callable_requirement(self, data, stack) if self.use_cache: cache_dict[self.cache_key] = result + return result -class CachedRequirement(Requirement[T]): - __slots__ = "cache_key", "value_on_miss" - - def __init__(self, cache_key: CacheKeyType, value_on_miss: T): - self.cache_key: CacheKeyType = cache_key - self.value_on_miss = value_on_miss - - async def __call__( - self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], - ) -> T: - return cache_dict.get(self.cache_key, self.value_on_miss) - - async def initialize_callable_requirement( - required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack + required: Requirement[T], data: Dict[str, Any], stack: AsyncExitStack ) -> T: actual_data = required.filter_kwargs(data) async_cm: Optional[Any] = None @@ -162,12 +139,3 @@ def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]: return { req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement) } - - -def require( - what: _RequiredCallback[T], - *, - cache_key: Optional[CacheKeyType] = None, - use_cache: bool = True, -) -> T: - return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore diff --git a/tests/test_dispatcher/test_event/test_handler.py b/tests/test_dispatcher/test_event/test_handler.py index 3db676db..d174fc57 100644 --- a/tests/test_dispatcher/test_event/test_handler.py +++ b/tests/test_dispatcher/test_event/test_handler.py @@ -18,7 +18,7 @@ async def callback2(foo: int, bar: int, baz: int): return locals() -async def callback3(foo: int, **kwargs): +async def callback3(foo: int, **data): return locals() @@ -62,8 +62,8 @@ class TestCallableMixin: def test_init_decorated(self): def decorator(func): @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) + def wrapper(*args, **data): + return func(*args, **data) return wrapper @@ -85,7 +85,7 @@ class TestCallableMixin: assert obj2.callback == callback2 @pytest.mark.parametrize( - "callback,kwargs,result", + "callback,data,result", [ pytest.param( callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"} @@ -108,9 +108,9 @@ class TestCallableMixin: ), ], ) - def test_prepare_kwargs(self, callback, kwargs, result): + def test_prepare_data(self, callback, data, result): obj = CallableMixin(callback) - assert obj._prepare_kwargs(kwargs) == result + assert obj._prepare_kwargs(data) == result @pytest.mark.asyncio async def test_sync_call(self): @@ -127,8 +127,8 @@ class TestCallableMixin: assert result == {"foo": 42, "bar": "test", "baz": "fuz"} -async def simple_handler(*args, **kwargs): - return args, kwargs +async def simple_handler(*args, **data): + return args, data class TestHandlerObject: diff --git a/tests/test_dispatcher/test_requirement.py b/tests/test_dispatcher/test_requirement.py index 768952d3..c1ac6893 100644 --- a/tests/test_dispatcher/test_requirement.py +++ b/tests/test_dispatcher/test_requirement.py @@ -1,17 +1,37 @@ # todo -from aiogram.dispatcher.requirement import require, CallableRequirement +from aiogram.dispatcher.requirement import Requirement, get_reqs_from_class, get_reqs_from_callable tick_data = {"ticks": 0} +req1 = Requirement(lambda: 1) +req2 = Requirement(lambda: 1) + + +async def callback( + o, + x=req1, + y=req2 +): + ... + + def test_require(): - x = require(lambda: "str", use_cache=True, cache_key=0) - assert isinstance(x, CallableRequirement) + x = Requirement(lambda: "str", use_cache=True, cache_key=0) + assert isinstance(x, Requirement) assert callable(x) & callable(x.callable) assert x.cache_key == 0 assert x.use_cache -class TestCallableRequirementCache: - def test_cache(self): - ... +class TestReqUtils: + def test_get_reqs_from_callable(self): + assert set(get_reqs_from_callable(callback).values()) == {req1, req2} + assert set(get_reqs_from_callable(callback).keys()) == {"x", "y"} + + def test_get_reqs_from_class(self): + class Class: + x = req1 + y = req2 + + assert set(get_reqs_from_class(Class)) == {"x", "y"}