From 04ccb390d5c25a3c8882e04738a484572fd17998 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 8 Jan 2023 16:49:34 +0200 Subject: [PATCH] Callback answer feature (#1091) * Added callback answer feature * Fixed typehints and tests * Make context manager in tests compatible with Python 3.8 --- CHANGES/1091.feature.rst | 1 + aiogram/exceptions.py | 4 + aiogram/utils/callback_answer.py | 211 ++++++++++++++++++++++ docs/dispatcher/flags.rst | 2 + docs/utils/callback_answer.rst | 106 +++++++++++ docs/utils/index.rst | 1 + tests/test_filters/test_callback_data.py | 62 +++---- tests/test_utils/test_callback_answer.py | 217 +++++++++++++++++++++++ 8 files changed, 574 insertions(+), 30 deletions(-) create mode 100644 CHANGES/1091.feature.rst create mode 100644 aiogram/utils/callback_answer.py create mode 100644 docs/utils/callback_answer.rst create mode 100644 tests/test_utils/test_callback_answer.py diff --git a/CHANGES/1091.feature.rst b/CHANGES/1091.feature.rst new file mode 100644 index 00000000..b3e773ce --- /dev/null +++ b/CHANGES/1091.feature.rst @@ -0,0 +1 @@ +Added :ref:`callback answer ` feature diff --git a/aiogram/exceptions.py b/aiogram/exceptions.py index a32c3ec6..15f5c45f 100644 --- a/aiogram/exceptions.py +++ b/aiogram/exceptions.py @@ -24,6 +24,10 @@ class DetailedAiogramError(AiogramError): return f"{type(self).__name__}('{self}')" +class CallbackAnswerException(AiogramError): + pass + + class TelegramAPIError(DetailedAiogramError): def __init__( self, diff --git a/aiogram/utils/callback_answer.py b/aiogram/utils/callback_answer.py new file mode 100644 index 00000000..1515a7e6 --- /dev/null +++ b/aiogram/utils/callback_answer.py @@ -0,0 +1,211 @@ +from typing import Any, Awaitable, Callable, Dict, Optional, Union + +from aiogram import BaseMiddleware, loggers +from aiogram.dispatcher.flags import get_flag +from aiogram.exceptions import CallbackAnswerException +from aiogram.methods import AnswerCallbackQuery +from aiogram.types import CallbackQuery, TelegramObject + + +class CallbackAnswer: + def __init__( + self, + answered: bool, + disabled: bool = False, + text: Optional[str] = None, + show_alert: Optional[bool] = None, + url: Optional[str] = None, + cache_time: Optional[int] = None, + ) -> None: + """ + Callback answer configuration + + :param answered: this request is already answered by middleware + :param disabled: answer will not be performed + :param text: answer with text + :param show_alert: show alert + :param url: game url + :param cache_time: cache answer for some time + """ + self._answered = answered + self._disabled = disabled + self._text = text + self._show_alert = show_alert + self._url = url + self._cache_time = cache_time + + def disable(self) -> None: + """ + Deactivate answering for this handler + """ + self.disabled = True + + @property + def disabled(self) -> bool: + """Indicates that automatic answer is disabled in this handler""" + return self._disabled + + @disabled.setter + def disabled(self, value: bool) -> None: + if self._answered: + raise CallbackAnswerException("Can't change disabled state after answer") + self._disabled = value + + @property + def answered(self) -> bool: + """ + Indicates that request is already answered by middleware + """ + return self._answered + + @property + def text(self) -> Optional[str]: + """ + Response text + :return: + """ + return self._text + + @text.setter + def text(self, value: Optional[str]) -> None: + if self._answered: + raise CallbackAnswerException("Can't change text after answer") + self._text = value + + @property + def show_alert(self) -> Optional[bool]: + """ + Whether to display an alert + """ + return self._show_alert + + @show_alert.setter + def show_alert(self, value: Optional[bool]) -> None: + if self._answered: + raise CallbackAnswerException("Can't change show_alert after answer") + self._show_alert = value + + @property + def url(self) -> Optional[str]: + """ + Game url + """ + return self._url + + @url.setter + def url(self, value: Optional[str]) -> None: + if self._answered: + raise CallbackAnswerException("Can't change url after answer") + self._url = value + + @property + def cache_time(self) -> Optional[int]: + """ + Response cache time + """ + return self._cache_time + + @cache_time.setter + def cache_time(self, value: Optional[int]) -> None: + if self._answered: + raise CallbackAnswerException("Can't change cache_time after answer") + self._cache_time = value + + def __str__(self) -> str: + args = ", ".join( + f"{k}={v!r}" + for k, v in { + "answered": self.answered, + "disabled": self.disabled, + "text": self.text, + "show_alert": self.show_alert, + "url": self.url, + "cache_time": self.cache_time, + }.items() + if v is not None + ) + return f"{type(self).__name__}({args})" + + +class CallbackAnswerMiddleware(BaseMiddleware): + def __init__( + self, + pre: bool = False, + text: Optional[str] = None, + show_alert: Optional[bool] = None, + url: Optional[str] = None, + cache_time: Optional[int] = None, + ) -> None: + """ + Inner middleware for callback query handlers, can be useful in bots with a lot of callback + handlers to automatically take answer to all requests + + :param pre: send answer before execute handler + :param text: answer with text + :param show_alert: show alert + :param url: game url + :param cache_time: cache answer for some time + """ + self.pre = pre + self.text = text + self.show_alert = show_alert + self.url = url + self.cache_time = cache_time + + async def __call__( + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: TelegramObject, + data: Dict[str, Any], + ) -> Any: + if not isinstance(event, CallbackQuery): + return await handler(event, data) + + callback_answer = data["callback_answer"] = self.construct_callback_answer( + properties=get_flag(data, "callback_answer") + ) + + if not callback_answer.disabled and callback_answer.answered: + await self.answer(event, callback_answer) + try: + return await handler(event, data) + finally: + if not callback_answer.disabled and not callback_answer.answered: + await self.answer(event, callback_answer) + + def construct_callback_answer( + self, properties: Optional[Union[Dict[str, Any], bool]] + ) -> CallbackAnswer: + pre, disabled, text, show_alert, url, cache_time = ( + self.pre, + False, + self.text, + self.show_alert, + self.url, + self.cache_time, + ) + if isinstance(properties, dict): + pre = properties.get("pre", pre) + disabled = properties.get("disabled", disabled) + text = properties.get("text", text) + show_alert = properties.get("show_alert", show_alert) + url = properties.get("url", url) + cache_time = properties.get("cache_time", cache_time) + + return CallbackAnswer( + answered=pre, + disabled=disabled, + text=text, + show_alert=show_alert, + url=url, + cache_time=cache_time, + ) + + def answer(self, event: CallbackQuery, callback_answer: CallbackAnswer) -> AnswerCallbackQuery: + loggers.middlewares.info("Answer to callback query id=%s", event.id) + return event.answer( + text=callback_answer.text, + show_alert=callback_answer.show_alert, + url=callback_answer.url, + cache_time=callback_answer.cache_time, + ) diff --git a/docs/dispatcher/flags.rst b/docs/dispatcher/flags.rst index 1e65f28e..b958de1d 100644 --- a/docs/dispatcher/flags.rst +++ b/docs/dispatcher/flags.rst @@ -1,3 +1,5 @@ +.. _flags: + ===== Flags ===== diff --git a/docs/utils/callback_answer.rst b/docs/utils/callback_answer.rst new file mode 100644 index 00000000..f7fc4715 --- /dev/null +++ b/docs/utils/callback_answer.rst @@ -0,0 +1,106 @@ +.. _callback-answer-util: +=============== +Callback answer +=============== + +Helper for callback query handlers, can be useful in bots with a lot of callback +handlers to automatically take answer to all requests. + +Simple usage +============ + +For use, it is enough to register the inner middleware :class:`aiogram.utils.callback_answer.CallbackAnswerMiddleware` in dispatcher or specific router: + +.. code-block:: python + + dispatcher.callback_query.middleware(CallbackAnswerMiddleware()) + +After that all handled callback queries will be answered automatically after processing the handler. + +Advanced usage +============== + +In some cases you need to have some non-standard response parameters, this can be done in several ways: + +Global defaults +--------------- + +Change default parameters while initializing middleware, for example change answer to `pre` mode and text "OK": + +.. code-block:: python + + dispatcher.callback_query.middleware(CallbackAnswerMiddleware(pre=True, text="OK")) + + +Look at :class:`aiogram.utils.callback_answer.CallbackAnswerMiddleware` to get all available parameters + + +Handler specific +---------------- + +By using :ref:`flags ` you can change the behavior for specific handler + +.. code-block:: python + + @router.callback_query() + @flags.callback_answer(text="Thanks", cache_time=30) + async def my_handler(query: CallbackQuery): + ... + +Flag arguments is the same as in :class:`aiogram.utils.callback_answer.CallbackAnswerMiddleware` +with additional one :code:`disabled` to disable answer. + +A special case +-------------- + +It is not always correct to answer the same in every case, +so there is an option to change the answer inside the handler. You can get an instance of :class:`aiogram.utils.callback_answer.CallbackAnswer` object inside handler and change whatever you want. + +.. danger:: + + Note that is impossible to change callback answer attributes when you use :code:`pre=True` mode. + +.. code-block:: python + + @router.callback_query() + async def my_handler(query: CallbackQuery, callback_answer: CallbackAnswer): + ... + if : + callback_answer.text = "All is ok" + else: + callback_answer.text = "Something wrong" + callback_answer.cache_time = 10 + + +Combine that all at once +------------------------ + +For example you want to answer in most of cases before handler with text "🤔" but at some cases need to answer after the handler with custom text, +so you can do it: + +.. code-block:: python + + dispatcher.callback_query.middleware(CallbackAnswerMiddleware(pre=True, text="🤔")) + + @router.callback_query() + @flags.callback_answer(pre=False, cache_time=30) + async def my_handler(query: CallbackQuery): + ... + if : + callback_answer.text = "All is ok" + + +Description of objects +====================== + +.. autoclass:: aiogram.utils.callback_answer.CallbackAnswerMiddleware + :show-inheritance: + :member-order: bysource + :special-members: __init__ + :members: + +.. autoclass:: aiogram.utils.callback_answer.CallbackAnswer + :show-inheritance: + :member-order: bysource + :special-members: __init__ + :members: diff --git a/docs/utils/index.rst b/docs/utils/index.rst index a7a8474d..cfe5a543 100644 --- a/docs/utils/index.rst +++ b/docs/utils/index.rst @@ -8,3 +8,4 @@ Utils i18n chat_action web_app + callback_answer diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index e2f1f1fc..70388689 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,5 +1,8 @@ +from decimal import Decimal from enum import Enum, auto +from fractions import Fraction from typing import Optional +from uuid import UUID import pytest from magic_filter import MagicFilter @@ -45,36 +48,35 @@ class TestCallbackData: class MyInvalidCallback(CallbackData, prefix="sp@m", sep="@"): pass - # - # @pytest.mark.parametrize( - # "value,success,expected", - # [ - # [None, True, ""], - # [42, True, "42"], - # ["test", True, "test"], - # [9.99, True, "9.99"], - # [Decimal("9.99"), True, "9.99"], - # [Fraction("3/2"), True, "3/2"], - # [ - # UUID("123e4567-e89b-12d3-a456-426655440000"), - # True, - # "123e4567-e89b-12d3-a456-426655440000", - # ], - # [MyIntEnum.FOO, True, "1"], - # [MyStringEnum.FOO, True, "FOO"], - # [..., False, "..."], - # [object, False, "..."], - # [object(), False, "..."], - # [User(id=42, is_bot=False, first_name="test"), False, "..."], - # ], - # ) - # def test_encode_value(self, value, success, expected): - # callback = MyCallback(foo="test", bar=42) - # if success: - # assert callback._encode_value("test", value) == expected - # else: - # with pytest.raises(ValueError): - # assert callback._encode_value("test", value) == expected + @pytest.mark.parametrize( + "value,success,expected", + [ + [None, True, ""], + [42, True, "42"], + ["test", True, "test"], + [9.99, True, "9.99"], + [Decimal("9.99"), True, "9.99"], + [Fraction("3/2"), True, "3/2"], + [ + UUID("123e4567-e89b-12d3-a456-426655440000"), + True, + "123e4567-e89b-12d3-a456-426655440000", + ], + [MyIntEnum.FOO, True, "1"], + [MyStringEnum.FOO, True, "FOO"], + [..., False, "..."], + [object, False, "..."], + [object(), False, "..."], + [User(id=42, is_bot=False, first_name="test"), False, "..."], + ], + ) + def test_encode_value(self, value, success, expected): + callback = MyCallback(foo="test", bar=42) + if success: + assert callback._encode_value("test", value) == expected + else: + with pytest.raises(ValueError): + assert callback._encode_value("test", value) == expected def test_pack(self): with pytest.raises(ValueError, match="Separator symbol .+"): diff --git a/tests/test_utils/test_callback_answer.py b/tests/test_utils/test_callback_answer.py new file mode 100644 index 00000000..cf641f46 --- /dev/null +++ b/tests/test_utils/test_callback_answer.py @@ -0,0 +1,217 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from aiogram.exceptions import CallbackAnswerException +from aiogram.methods import AnswerCallbackQuery +from aiogram.types import CallbackQuery, User +from aiogram.utils.callback_answer import CallbackAnswer, CallbackAnswerMiddleware + + +class TestCallbackAnswer: + @pytest.mark.parametrize( + "name,value", + [ + ["answered", True], + ["answered", False], + ["disabled", True], + ["disabled", False], + ["text", "test"], + ["text", None], + ["show_alert", True], + ["show_alert", False], + ["show_alert", None], + ["url", "https://example.com"], + ["url", None], + ["cache_time", None], + ["cache_time", 10], + ], + ) + def test_getters(self, name, value): + kwargs = { + "answered": False, + name: value, + } + instance = CallbackAnswer(**kwargs) + result = getattr(instance, name) + assert result == value + + @pytest.mark.parametrize( + "name,value", + [ + ["disabled", True], + ["disabled", False], + ["text", None], + ["text", ""], + ["text", "test"], + ["show_alert", None], + ["show_alert", True], + ["show_alert", False], + ["url", None], + ["url", "https://example.com"], + ["cache_time", None], + ["cache_time", 0], + ["cache_time", 10], + ], + ) + def test_setter_allowed(self, name, value): + instance = CallbackAnswer(answered=False) + setattr(instance, name, value) + assert getattr(instance, name) == value + + @pytest.mark.parametrize( + "name", + [ + "disabled", + "text", + "show_alert", + "url", + "cache_time", + ], + ) + def test_setter_blocked(self, name): + instance = CallbackAnswer(answered=True) + with pytest.raises(CallbackAnswerException): + setattr(instance, name, "test") + + def test_disable(self): + instance = CallbackAnswer(answered=False) + assert not instance.disabled + instance.disable() + assert instance.disabled + + def test_str(self): + instance = CallbackAnswer(answered=False, text="test") + assert str(instance) == "CallbackAnswer(answered=False, disabled=False, text='test')" + + +class TestCallbackAnswerMiddleware: + @pytest.mark.parametrize( + "init_kwargs,flag_properties,expected", + [ + [ + {}, + True, + { + "answered": False, + "disabled": False, + "text": None, + "show_alert": None, + "url": None, + "cache_time": None, + }, + ], + [ + { + "pre": True, + "text": "test", + "show_alert": True, + "url": "https://example.com", + "cache_time": 5, + }, + True, + { + "answered": True, + "disabled": False, + "text": "test", + "show_alert": True, + "url": "https://example.com", + "cache_time": 5, + }, + ], + [ + { + "pre": False, + "text": "test", + "show_alert": True, + "url": "https://example.com", + "cache_time": 5, + }, + { + "pre": True, + "disabled": True, + "text": "another test", + "show_alert": False, + "url": "https://example.com/game.html", + "cache_time": 10, + }, + { + "answered": True, + "disabled": True, + "text": "another test", + "show_alert": False, + "url": "https://example.com/game.html", + "cache_time": 10, + }, + ], + ], + ) + def test_construct_answer(self, init_kwargs, flag_properties, expected): + middleware = CallbackAnswerMiddleware(**init_kwargs) + callback_answer = middleware.construct_callback_answer(properties=flag_properties) + for key, value in expected.items(): + assert getattr(callback_answer, key) == value + + def test_answer(self): + middleware = CallbackAnswerMiddleware() + event = CallbackQuery( + id="1", + from_user=User(id=42, first_name="Test", is_bot=False), + chat_instance="test", + ) + callback_answer = CallbackAnswer( + answered=False, + disabled=False, + text="another test", + show_alert=False, + url="https://example.com/game.html", + cache_time=10, + ) + method = middleware.answer(event=event, callback_answer=callback_answer) + + assert isinstance(method, AnswerCallbackQuery) + assert method.text == callback_answer.text + assert method.show_alert == callback_answer.show_alert + assert method.url == callback_answer.url + assert method.cache_time == callback_answer.cache_time + + @pytest.mark.parametrize( + "properties,expected_stack", + [ + [{"answered": False}, ["handler", "answer"]], + [{"answered": True}, ["answer", "handler"]], + [{"disabled": True}, ["handler"]], + ], + ) + async def test_call(self, properties, expected_stack): + stack = [] + event = CallbackQuery( + id="1", + from_user=User(id=42, first_name="Test", is_bot=False), + chat_instance="test", + ) + + async def handler(*args, **kwargs): + stack.append("handler") + + async def answer(*args, **kwargs): + stack.append("answer") + + middleware = CallbackAnswerMiddleware() + with patch( + "aiogram.utils.callback_answer.CallbackAnswerMiddleware.construct_callback_answer", + new_callable=MagicMock, + side_effect=lambda **kwargs: CallbackAnswer(**{"answered": False, **properties}), + ), patch( + "aiogram.utils.callback_answer.CallbackAnswerMiddleware.answer", + new=answer, + ): + await middleware(handler, event, {}) + + assert stack == expected_stack + + async def test_invalid_event_type(self): + middleware = CallbackAnswerMiddleware() + handler = AsyncMock() + await middleware(handler, None, {}) + handler.assert_awaited()