diff --git a/aiogram/utils/callback_answer.py b/aiogram/utils/callback_answer.py index 6602ebdf..1515a7e6 100644 --- a/aiogram/utils/callback_answer.py +++ b/aiogram/utils/callback_answer.py @@ -3,7 +3,8 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Union from aiogram import BaseMiddleware, loggers from aiogram.dispatcher.flags import get_flag from aiogram.exceptions import CallbackAnswerException -from aiogram.types import CallbackQuery +from aiogram.methods import AnswerCallbackQuery +from aiogram.types import CallbackQuery, TelegramObject class CallbackAnswer: @@ -45,7 +46,7 @@ class CallbackAnswer: return self._disabled @disabled.setter - def disabled(self, value) -> None: + def disabled(self, value: bool) -> None: if self._answered: raise CallbackAnswerException("Can't change disabled state after answer") self._disabled = value @@ -66,7 +67,7 @@ class CallbackAnswer: return self._text @text.setter - def text(self, value) -> None: + def text(self, value: Optional[str]) -> None: if self._answered: raise CallbackAnswerException("Can't change text after answer") self._text = value @@ -79,7 +80,7 @@ class CallbackAnswer: return self._show_alert @show_alert.setter - def show_alert(self, value) -> None: + def show_alert(self, value: Optional[bool]) -> None: if self._answered: raise CallbackAnswerException("Can't change show_alert after answer") self._show_alert = value @@ -92,7 +93,7 @@ class CallbackAnswer: return self._url @url.setter - def url(self, value) -> None: + def url(self, value: Optional[str]) -> None: if self._answered: raise CallbackAnswerException("Can't change url after answer") self._url = value @@ -105,12 +106,12 @@ class CallbackAnswer: return self._cache_time @cache_time.setter - def cache_time(self, value) -> None: + def cache_time(self, value: Optional[int]) -> None: if self._answered: raise CallbackAnswerException("Can't change cache_time after answer") self._cache_time = value - def __str__(self): + def __str__(self) -> str: args = ", ".join( f"{k}={v!r}" for k, v in { @@ -134,7 +135,7 @@ class CallbackAnswerMiddleware(BaseMiddleware): show_alert: Optional[bool] = None, url: Optional[str] = None, cache_time: Optional[int] = None, - ): + ) -> None: """ Inner middleware for callback query handlers, can be useful in bots with a lot of callback handlers to automatically take answer to all requests @@ -153,10 +154,13 @@ class CallbackAnswerMiddleware(BaseMiddleware): async def __call__( self, - handler: Callable[[CallbackQuery, Dict[str, Any]], Awaitable[Any]], - event: CallbackQuery, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: TelegramObject, data: Dict[str, Any], ) -> Any: + if not isinstance(event, CallbackQuery): + return await handler(event, data) + callback_answer = data["callback_answer"] = self.construct_callback_answer( properties=get_flag(data, "callback_answer") ) @@ -169,7 +173,9 @@ class CallbackAnswerMiddleware(BaseMiddleware): if not callback_answer.disabled and not callback_answer.answered: await self.answer(event, callback_answer) - def construct_callback_answer(self, properties: Optional[Union[Dict[str, Any], bool]]): + def construct_callback_answer( + self, properties: Optional[Union[Dict[str, Any], bool]] + ) -> CallbackAnswer: pre, disabled, text, show_alert, url, cache_time = ( self.pre, False, @@ -195,7 +201,7 @@ class CallbackAnswerMiddleware(BaseMiddleware): cache_time=cache_time, ) - def answer(self, event: CallbackQuery, callback_answer: CallbackAnswer): + def answer(self, event: CallbackQuery, callback_answer: CallbackAnswer) -> AnswerCallbackQuery: loggers.middlewares.info("Answer to callback query id=%s", event.id) return event.answer( text=callback_answer.text, diff --git a/tests/test_utils/test_callback_answer.py b/tests/test_utils/test_callback_answer.py index 072e2edc..8c0d76c2 100644 --- a/tests/test_utils/test_callback_answer.py +++ b/tests/test_utils/test_callback_answer.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -212,3 +212,9 @@ class TestCallbackAnswerMiddleware: await middleware(handler, event, {}) assert stack == expected_stack + + async def test_invalid_event_type(self): + middleware = CallbackAnswerMiddleware() + handler = AsyncMock() + await middleware(handler, None, {}) + handler.assert_awaited()