Fixed typehints and tests

This commit is contained in:
Alex Root Junior 2023-01-08 02:29:14 +02:00
parent 4ba07c0815
commit 053e0bbbcc
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
2 changed files with 25 additions and 13 deletions

View file

@ -3,7 +3,8 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Union
from aiogram import BaseMiddleware, loggers
from aiogram.dispatcher.flags import get_flag
from aiogram.exceptions import CallbackAnswerException
from aiogram.types import CallbackQuery
from aiogram.methods import AnswerCallbackQuery
from aiogram.types import CallbackQuery, TelegramObject
class CallbackAnswer:
@ -45,7 +46,7 @@ class CallbackAnswer:
return self._disabled
@disabled.setter
def disabled(self, value) -> None:
def disabled(self, value: bool) -> None:
if self._answered:
raise CallbackAnswerException("Can't change disabled state after answer")
self._disabled = value
@ -66,7 +67,7 @@ class CallbackAnswer:
return self._text
@text.setter
def text(self, value) -> None:
def text(self, value: Optional[str]) -> None:
if self._answered:
raise CallbackAnswerException("Can't change text after answer")
self._text = value
@ -79,7 +80,7 @@ class CallbackAnswer:
return self._show_alert
@show_alert.setter
def show_alert(self, value) -> None:
def show_alert(self, value: Optional[bool]) -> None:
if self._answered:
raise CallbackAnswerException("Can't change show_alert after answer")
self._show_alert = value
@ -92,7 +93,7 @@ class CallbackAnswer:
return self._url
@url.setter
def url(self, value) -> None:
def url(self, value: Optional[str]) -> None:
if self._answered:
raise CallbackAnswerException("Can't change url after answer")
self._url = value
@ -105,12 +106,12 @@ class CallbackAnswer:
return self._cache_time
@cache_time.setter
def cache_time(self, value) -> None:
def cache_time(self, value: Optional[int]) -> None:
if self._answered:
raise CallbackAnswerException("Can't change cache_time after answer")
self._cache_time = value
def __str__(self):
def __str__(self) -> str:
args = ", ".join(
f"{k}={v!r}"
for k, v in {
@ -134,7 +135,7 @@ class CallbackAnswerMiddleware(BaseMiddleware):
show_alert: Optional[bool] = None,
url: Optional[str] = None,
cache_time: Optional[int] = None,
):
) -> None:
"""
Inner middleware for callback query handlers, can be useful in bots with a lot of callback
handlers to automatically take answer to all requests
@ -153,10 +154,13 @@ class CallbackAnswerMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[CallbackQuery, Dict[str, Any]], Awaitable[Any]],
event: CallbackQuery,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
if not isinstance(event, CallbackQuery):
return await handler(event, data)
callback_answer = data["callback_answer"] = self.construct_callback_answer(
properties=get_flag(data, "callback_answer")
)
@ -169,7 +173,9 @@ class CallbackAnswerMiddleware(BaseMiddleware):
if not callback_answer.disabled and not callback_answer.answered:
await self.answer(event, callback_answer)
def construct_callback_answer(self, properties: Optional[Union[Dict[str, Any], bool]]):
def construct_callback_answer(
self, properties: Optional[Union[Dict[str, Any], bool]]
) -> CallbackAnswer:
pre, disabled, text, show_alert, url, cache_time = (
self.pre,
False,
@ -195,7 +201,7 @@ class CallbackAnswerMiddleware(BaseMiddleware):
cache_time=cache_time,
)
def answer(self, event: CallbackQuery, callback_answer: CallbackAnswer):
def answer(self, event: CallbackQuery, callback_answer: CallbackAnswer) -> AnswerCallbackQuery:
loggers.middlewares.info("Answer to callback query id=%s", event.id)
return event.answer(
text=callback_answer.text,

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -212,3 +212,9 @@ class TestCallbackAnswerMiddleware:
await middleware(handler, event, {})
assert stack == expected_stack
async def test_invalid_event_type(self):
middleware = CallbackAnswerMiddleware()
handler = AsyncMock()
await middleware(handler, None, {})
handler.assert_awaited()