mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
refactor(handler): remove require func
refactor requirements feature
This commit is contained in:
parent
de962d043d
commit
e060678b40
4 changed files with 69 additions and 82 deletions
|
|
@ -35,7 +35,7 @@ class CallableMixin:
|
|||
awaitable: bool = field(init=False)
|
||||
spec: inspect.FullArgSpec = field(init=False)
|
||||
|
||||
__reqs__: Dict[str, Requirement[Any]] = field(init=False)
|
||||
requirements: Dict[str, Requirement[Any]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
callback = inspect.unwrap(self.callback)
|
||||
|
|
@ -47,10 +47,9 @@ class CallableMixin:
|
|||
|
||||
if _is_class_handler(callback):
|
||||
self.awaitable = True
|
||||
self.__reqs__ = get_reqs_from_class(callback)
|
||||
|
||||
self.requirements = get_reqs_from_class(callback)
|
||||
else:
|
||||
self.__reqs__ = get_reqs_from_callable(callable_=callback)
|
||||
self.requirements = get_reqs_from_callable(callable_=callback)
|
||||
|
||||
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.spec.varkw:
|
||||
|
|
@ -58,33 +57,33 @@ class CallableMixin:
|
|||
|
||||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||
|
||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
async def call(self, *args: Any, **data: Any) -> Any:
|
||||
# we don't requirements_data and kwargs keys to intersect
|
||||
requirements_data: Dict[str, Any] = {}
|
||||
|
||||
if self.__reqs__:
|
||||
stack = cast(AsyncExitStack, kwargs.get(ASYNC_STACK_KEY))
|
||||
cache_dict: Dict[CacheKeyType, Any] = kwargs.get(REQUIREMENT_CACHE_KEY, {})
|
||||
requirements_data = kwargs.copy()
|
||||
if self.requirements:
|
||||
stack = cast(AsyncExitStack, data.get(ASYNC_STACK_KEY))
|
||||
cache_dict: Dict[CacheKeyType, Any] = data.get(REQUIREMENT_CACHE_KEY, {})
|
||||
requirements_data = data.copy()
|
||||
|
||||
for req_id, req in self.__reqs__.items():
|
||||
for req_id, req in self.requirements.items():
|
||||
requirements_data[req_id] = await req(
|
||||
cache_dict=cache_dict, stack=stack, data=requirements_data
|
||||
cache_dict=cache_dict, stack=stack, data=requirements_data,
|
||||
)
|
||||
|
||||
for to_pop in kwargs:
|
||||
for to_pop in data:
|
||||
requirements_data.pop(to_pop, None)
|
||||
|
||||
kwargs.pop(ASYNC_STACK_KEY, None)
|
||||
kwargs.pop(REQUIREMENT_CACHE_KEY, None)
|
||||
data.pop(ASYNC_STACK_KEY, None)
|
||||
data.pop(REQUIREMENT_CACHE_KEY, None)
|
||||
|
||||
if _is_class_handler(self.callback):
|
||||
wrapped = partial(self.callback, *args, requirements_data, kwargs)
|
||||
wrapped = partial(self.callback, *args, requirements_data, data)
|
||||
else:
|
||||
wrapped = partial(
|
||||
self.callback,
|
||||
*args,
|
||||
**self._prepare_kwargs(kwargs),
|
||||
**self._prepare_kwargs(data),
|
||||
**self._prepare_kwargs(requirements_data),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import abc
|
||||
import enum
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
||||
|
|
@ -8,6 +7,7 @@ from typing import (
|
|||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
|
|
@ -18,7 +18,10 @@ from typing import (
|
|||
|
||||
T = TypeVar("T")
|
||||
CacheKeyType = Union[str, int]
|
||||
_RequiredCallback = Callable[..., Union[T, AsyncGenerator[None, T], Awaitable[T]]]
|
||||
CacheType = Dict[CacheKeyType, Any]
|
||||
_RequiredCallback = Callable[
|
||||
..., Union[T, Generator[T, None, None], AsyncGenerator[None, T], Awaitable[T]]
|
||||
]
|
||||
|
||||
|
||||
class GeneratorKind(enum.IntEnum):
|
||||
|
|
@ -42,28 +45,19 @@ async def move_to_async_gen(context_manager: Any) -> Any:
|
|||
context_manager.__exit__(None, None, None)
|
||||
|
||||
|
||||
class Requirement(abc.ABC, Generic[T]):
|
||||
"""
|
||||
Interface for all requirements
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CallableRequirement(Requirement[T]):
|
||||
class Requirement(Generic[T]):
|
||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callable_: _RequiredCallback[T],
|
||||
*,
|
||||
cache_key: Optional[CacheKeyType] = None,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
self.callable = callable_
|
||||
self.use_cache = use_cache
|
||||
self.children: Dict[str, Requirement[Any]] = {}
|
||||
|
||||
self.generator_type = GeneratorKind.not_a_gen
|
||||
|
||||
if inspect.isasyncgenfunction(callable_):
|
||||
|
|
@ -72,7 +66,6 @@ class CallableRequirement(Requirement[T]):
|
|||
self.generator_type = GeneratorKind.plain_gen
|
||||
|
||||
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
||||
self.use_cache = use_cache
|
||||
self.children = get_reqs_from_callable(callable_)
|
||||
|
||||
assert not (
|
||||
|
|
@ -83,27 +76,23 @@ class CallableRequirement(Requirement[T]):
|
|||
return {key: value for key, value in data.items() if key in self.children}
|
||||
|
||||
async def initialize_children(
|
||||
self, data: Dict[str, Any], cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack
|
||||
self, data: Dict[str, Any], cache_dict: CacheType, stack: AsyncExitStack,
|
||||
) -> None:
|
||||
for req_id, req in self.children.items():
|
||||
if isinstance(req, CachedRequirement):
|
||||
data[req_id] = await req(data=data, cache_dict=cache_dict, stack=stack)
|
||||
continue
|
||||
|
||||
if isinstance(req, CallableRequirement):
|
||||
await req.initialize_children(data, cache_dict, stack)
|
||||
await req.initialize_children(data, cache_dict, stack)
|
||||
|
||||
if req.use_cache and req.cache_key in cache_dict:
|
||||
data[req_id] = cache_dict[req.cache_key]
|
||||
if req.use_cache and req.cache_key in cache_dict:
|
||||
data[req_id] = cache_dict[req.cache_key]
|
||||
|
||||
else:
|
||||
data[req_id] = await initialize_callable_requirement(req, data, stack)
|
||||
else:
|
||||
data[req_id] = await initialize_callable_requirement(req, data, stack)
|
||||
|
||||
if req.use_cache:
|
||||
cache_dict[req.cache_key] = data[req_id]
|
||||
if req.use_cache:
|
||||
cache_dict[req.cache_key] = data[req_id]
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
self, *, cache_dict: CacheType, stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
await self.initialize_children(data, cache_dict, stack)
|
||||
|
||||
|
|
@ -113,24 +102,12 @@ class CallableRequirement(Requirement[T]):
|
|||
result = await initialize_callable_requirement(self, data, stack)
|
||||
if self.use_cache:
|
||||
cache_dict[self.cache_key] = result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class CachedRequirement(Requirement[T]):
|
||||
__slots__ = "cache_key", "value_on_miss"
|
||||
|
||||
def __init__(self, cache_key: CacheKeyType, value_on_miss: T):
|
||||
self.cache_key: CacheKeyType = cache_key
|
||||
self.value_on_miss = value_on_miss
|
||||
|
||||
async def __call__(
|
||||
self, cache_dict: Dict[CacheKeyType, Any], stack: AsyncExitStack, data: Dict[str, Any],
|
||||
) -> T:
|
||||
return cache_dict.get(self.cache_key, self.value_on_miss)
|
||||
|
||||
|
||||
async def initialize_callable_requirement(
|
||||
required: CallableRequirement[T], data: Dict[str, Any], stack: AsyncExitStack
|
||||
required: Requirement[T], data: Dict[str, Any], stack: AsyncExitStack
|
||||
) -> T:
|
||||
actual_data = required.filter_kwargs(data)
|
||||
async_cm: Optional[Any] = None
|
||||
|
|
@ -162,12 +139,3 @@ def get_reqs_from_class(cls: Type[Any]) -> Dict[str, Requirement[Any]]:
|
|||
return {
|
||||
req_attr: req for req_attr, req in cls.__dict__.items() if isinstance(req, Requirement)
|
||||
}
|
||||
|
||||
|
||||
def require(
|
||||
what: _RequiredCallback[T],
|
||||
*,
|
||||
cache_key: Optional[CacheKeyType] = None,
|
||||
use_cache: bool = True,
|
||||
) -> T:
|
||||
return CallableRequirement(what, cache_key=cache_key, use_cache=use_cache) # type: ignore
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ async def callback2(foo: int, bar: int, baz: int):
|
|||
return locals()
|
||||
|
||||
|
||||
async def callback3(foo: int, **kwargs):
|
||||
async def callback3(foo: int, **data):
|
||||
return locals()
|
||||
|
||||
|
||||
|
|
@ -62,8 +62,8 @@ class TestCallableMixin:
|
|||
def test_init_decorated(self):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
def wrapper(*args, **data):
|
||||
return func(*args, **data)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ class TestCallableMixin:
|
|||
assert obj2.callback == callback2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callback,kwargs,result",
|
||||
"callback,data,result",
|
||||
[
|
||||
pytest.param(
|
||||
callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
|
||||
|
|
@ -108,9 +108,9 @@ class TestCallableMixin:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_kwargs(self, callback, kwargs, result):
|
||||
def test_prepare_data(self, callback, data, result):
|
||||
obj = CallableMixin(callback)
|
||||
assert obj._prepare_kwargs(kwargs) == result
|
||||
assert obj._prepare_kwargs(data) == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_call(self):
|
||||
|
|
@ -127,8 +127,8 @@ class TestCallableMixin:
|
|||
assert result == {"foo": 42, "bar": "test", "baz": "fuz"}
|
||||
|
||||
|
||||
async def simple_handler(*args, **kwargs):
|
||||
return args, kwargs
|
||||
async def simple_handler(*args, **data):
|
||||
return args, data
|
||||
|
||||
|
||||
class TestHandlerObject:
|
||||
|
|
|
|||
|
|
@ -1,17 +1,37 @@
|
|||
# todo
|
||||
from aiogram.dispatcher.requirement import require, CallableRequirement
|
||||
from aiogram.dispatcher.requirement import Requirement, get_reqs_from_class, get_reqs_from_callable
|
||||
|
||||
tick_data = {"ticks": 0}
|
||||
|
||||
|
||||
req1 = Requirement(lambda: 1)
|
||||
req2 = Requirement(lambda: 1)
|
||||
|
||||
|
||||
async def callback(
|
||||
o,
|
||||
x=req1,
|
||||
y=req2
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
def test_require():
|
||||
x = require(lambda: "str", use_cache=True, cache_key=0)
|
||||
assert isinstance(x, CallableRequirement)
|
||||
x = Requirement(lambda: "str", use_cache=True, cache_key=0)
|
||||
assert isinstance(x, Requirement)
|
||||
assert callable(x) & callable(x.callable)
|
||||
assert x.cache_key == 0
|
||||
assert x.use_cache
|
||||
|
||||
|
||||
class TestCallableRequirementCache:
|
||||
def test_cache(self):
|
||||
...
|
||||
class TestReqUtils:
|
||||
def test_get_reqs_from_callable(self):
|
||||
assert set(get_reqs_from_callable(callback).values()) == {req1, req2}
|
||||
assert set(get_reqs_from_callable(callback).keys()) == {"x", "y"}
|
||||
|
||||
def test_get_reqs_from_class(self):
|
||||
class Class:
|
||||
x = req1
|
||||
y = req2
|
||||
|
||||
assert set(get_reqs_from_class(Class)) == {"x", "y"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue