refactor(handler): remove require func

refactor requirements feature
This commit is contained in:
mpa 2020-05-31 09:09:56 +04:00
parent de962d043d
commit e060678b40
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8
4 changed files with 69 additions and 82 deletions

View file

@ -35,7 +35,7 @@ class CallableMixin:
awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False)
__reqs__: Dict[str, Requirement[Any]] = field(init=False)
requirements: Dict[str, Requirement[Any]] = field(init=False)
def __post_init__(self) -> None:
callback = inspect.unwrap(self.callback)
@ -47,10 +47,9 @@ class CallableMixin:
if _is_class_handler(callback):
self.awaitable = True
self.__reqs__ = get_reqs_from_class(callback)
self.requirements = get_reqs_from_class(callback)
else:
self.__reqs__ = get_reqs_from_callable(callable_=callback)
self.requirements = get_reqs_from_callable(callable_=callback)
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if self.spec.varkw:
@ -58,33 +57,33 @@ class CallableMixin:
return {k: v for k, v in kwargs.items() if k in self.spec.args}
async def call(self, *args: Any, **kwargs: Any) -> Any:
async def call(self, *args: Any, **data: Any) -> Any:
# we don't requirements_data and kwargs keys to intersect
requirements_data: Dict[str, Any] = {}
if self.__reqs__:
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
requirements_data = kwargs.copy()
if self.requirements:
stack = cast(AsyncExitStack, data.get(ASYNC_STACK_KEY))
cache_dict: Dict[CacheKeyType, Any] = data.get(REQUIREMENT_CACHE_KEY, {})
requirements_data = data.copy()
for req_id, req in self.__reqs__.items():
for req_id, req in self.requirements.items():
requirements_data[req_id] = await req(
cache_dict=cache_dict, stack=stack, data=requirements_data
cache_dict=cache_dict, stack=stack, data=requirements_data,
)
for to_pop in kwargs:
for to_pop in data:
requirements_data.pop(to_pop, None)
kwargs.pop(ASYNC_STACK_KEY, None)
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
data.pop(ASYNC_STACK_KEY, None)
data.pop(REQUIREMENT_CACHE_KEY, None)
if _is_class_handler(self.callback):
wrapped = partial(self.callback, *args, requirements_data, kwargs)
wrapped = partial(self.callback, *args, requirements_data, data)
else:
wrapped = partial(
self.callback,
*args,
**self._prepare_kwargs(kwargs),
**self._prepare_kwargs(data),
**self._prepare_kwargs(requirements_data),
)

View file

@ -1,4 +1,3 @@
import abc
import enum
import inspect
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
@ -8,6 +7,7 @@ from typing import (
Awaitable,
Callable,
Dict,
Generator,
Generic,
Optional,
Type,
@ -18,7 +18,10 @@ from typing import (
T = TypeVar("T")
CacheKeyType = Union[str, int]
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]]
CacheType = Dict[CacheKeyType, Any]
_RequiredCallback = Callable[
..., Union[T, Generator[T, None, None], AsyncGenerator[None, T], Awaitable[T]]
]
class GeneratorKind(enum.IntEnum):
@ -42,28 +45,19 @@ async def move_to_async_gen(context_manager: Any) -> Any:
context_manager.__exit__(None, None, None)
class Requirement(abc.ABC, Generic[T]):
"""
Interface for all requirements
"""
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
raise NotImplementedError()
class CallableRequirement(Requirement[T]):
class Requirement(Generic[T]):
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
def __init__(
self,
callable_: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True,
):
self.callable = callable_
self.use_cache = use_cache
self.children: Dict[str, Requirement[Any]] = {}
self.generator_type = GeneratorKind.not_a_gen
if inspect.isasyncgenfunction(callable_):
@ -72,7 +66,6 @@ class CallableRequirement(Requirement[T]):
self.generator_type = GeneratorKind.plain_gen
self.cache_key = hash(callable_) if cache_key is None else cache_key
self.use_cache = use_cache
self.children = get_reqs_from_callable(callable_)
assert not (
@ -83,27 +76,23 @@ class CallableRequirement(Requirement[T]):
return {key: value for key, value in data.items() if key in self.children}
async def initialize_children(
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack
self, data: Dict[str, Any], cache_dict: CacheType, stack: AsyncExitStack,
) -> None:
for req_id, req in self.children.items():
if isinstance(req, CachedRequirement):
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
continue
if isinstance(req, CallableRequirement):
await req.initialize_children(data, cache_dict, stack)
await req.initialize_children(data, cache_dict, stack)
if req.use_cache and req.cache_key in cache_dict:
data[req_id] = cache_dict[req.cache_key]
if req.use_cache and req.cache_key in cache_dict:
data[req_id] = cache_dict[req.cache_key]
else:
data[req_id] = await initialize_callable_requirement(req, data, stack)
else:
data[req_id] = await initialize_callable_requirement(req, data, stack)
if req.use_cache:
cache_dict[req.cache_key] = data[req_id]
if req.use_cache:
cache_dict[req.cache_key] = data[req_id]
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
self, *, cache_dict: CacheType, stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
await self.initialize_children(data, cache_dict, stack)
@ -113,24 +102,12 @@ class CallableRequirement(Requirement[T]):
result = await initialize_callable_requirement(self, data, stack)
if self.use_cache:
cache_dict[self.cache_key] = result
return result
class CachedRequirement(Requirement[T]):
__slots__ = "cache_key", "value_on_miss"
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
self.cache_key: CacheKeyType = cache_key
self.value_on_miss = value_on_miss
async def __call__(
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
) -> T:
return cache_dict.get(self.cache_key, self.value_on_miss)
async def initialize_callable_requirement(
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack
required: Requirement[T], data: Dict[str, Any], stack: AsyncExitStack
) -> T:
actual_data = required.filter_kwargs(data)
async_cm: Optional[Any] = None
@ -162,12 +139,3 @@ def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
return {
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
}
def require(
what: _RequiredCallback[T],
*,
cache_key: Optional[CacheKeyType] = None,
use_cache: bool = True,
) -> T:
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore

View file

@ -18,7 +18,7 @@ async def callback2(foo: int, bar: int, baz: int):
return locals()
async def callback3(foo: int, **kwargs):
async def callback3(foo: int, **data):
return locals()
@ -62,8 +62,8 @@ class TestCallableMixin:
def test_init_decorated(self):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
def wrapper(*args, **data):
return func(*args, **data)
return wrapper
@ -85,7 +85,7 @@ class TestCallableMixin:
assert obj2.callback == callback2
@pytest.mark.parametrize(
"callback,kwargs,result",
"callback,data,result",
[
pytest.param(
callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
@ -108,9 +108,9 @@ class TestCallableMixin:
),
],
)
def test_prepare_kwargs(self, callback, kwargs, result):
def test_prepare_data(self, callback, data, result):
obj = CallableMixin(callback)
assert obj._prepare_kwargs(kwargs) == result
assert obj._prepare_kwargs(data) == result
@pytest.mark.asyncio
async def test_sync_call(self):
@ -127,8 +127,8 @@ class TestCallableMixin:
assert result == {"foo": 42, "bar": "test", "baz": "fuz"}
async def simple_handler(*args, **kwargs):
return args, kwargs
async def simple_handler(*args, **data):
return args, data
class TestHandlerObject:

View file

@ -1,17 +1,37 @@
# todo
from aiogram.dispatcher.requirement import require, CallableRequirement
from aiogram.dispatcher.requirement import Requirement, get_reqs_from_class, get_reqs_from_callable
tick_data = {"ticks": 0}
req1 = Requirement(lambda: 1)
req2 = Requirement(lambda: 1)
async def callback(
o,
x=req1,
y=req2
):
...
def test_require():
x = require(lambda: "str", use_cache=True, cache_key=0)
assert isinstance(x, CallableRequirement)
x = Requirement(lambda: "str", use_cache=True, cache_key=0)
assert isinstance(x, Requirement)
assert callable(x) & callable(x.callable)
assert x.cache_key == 0
assert x.use_cache
class TestCallableRequirementCache:
def test_cache(self):
...
class TestReqUtils:
def test_get_reqs_from_callable(self):
assert set(get_reqs_from_callable(callback).values()) == {req1, req2}
assert set(get_reqs_from_callable(callback).keys()) == {"x", "y"}
def test_get_reqs_from_class(self):
class Class:
x = req1
y = req2
assert set(get_reqs_from_class(Class)) == {"x", "y"}