feat(handler): requirement class

implement POC of "smart defaults", pin newest pydantic 1.5.1 (resolves
issue with BaseFilter signature inspection)
This commit is contained in:
mpa 2020-05-27 12:22:22 +04:00
parent de3c5c1a8d
commit 32ffda2eb7
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8
6 changed files with 238 additions and 21 deletions

View file

@ -1,12 +1,14 @@
import asyncio
import contextvars
import inspect
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
from aiogram.dispatcher.requirement import CacheKeyType, Requirement, get_reqs_from_callable
CallbackType = Callable[..., Awaitable[Any]]
SyncFilter = Callable[..., Any]
@ -14,6 +16,9 @@ AsyncFilter = Callable[..., Awaitable[Any]]
FilterType = Union[SyncFilter, AsyncFilter, BaseFilter]
HandlerType = Union[FilterType, Type[BaseHandler]]
REQUIREMENT_CACHE_KEY = "_req_cache"
ASYNC_STACK_KEY = "_stack"
@dataclass
class CallableMixin:
@ -21,17 +26,18 @@ class CallableMixin:
awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False)
__reqs__: Dict[str, Requirement[Any]] = field(init=False)
def __post_init__(self) -> None:
callback = inspect.unwrap(self.callback)
self.awaitable = inspect.isawaitable(callback) or inspect.iscoroutinefunction(callback)
if isinstance(callback, BaseFilter):
# Pydantic 1.5 has incorrect signature generator
# Issue: https://github.com/samuelcolvin/pydantic/issues/1419
# Fixes: https://github.com/samuelcolvin/pydantic/pull/1427
# TODO: Remove this temporary fix
callback = inspect.unwrap(callback.__call__)
callback = callback.__call__
self.spec = inspect.getfullargspec(callback)
self.__reqs__ = get_reqs_from_callable(callable_=callback)
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if self.spec.varkw:
return kwargs
@ -39,7 +45,18 @@ class CallableMixin:
return {k: v for k, v in kwargs.items() if k in self.spec.args}
async def call(self, *args: Any, **kwargs: Any) -> Any:
if self.__reqs__:
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
for kwarg, default in self.__reqs__.items():
kwargs[kwarg] = await default(cache_dict=cache_dict, stack=stack, data=kwargs)
kwargs.pop(ASYNC_STACK_KEY, None)
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
if self.awaitable:
return await wrapped()

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import functools
from contextlib import AsyncExitStack
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Type, Union
@ -9,7 +10,15 @@ from pydantic import ValidationError
from ...api.types import TelegramObject
from ..filters.base import BaseFilter
from .bases import NOT_HANDLED, MiddlewareType, NextMiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
from .handler import (
ASYNC_STACK_KEY,
REQUIREMENT_CACHE_KEY,
CallbackType,
FilterObject,
FilterType,
HandlerObject,
HandlerType,
)
if TYPE_CHECKING: # pragma: no cover
from aiogram.dispatcher.router import Router
@ -122,14 +131,18 @@ class TelegramEventObserver:
async def _trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
for handler in self.handlers:
result, data = await handler.check(event, **kwargs)
if result:
kwargs.update(data)
try:
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
return await wrapped_inner(event, kwargs)
except SkipHandler:
continue
async with AsyncExitStack() as stack:
# intermediate values. handler.call is responsible for cleaning them up
kwargs.update({ASYNC_STACK_KEY: stack, REQUIREMENT_CACHE_KEY: {}})
result, data = await handler.check(event, **kwargs)
if result:
kwargs.update(data)
try:
wrapped_inner = self._wrap_middleware(self.middlewares, handler.call)
return await wrapped_inner(event, kwargs)
except SkipHandler:
continue
return NOT_HANDLED

View file

@ -0,0 +1,166 @@
import abc
import enum
import inspect
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Generic,
Optional,
TypeVar,
Union,
cast,
)
T = TypeVar("T")
CacheKeyType = Union[str, int]
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]]
class GeneratorKind(enum.IntEnum):
not_a_gen = enum.auto() # not a generator
plain_gen = enum.auto() # proper generator not async
async_gen = enum.auto() # async generator
@asynccontextmanager
async def move_to_async_gen(context_manager: Any) -> Any:
"""
Wrap existing contextmanager into a asynchronous generator, then async ctx manager
"""
try:
yield context_manager.__enter__()
except Exception as e:
err = context_manager.__exit__(type(e), e, None)
if not err:
raise e
else:
context_manager.__exit__(None, None, None)
class Requirement(abc.ABC, Generic[T]):
"""
Interface for all requirements
"""
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
raise NotImplementedError()
class CallableRequirement(Requirement[T]):
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
def __init__(
self,
callable_: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True,
):
self.callable = callable_
self.generator_type = GeneratorKind.not_a_gen
if inspect.isasyncgenfunction(callable_):
self.generator_type = GeneratorKind.async_gen
elif inspect.isgeneratorfunction(callable_):
self.generator_type = GeneratorKind.plain_gen
self.cache_key = hash(callable_) if cache_key is None else cache_key
self.use_cache = use_cache
self.children = get_reqs_from_callable(callable_)
assert not (
set(inspect.signature(callable_).parameters) ^ set(self.children)
), "Required can't manage callbacks with non `Requirement` parameters"
def filter_kwargs(self, data: Dict[str, Any]) -> Dict[str, Any]:
return {key: value for key, value in data.items() if key in self.children}
async def initialize_children(
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack
) -> None:
for req_id, req in self.children.items():
if isinstance(req, CachedRequirement):
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
continue
if isinstance(req, CallableRequirement):
await req.initialize_children(data, cache_dict, stack)
if req.use_cache and req.cache_key in cache_dict:
data[req_id] = cache_dict[req.cache_key]
else:
data[req_id] = await initialize_callable_requirement(req, data, stack)
if req.use_cache:
cache_dict[req.cache_key] = data[req_id]
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
await self.initialize_children(data, cache_dict, stack)
if self.use_cache and self.cache_key in cache_dict:
return cast(T, cache_dict[self.cache_key])
else:
result = await initialize_callable_requirement(self, data, stack)
if self.use_cache:
cache_dict[self.cache_key] = result
return result
class CachedRequirement(Requirement[T]):
__slots__ = "cache_key", "value_on_miss"
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
self.cache_key: CacheKeyType = cache_key
self.value_on_miss = value_on_miss
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
return cache_dict.get(self.cache_key, self.value_on_miss)
async def initialize_callable_requirement(
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack
) -> T:
actual_data = required.filter_kwargs(data)
async_cm: Optional[Any] = None
if required.generator_type is GeneratorKind.async_gen:
async_cm = asynccontextmanager(required.callable)(**actual_data) # type: ignore
elif required.generator_type is GeneratorKind.plain_gen:
async_cm = move_to_async_gen(contextmanager(required.callable)(**actual_data)) # type: ignore
if async_cm is not None:
return await stack.enter_async_context(async_cm)
else:
result = required.callable(**actual_data)
if isinstance(result, Awaitable):
return cast(T, await result)
return cast(T, result)
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]:
signature = inspect.signature(callable_)
return {
param_name: param_self.default
for param_name, param_self in signature.parameters.items()
if isinstance(param_self.default, Requirement)
}
def require(
what: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True,
) -> T:
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore

14
poetry.lock generated
View file

@ -380,7 +380,7 @@ description = "File identification library for Python"
name = "identify"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
version = "1.4.16"
version = "1.4.17"
[package.extras]
license = ["editdistance"]
@ -942,7 +942,6 @@ wcwidth = "*"
[[package]]
category = "dev"
description = "Run a subprocess in a pseudo terminal"
marker = "sys_platform != \"win32\""
name = "ptyprocess"
optional = false
python-versions = "*"
@ -1401,7 +1400,7 @@ fast = ["uvloop"]
proxy = ["aiohttp-socks"]
[metadata]
content-hash = "152bb9b155a00baadd3c8b9fa21f08af719180bddccb8ad6c3dd6548c3e71e3e"
content-hash = "50478042294d42c7ac6eef8df3d2036b4dc42455ca60cdfb9434cd476e27ddfc"
python-versions = "^3.7"
[metadata.files]
@ -1617,8 +1616,8 @@ html5lib = [
{file = "html5lib-1.0.1.tar.gz", hash = "sha256:66cb0dcfdbbc4f9c3ba1a63fdb511ffdbd4f513b2b6d81b80cd26ce6b3fb3736"},
]
identify = [
{file = "identify-1.4.16-py2.py3-none-any.whl", hash = "sha256:0f3c3aac62b51b86fea6ff52fe8ff9e06f57f10411502443809064d23e16f1c2"},
{file = "identify-1.4.16.tar.gz", hash = "sha256:f9ad3d41f01e98eb066b6e05c5b184fd1e925fadec48eb165b4e01c72a1ef3a7"},
{file = "identify-1.4.17-py2.py3-none-any.whl", hash = "sha256:ef6fa3d125c27516f8d1aaa2038c3263d741e8723825eb38350cdc0288ab35eb"},
{file = "identify-1.4.17.tar.gz", hash = "sha256:be66b9673d59336acd18a3a0e0c10d35b8a780309561edf16c46b6b74b83f6af"},
]
idna = [
{file = "idna-2.9-py2.py3-none-any.whl", hash = "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"},
@ -1742,6 +1741,11 @@ markupsafe = [
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"},
{file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"},
{file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"},
]
mccabe = [

View file

@ -34,7 +34,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.7"
aiohttp = "^3.6"
pydantic = "^1.5"
pydantic = "^1.5.1"
Babel = "^2.7"
aiofiles = "^0.4.0"
uvloop = {version = "^0.14.0", markers = "sys_platform == 'darwin' or sys_platform == 'linux'", optional = true}

View file

@ -0,0 +1,17 @@
# todo
from aiogram.dispatcher.requirement import require, CallableRequirement
tick_data = {"ticks": 0}
def test_require():
x = require(lambda: "str", use_cache=True, cache_key=0)
assert isinstance(x, CallableRequirement)
assert callable(x) & callable(x.callable)
assert x.cache_key == 0
assert x.use_cache
class TestCallableRequirementCache:
def test_cache(self):
...