diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 6c2a5e57..e9c1b496 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -1,12 +1,14 @@ import asyncio import contextvars import inspect +from contextlib import AsyncExitStack from dataclasses import dataclass, field from functools import partial -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler +from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable CallbackType = Callable[..., Awaitable[Any]] SyncFilter = Callable[..., Any] @@ -14,6 +16,9 @@ AsyncFilter = Callable[..., Awaitable[Any]] FilterType = Union[SyncFilter, AsyncFilter, BaseFilter] HandlerType = Union[FilterType, Type[BaseHandler]] +REQUIREMENT_CACHE_KEY = "_req_cache" +ASYNC_STACK_KEY = "_stack" + @dataclass class CallableMixin: @@ -21,17 +26,18 @@ class CallableMixin: awaitable: bool = field(init=False) spec: inspect.FullArgSpec = field(init=False) + __reqs__: Dict[str, Requirement[Any]] = field(init=False) + def __post_init__(self) -> None: callback = inspect.unwrap(self.callback) self.awaitable = inspect.isawaitable(callback) or inspect.iscoroutinefunction(callback) if isinstance(callback, BaseFilter): - # Pydantic 1.5 has incorrect signature generator - # Issue: https://github.com/samuelcolvin/pydantic/issues/1419 - # Fixes: https://github.com/samuelcolvin/pydantic/pull/1427 - # TODO: Remove this temporary fix - callback = inspect.unwrap(callback.__call__) + callback = callback.__call__ + self.spec = inspect.getfullargspec(callback) + self.__reqs__ = get_reqs_from_callable(callable_=callback) + def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: if self.spec.varkw: return kwargs @@ -39,7 +45,18 @@ class CallableMixin: return {k: v for k, v in kwargs.items() if k in self.spec.args} async def call(self, *args: Any, **kwargs: Any) -> Any: + if self.__reqs__: + stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY)) + cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {}) + + for kwarg, default in self.__reqs__.items(): + kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs) + + kwargs.pop(ASYNC_STACK_KEY, None) + kwargs.pop(REQUIREMENT_CACHE_KEY, None) + wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) + if self.awaitable: return await wrapped() diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index e72d5db3..be1eac5a 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from contextlib import AsyncExitStack from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union @@ -9,7 +10,15 @@ from pydantic import ValidationError from ...api.types import TelegramObject from ..filters.base import BaseFilter from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler -from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType +from .handler import ( + ASYNC_STACK_KEY, + REQUIREMENT_CACHE_KEY, + CallbackType, + FilterObject, + FilterType, + HandlerObject, + HandlerType, +) if TYPE_CHECKING: # pragma: no cover from aiogram.dispatcher.router import Router @@ -122,14 +131,18 @@ class TelegramEventObserver: async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any: for handler in self.handlers: - result, data = await handler.check(event, **kwargs) - if result: - kwargs.update(data) - try: - wrapped_inner = self._wrap_middleware(self.middlewares, handler.call) - return await wrapped_inner(event, kwargs) - except SkipHandler: - continue + async with AsyncExitStack() as stack: + # intermediate values. handler.call is responsible for cleaning them up + kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}}) + + result, data = await handler.check(event, **kwargs) + if result: + kwargs.update(data) + try: + wrapped_inner = self._wrap_middleware(self.middlewares, handler.call) + return await wrapped_inner(event, kwargs) + except SkipHandler: + continue return NOT_HANDLED diff --git a/aiogram/dispatcher/requirement.py b/aiogram/dispatcher/requirement.py new file mode 100644 index 00000000..9082c9d6 --- /dev/null +++ b/aiogram/dispatcher/requirement.py @@ -0,0 +1,166 @@ +import abc +import enum +import inspect +from contextlib import AsyncExitStack, asynccontextmanager, contextmanager +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + Generic, + Optional, + TypeVar, + Union, + cast, +) + +T = TypeVar("T") +CacheKeyType = Union[str, int] +_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]] + + +class GeneratorKind(enum.IntEnum): + not_a_gen = enum.auto() # not a generator + plain_gen = enum.auto() # proper generator not async + async_gen = enum.auto() # async generator + + +@asynccontextmanager +async def move_to_async_gen(context_manager: Any) -> Any: + """ + Wrap existing contextmanager into a asynchronous generator, then async ctx manager + """ + try: + yield context_manager.__enter__() + except Exception as e: + err = context_manager.__exit__(type(e), e, None) + if not err: + raise e + else: + context_manager.__exit__(None, None, None) + + +class Requirement(abc.ABC, Generic[T]): + """ + Interface for all requirements + """ + + async def __call__( + self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], + ) -> T: + raise NotImplementedError() + + +class CallableRequirement(Requirement[T]): + __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type" + + def __init__( + self, + callable_: _RequiredCallback[T], + *, + cache_key: Optional[CacheKeyType] = None, + use_cache: bool = True, + ): + self.callable = callable_ + self.generator_type = GeneratorKind.not_a_gen + + if inspect.isasyncgenfunction(callable_): + self.generator_type = GeneratorKind.async_gen + elif inspect.isgeneratorfunction(callable_): + self.generator_type = GeneratorKind.plain_gen + + self.cache_key = hash(callable_) if cache_key is None else cache_key + self.use_cache = use_cache + self.children = get_reqs_from_callable(callable_) + + assert not ( + set(inspect.signature(callable_).parameters) ^ set(self.children) + ), "Required can't manage callbacks with non `Requirement` parameters" + + def filter_kwargs(self, data: Dict[str, Any]) -> Dict[str, Any]: + return {key: value for key, value in data.items() if key in self.children} + + async def initialize_children( + self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack + ) -> None: + for req_id, req in self.children.items(): + if isinstance(req, CachedRequirement): + data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack) + continue + + if isinstance(req, CallableRequirement): + await req.initialize_children(data, cache_dict, stack) + + if req.use_cache and req.cache_key in cache_dict: + data[req_id] = cache_dict[req.cache_key] + + else: + data[req_id] = await initialize_callable_requirement(req, data, stack) + + if req.use_cache: + cache_dict[req.cache_key] = data[req_id] + + async def __call__( + self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], + ) -> T: + await self.initialize_children(data, cache_dict, stack) + + if self.use_cache and self.cache_key in cache_dict: + return cast(T, cache_dict[self.cache_key]) + else: + result = await initialize_callable_requirement(self, data, stack) + if self.use_cache: + cache_dict[self.cache_key] = result + return result + + +class CachedRequirement(Requirement[T]): + __slots__ = "cache_key", "value_on_miss" + + def __init__(self, cache_key: CacheKeyType, value_on_miss: T): + self.cache_key: CacheKeyType = cache_key + self.value_on_miss = value_on_miss + + async def __call__( + self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any], + ) -> T: + return cache_dict.get(self.cache_key, self.value_on_miss) + + +async def initialize_callable_requirement( + required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack +) -> T: + actual_data = required.filter_kwargs(data) + async_cm: Optional[Any] = None + + if required.generator_type is GeneratorKind.async_gen: + async_cm = asynccontextmanager(required.callable)(**actual_data) # type: ignore + elif required.generator_type is GeneratorKind.plain_gen: + async_cm = move_to_async_gen(contextmanager(required.callable)(**actual_data)) # type: ignore + + if async_cm is not None: + return await stack.enter_async_context(async_cm) + else: + result = required.callable(**actual_data) + if isinstance(result, Awaitable): + return cast(T, await result) + return cast(T, result) + + +def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]: + signature = inspect.signature(callable_) + return { + param_name: param_self.default + for param_name, param_self in signature.parameters.items() + if isinstance(param_self.default, Requirement) + } + + +def require( + what: _RequiredCallback[T], + *, + cache_key: Optional[CacheKeyType] = None, + use_cache: bool = True, +) -> T: + return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore diff --git a/poetry.lock b/poetry.lock index 72a5c042..9b328339 100644 --- a/poetry.lock +++ b/poetry.lock @@ -380,7 +380,7 @@ description = "File identification library for Python" name = "identify" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -version = "1.4.16" +version = "1.4.17" [package.extras] license = ["editdistance"] @@ -942,7 +942,6 @@ wcwidth = "*" [[package]] category = "dev" description = "Run a subprocess in a pseudo terminal" -marker = "sys_platform != \"win32\"" name = "ptyprocess" optional = false python-versions = "*" @@ -1401,7 +1400,7 @@ fast = ["uvloop"] proxy = ["aiohttp-socks"] [metadata] -content-hash = "152bb9b155a00baadd3c8b9fa21f08af719180bddccb8ad6c3dd6548c3e71e3e" +content-hash = "50478042294d42c7ac6eef8df3d2036b4dc42455ca60cdfb9434cd476e27ddfc" python-versions = "^3.7" [metadata.files] @@ -1617,8 +1616,8 @@ html5lib = [ {file = "html5lib-1.0.1.tar.gz", hash = "sha256:66cb0dcfdbbc4f9c3ba1a63fdb511ffdbd4f513b2b6d81b80cd26ce6b3fb3736"}, ] identify = [ - {file = "identify-1.4.16-py2.py3-none-any.whl", hash = "sha256:0f3c3aac62b51b86fea6ff52fe8ff9e06f57f10411502443809064d23e16f1c2"}, - {file = "identify-1.4.16.tar.gz", hash = "sha256:f9ad3d41f01e98eb066b6e05c5b184fd1e925fadec48eb165b4e01c72a1ef3a7"}, + {file = "identify-1.4.17-py2.py3-none-any.whl", hash = "sha256:ef6fa3d125c27516f8d1aaa2038c3263d741e8723825eb38350cdc0288ab35eb"}, + {file = "identify-1.4.17.tar.gz", hash = "sha256:be66b9673d59336acd18a3a0e0c10d35b8a780309561edf16c46b6b74b83f6af"}, ] idna = [ {file = "idna-2.9-py2.py3-none-any.whl", hash = "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"}, @@ -1742,6 +1741,11 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] mccabe = [ diff --git a/pyproject.toml b/pyproject.toml index 2db33c6c..e878d0f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" aiohttp = "^3.6" -pydantic = "^1.5" +pydantic = "^1.5.1" Babel = "^2.7" aiofiles = "^0.4.0" uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true} diff --git a/tests/test_dispatcher/test_requirement.py b/tests/test_dispatcher/test_requirement.py new file mode 100644 index 00000000..768952d3 --- /dev/null +++ b/tests/test_dispatcher/test_requirement.py @@ -0,0 +1,17 @@ +# todo +from aiogram.dispatcher.requirement import require, CallableRequirement + +tick_data = {"ticks": 0} + + +def test_require(): + x = require(lambda: "str", use_cache=True, cache_key=0) + assert isinstance(x, CallableRequirement) + assert callable(x) & callable(x.callable) + assert x.cache_key == 0 + assert x.use_cache + + +class TestCallableRequirementCache: + def test_cache(self): + ...