From 00cff4acf5af422cb2f1c5c34c0833d8d021301a Mon Sep 17 00:00:00 2001 From: Martin Winks <50446230+uwinx@users.noreply.github.com> Date: Sat, 22 Aug 2020 01:07:03 +0400 Subject: [PATCH] fix(handlerObj): fix parameter-spec solving (#408) --- aiogram/dispatcher/handler.py | 2 +- tests/test_dispatcher/test_handler.py | 66 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 tests/test_dispatcher/test_handler.py diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index cd5e9b50..38219012 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -33,7 +33,7 @@ def _check_spec(spec: inspect.FullArgSpec, kwargs: dict): if spec.varkw: return kwargs - return {k: v for k, v in kwargs.items() if k in spec.args} + return {k: v for k, v in kwargs.items() if k in set(spec.args + spec.kwonlyargs)} class Handler: diff --git a/tests/test_dispatcher/test_handler.py b/tests/test_dispatcher/test_handler.py new file mode 100644 index 00000000..b823c8f8 --- /dev/null +++ b/tests/test_dispatcher/test_handler.py @@ -0,0 +1,66 @@ +import functools + +import pytest + +from aiogram.dispatcher.handler import Handler, _check_spec, _get_spec + + +def callback1(foo: int, bar: int, baz: int): + return locals() + + +async def callback2(foo: int, bar: int, baz: int): + return locals() + + +async def callback3(foo: int, **kwargs): + return locals() + + +class TestHandlerObj: + def test_init_decorated(self): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + @decorator + def callback1(foo, bar, baz): + pass + + @decorator + @decorator + def callback2(foo, bar, baz): + pass + + obj1 = Handler.HandlerObj(callback1, _get_spec(callback1)) + obj2 = Handler.HandlerObj(callback2, _get_spec(callback2)) + + assert set(obj1.spec.args) == {"foo", "bar", "baz"} + assert obj1.handler == callback1 + assert set(obj2.spec.args) == {"foo", "bar", "baz"} + assert obj2.handler == callback2 + + @pytest.mark.parametrize( + "callback,kwargs,result", + [ + pytest.param( + callback1, {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"} + ), + pytest.param( + callback2, + {"foo": 42, "spam": True, "baz": "fuz", "bar": "test"}, + {"foo": 42, "baz": "fuz", "bar": "test"}, + ), + pytest.param( + callback3, + {"foo": 42, "spam": True, "baz": "fuz", "bar": "test"}, + {"foo": 42, "spam": True, "baz": "fuz", "bar": "test"}, + ), + ], + ) + def test__check_spec(self, callback, kwargs, result): + spec = _get_spec(callback) + assert _check_spec(spec, kwargs) == result