mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat(handler): requirement class
implement POC of "smart defaults", pin newest pydantic 1.5.1 (resolves issue with BaseFilter signature inspection)
This commit is contained in:
parent
de3c5c1a8d
commit
32ffda2eb7
6 changed files with 238 additions and 21 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio
|
||||
import contextvars
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.handler.base import BaseHandler
|
||||
from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable
|
||||
|
||||
CallbackType = Callable[..., Awaitable[Any]]
|
||||
SyncFilter = Callable[..., Any]
|
||||
|
|
@ -14,6 +16,9 @@ AsyncFilter = Callable[..., Awaitable[Any]]
|
|||
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
|
||||
HandlerType = Union[FilterType, Type[BaseHandler]]
|
||||
|
||||
REQUIREMENT_CACHE_KEY = "_req_cache"
|
||||
ASYNC_STACK_KEY = "_stack"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallableMixin:
|
||||
|
|
@ -21,17 +26,18 @@ class CallableMixin:
|
|||
awaitable: bool = field(init=False)
|
||||
spec: inspect.FullArgSpec = field(init=False)
|
||||
|
||||
__reqs__: Dict[str, Requirement[Any]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
callback = inspect.unwrap(self.callback)
|
||||
self.awaitable = inspect.isawaitable(callback) or inspect.iscoroutinefunction(callback)
|
||||
if isinstance(callback, BaseFilter):
|
||||
# Pydantic 1.5 has incorrect signature generator
|
||||
# Issue: https://github.com/samuelcolvin/pydantic/issues/1419
|
||||
# Fixes: https://github.com/samuelcolvin/pydantic/pull/1427
|
||||
# TODO: Remove this temporary fix
|
||||
callback = inspect.unwrap(callback.__call__)
|
||||
callback = callback.__call__
|
||||
|
||||
self.spec = inspect.getfullargspec(callback)
|
||||
|
||||
self.__reqs__ = get_reqs_from_callable(callable_=callback)
|
||||
|
||||
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.spec.varkw:
|
||||
return kwargs
|
||||
|
|
@ -39,7 +45,18 @@ class CallableMixin:
|
|||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||
|
||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if self.__reqs__:
|
||||
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
|
||||
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
|
||||
|
||||
for kwarg, default in self.__reqs__.items():
|
||||
kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs)
|
||||
|
||||
kwargs.pop(ASYNC_STACK_KEY, None)
|
||||
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
|
||||
|
||||
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
|
||||
|
||||
if self.awaitable:
|
||||
return await wrapped()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from contextlib import AsyncExitStack
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
|
|
@ -9,7 +10,15 @@ from pydantic import ValidationError
|
|||
from ...api.types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
from .handler import (
|
||||
ASYNC_STACK_KEY,
|
||||
REQUIREMENT_CACHE_KEY,
|
||||
CallbackType,
|
||||
FilterObject,
|
||||
FilterType,
|
||||
HandlerObject,
|
||||
HandlerType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
|
@ -122,14 +131,18 @@ class TelegramEventObserver:
|
|||
|
||||
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
for handler in self.handlers:
|
||||
result, data = await handler.check(event, **kwargs)
|
||||
if result:
|
||||
kwargs.update(data)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
async with AsyncExitStack() as stack:
|
||||
# intermediate values. handler.call is responsible for cleaning them up
|
||||
kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}})
|
||||
|
||||
result, data = await handler.check(event, **kwargs)
|
||||
if result:
|
||||
kwargs.update(data)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
return NOT_HANDLED
|
||||
|
||||
|
|
|
|||
166
aiogram/dispatcher/requirement.py
Normal file
166
aiogram/dispatcher/requirement.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import abc
|
||||
import enum
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
CacheKeyType = Union[str, int]
|
||||
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]]
|
||||
|
||||
|
||||
class GeneratorKind(enum.IntEnum):
|
||||
not_a_gen = enum.auto() # not a generator
|
||||
plain_gen = enum.auto() # proper generator not async
|
||||
async_gen = enum.auto() # async generator
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def move_to_async_gen(context_manager: Any) -> Any:
|
||||
"""
|
||||
Wrap existing contextmanager into a asynchronous generator, then async ctx manager
|
||||
"""
|
||||
try:
|
||||
yield context_manager.__enter__()
|
||||
except Exception as e:
|
||||
err = context_manager.__exit__(type(e), e, None)
|
||||
if not err:
|
||||
raise e
|
||||
else:
|
||||
context_manager.__exit__(None, None, None)
|
||||
|
||||
|
||||
class Requirement(abc.ABC, Generic[T]):
|
||||
"""
|
||||
Interface for all requirements
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CallableRequirement(Requirement[T]):
|
||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callable_: _RequiredCallback[T],
|
||||
*,
|
||||
cache_key: Optional[CacheKeyType] = None,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
self.callable = callable_
|
||||
self.generator_type = GeneratorKind.not_a_gen
|
||||
|
||||
if inspect.isasyncgenfunction(callable_):
|
||||
self.generator_type = GeneratorKind.async_gen
|
||||
elif inspect.isgeneratorfunction(callable_):
|
||||
self.generator_type = GeneratorKind.plain_gen
|
||||
|
||||
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
||||
self.use_cache = use_cache
|
||||
self.children = get_reqs_from_callable(callable_)
|
||||
|
||||
assert not (
|
||||
set(inspect.signature(callable_).parameters) ^ set(self.children)
|
||||
), "Required can't manage callbacks with non `Requirement` parameters"
|
||||
|
||||
def filter_kwargs(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {key: value for key, value in data.items() if key in self.children}
|
||||
|
||||
async def initialize_children(
|
||||
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack
|
||||
) -> None:
|
||||
for req_id, req in self.children.items():
|
||||
if isinstance(req, CachedRequirement):
|
||||
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
|
||||
continue
|
||||
|
||||
if isinstance(req, CallableRequirement):
|
||||
await req.initialize_children(data, cache_dict, stack)
|
||||
|
||||
if req.use_cache and req.cache_key in cache_dict:
|
||||
data[req_id] = cache_dict[req.cache_key]
|
||||
|
||||
else:
|
||||
data[req_id] = await initialize_callable_requirement(req, data, stack)
|
||||
|
||||
if req.use_cache:
|
||||
cache_dict[req.cache_key] = data[req_id]
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
await self.initialize_children(data, cache_dict, stack)
|
||||
|
||||
if self.use_cache and self.cache_key in cache_dict:
|
||||
return cast(T, cache_dict[self.cache_key])
|
||||
else:
|
||||
result = await initialize_callable_requirement(self, data, stack)
|
||||
if self.use_cache:
|
||||
cache_dict[self.cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
class CachedRequirement(Requirement[T]):
|
||||
__slots__ = "cache_key", "value_on_miss"
|
||||
|
||||
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
|
||||
self.cache_key: CacheKeyType = cache_key
|
||||
self.value_on_miss = value_on_miss
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
return cache_dict.get(self.cache_key, self.value_on_miss)
|
||||
|
||||
|
||||
async def initialize_callable_requirement(
|
||||
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack
|
||||
) -> T:
|
||||
actual_data = required.filter_kwargs(data)
|
||||
async_cm: Optional[Any] = None
|
||||
|
||||
if required.generator_type is GeneratorKind.async_gen:
|
||||
async_cm = asynccontextmanager(required.callable)(**actual_data) # type: ignore
|
||||
elif required.generator_type is GeneratorKind.plain_gen:
|
||||
async_cm = move_to_async_gen(contextmanager(required.callable)(**actual_data)) # type: ignore
|
||||
|
||||
if async_cm is not None:
|
||||
return await stack.enter_async_context(async_cm)
|
||||
else:
|
||||
result = required.callable(**actual_data)
|
||||
if isinstance(result, Awaitable):
|
||||
return cast(T, await result)
|
||||
return cast(T, result)
|
||||
|
||||
|
||||
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]:
|
||||
signature = inspect.signature(callable_)
|
||||
return {
|
||||
param_name: param_self.default
|
||||
for param_name, param_self in signature.parameters.items()
|
||||
if isinstance(param_self.default, Requirement)
|
||||
}
|
||||
|
||||
|
||||
def require(
|
||||
what: _RequiredCallback[T],
|
||||
*,
|
||||
cache_key: Optional[CacheKeyType] = None,
|
||||
use_cache: bool = True,
|
||||
) -> T:
|
||||
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore
|
||||
14
poetry.lock
generated
14
poetry.lock
generated
|
|
@ -380,7 +380,7 @@ description = "File identification library for Python"
|
|||
name = "identify"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
|
||||
version = "1.4.16"
|
||||
version = "1.4.17"
|
||||
|
||||
[package.extras]
|
||||
license = ["editdistance"]
|
||||
|
|
@ -942,7 +942,6 @@ wcwidth = "*"
|
|||
[[package]]
|
||||
category = "dev"
|
||||
description = "Run a subprocess in a pseudo terminal"
|
||||
marker = "sys_platform != \"win32\""
|
||||
name = "ptyprocess"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
|
@ -1401,7 +1400,7 @@ fast = ["uvloop"]
|
|||
proxy = ["aiohttp-socks"]
|
||||
|
||||
[metadata]
|
||||
content-hash = "152bb9b155a00baadd3c8b9fa21f08af719180bddccb8ad6c3dd6548c3e71e3e"
|
||||
content-hash = "50478042294d42c7ac6eef8df3d2036b4dc42455ca60cdfb9434cd476e27ddfc"
|
||||
python-versions = "^3.7"
|
||||
|
||||
[metadata.files]
|
||||
|
|
@ -1617,8 +1616,8 @@ html5lib = [
|
|||
{file = "html5lib-1.0.1.tar.gz", hash = "sha256:66cb0dcfdbbc4f9c3ba1a63fdb511ffdbd4f513b2b6d81b80cd26ce6b3fb3736"},
|
||||
]
|
||||
identify = [
|
||||
{file = "identify-1.4.16-py2.py3-none-any.whl", hash = "sha256:0f3c3aac62b51b86fea6ff52fe8ff9e06f57f10411502443809064d23e16f1c2"},
|
||||
{file = "identify-1.4.16.tar.gz", hash = "sha256:f9ad3d41f01e98eb066b6e05c5b184fd1e925fadec48eb165b4e01c72a1ef3a7"},
|
||||
{file = "identify-1.4.17-py2.py3-none-any.whl", hash = "sha256:ef6fa3d125c27516f8d1aaa2038c3263d741e8723825eb38350cdc0288ab35eb"},
|
||||
{file = "identify-1.4.17.tar.gz", hash = "sha256:be66b9673d59336acd18a3a0e0c10d35b8a780309561edf16c46b6b74b83f6af"},
|
||||
]
|
||||
idna = [
|
||||
{file = "idna-2.9-py2.py3-none-any.whl", hash = "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"},
|
||||
|
|
@ -1742,6 +1741,11 @@ markupsafe = [
|
|||
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"},
|
||||
{file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"},
|
||||
{file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"},
|
||||
{file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"},
|
||||
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"},
|
||||
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"},
|
||||
{file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"},
|
||||
{file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"},
|
||||
{file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"},
|
||||
]
|
||||
mccabe = [
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ classifiers = [
|
|||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
aiohttp = "^3.6"
|
||||
pydantic = "^1.5"
|
||||
pydantic = "^1.5.1"
|
||||
Babel = "^2.7"
|
||||
aiofiles = "^0.4.0"
|
||||
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true}
|
||||
|
|
|
|||
17
tests/test_dispatcher/test_requirement.py
Normal file
17
tests/test_dispatcher/test_requirement.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# todo
|
||||
from aiogram.dispatcher.requirement import require, CallableRequirement
|
||||
|
||||
tick_data = {"ticks": 0}
|
||||
|
||||
|
||||
def test_require():
|
||||
x = require(lambda: "str", use_cache=True, cache_key=0)
|
||||
assert isinstance(x, CallableRequirement)
|
||||
assert callable(x) & callable(x.callable)
|
||||
assert x.cache_key == 0
|
||||
assert x.use_cache
|
||||
|
||||
|
||||
class TestCallableRequirementCache:
|
||||
def test_cache(self):
|
||||
...
|
||||
Loading…
Add table
Add a link
Reference in a new issue