From b51104465041c2202098295ca0b851778f762bd8 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Wed, 16 Feb 2022 19:38:46 +0200 Subject: [PATCH] Cover chat action sender --- Makefile | 2 +- aiogram/utils/chat_action.py | 32 ++++--- tests/test_utils/test_chat_action.py | 129 +++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 tests/test_utils/test_chat_action.py diff --git a/Makefile b/Makefile index 343f64b0..c0508fea 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,7 @@ clean: rm -rf `find . -name .pytest_cache` rm -rf *.egg-info rm -f report.html - rm -f .coverage* + rm -f .coverage rm -rf {build,dist,site,.cache,.mypy_cache,reports} # ================================================================================================= diff --git a/aiogram/utils/chat_action.py b/aiogram/utils/chat_action.py index fe96412d..2e78ab26 100644 --- a/aiogram/utils/chat_action.py +++ b/aiogram/utils/chat_action.py @@ -33,9 +33,14 @@ class ChatActionSender: self.initial_sleep = initial_sleep self.bot = bot - self._close_event = Event() - self._running = False self._lock = Lock() + self._close_event = Event() + self._closed_event = Event() + self._task: Optional[asyncio.Task[Any]] = None + + @property + def running(self) -> bool: + return bool(self._task) async def _wait(self, interval: float) -> None: with suppress(asyncio.TimeoutError): @@ -72,17 +77,24 @@ class ChatActionSender: self.chat_id, self.bot.id, ) - self._running = False + self._closed_event.set() async def _run(self) -> None: async with self._lock: - if self._running: + self._close_event.clear() + self._closed_event.clear() + if self.running: raise RuntimeError("Already running") - asyncio.create_task(self._worker()) + self._task = asyncio.create_task(self._worker()) - def _stop(self) -> None: - if not self._close_event.is_set(): - self._close_event.set() + async def _stop(self) -> None: + async with self._lock: + if not self.running: + return + if not self._close_event.is_set(): + self._close_event.set() + await self._closed_event.wait() + self._task = None async def __aenter__(self) -> "ChatActionSender": await self._run() @@ -94,7 +106,7 @@ class ChatActionSender: exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Any: - self._stop() + await self._stop() @classmethod def typing( @@ -115,8 +127,8 @@ class ChatActionSender: @classmethod def upload_photo( cls, - bot: Bot, chat_id: Union[int, str], + bot: Optional[Bot] = None, interval: float = DEFAULT_INTERVAL, initial_sleep: float = DEFAULT_INITIAL_SLEEP, ) -> "ChatActionSender": diff --git a/tests/test_utils/test_chat_action.py b/tests/test_utils/test_chat_action.py new file mode 100644 index 00000000..c20f5718 --- /dev/null +++ b/tests/test_utils/test_chat_action.py @@ -0,0 +1,129 @@ +import asyncio +import time +from datetime import datetime + +import pytest + +from aiogram import Bot, flags +from aiogram.dispatcher.event.handler import HandlerObject +from aiogram.types import Chat, Message, User +from aiogram.utils.chat_action import ChatActionMiddleware, ChatActionSender +from tests.mocked_bot import MockedBot + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock # type: ignore + from unittest.mock import patch + +pytestmarm = pytest.mark.asyncio + + +class TestChatActionSender: + async def test_wait(self, bot: Bot, loop: asyncio.BaseEventLoop): + sender = ChatActionSender.typing(bot=bot, chat_id=42) + loop.call_soon(sender._close_event.set) + start = time.monotonic() + await sender._wait(1) + assert time.monotonic() - start < 1 + + @pytest.mark.parametrize( + "action", + [ + "typing", + "upload_photo", + "record_video", + "upload_video", + "record_voice", + "upload_voice", + "upload_document", + "choose_sticker", + "find_location", + "record_video_note", + "upload_video_note", + ], + ) + @pytest.mark.parametrize("pass_bot", [True, False]) + def test_factory(self, action: str, bot: MockedBot, pass_bot: bool): + sender_factory = getattr(ChatActionSender, action) + sender = sender_factory(chat_id=42, bot=bot if pass_bot else None) + assert isinstance(sender, ChatActionSender) + assert sender.action == action + assert sender.chat_id == 42 + assert sender.bot is bot + + async def test_worker(self, bot: Bot): + with patch( + "aiogram.client.bot.Bot.send_chat_action", + new_callable=CoroutineMock, + ) as mocked_send_chat_action: + async with ChatActionSender.typing( + bot=bot, chat_id=42, interval=0.01, initial_sleep=0 + ): + await asyncio.sleep(0.1) + assert mocked_send_chat_action.await_count > 1 + mocked_send_chat_action.assert_awaited_with(action="typing", chat_id=42) + + async def test_contextmanager(self, bot: MockedBot): + sender: ChatActionSender = ChatActionSender.typing(bot=bot, chat_id=42) + assert not sender.running + await sender._stop() # nothing + + async with sender: + assert sender.running + assert not sender._close_event.is_set() + + with pytest.raises(RuntimeError): + await sender._run() + + assert not sender.running + + +class TestChatActionMiddleware: + @pytest.mark.parametrize( + "value", + [ + None, + "sticker", + {"action": "upload_photo"}, + {"interval": 1, "initial_sleep": 0.5}, + ], + ) + async def test_call_default(self, value, bot: Bot): + async def handler(event, data): + return "OK" + + if value is None: + handler1 = flags.chat_action(handler) + else: + handler1 = flags.chat_action(value)(handler) + + middleware = ChatActionMiddleware() + with patch( + "aiogram.utils.chat_action.ChatActionSender._run", + new_callable=CoroutineMock, + ) as mocked_run, patch( + "aiogram.utils.chat_action.ChatActionSender._stop", + new_callable=CoroutineMock, + ) as mocked_stop: + data = {"handler": HandlerObject(callback=handler1), "bot": bot} + message = Message( + chat=Chat(id=42, type="private", title="Test"), + from_user=User(id=42, is_bot=False, first_name="Test"), + date=datetime.now(), + message_id=42, + ) + + result = await middleware(handler=handler1, event=None, data=data) + assert result == "OK" + mocked_run.assert_not_awaited() + mocked_stop.assert_not_awaited() + + result = await middleware( + handler=handler1, + event=message, + data=data, + ) + assert result == "OK" + mocked_run.assert_awaited() + mocked_stop.assert_awaited()