diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index d6577702..63eb9a3a 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -8,14 +8,9 @@ from asyncio import CancelledError, Event, Future, Lock from contextlib import suppress from typing import Any, AsyncGenerator, Dict, List, Optional, Union -from .event.bases import UNHANDLED, SkipHandler -from .event.telegram import TelegramEventObserver -from .middlewares.error import ErrorsMiddleware -from .middlewares.user_context import UserContextMiddleware -from .router import Router from .. import loggers from ..client.bot import Bot -from ..exceptions import TelegramAPIError, TelegramNotFound, TelegramUnauthorizedError +from ..exceptions import TelegramAPIError from ..fsm.middleware import FSMContextMiddleware from ..fsm.storage.base import BaseEventIsolation, BaseStorage from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage @@ -24,6 +19,11 @@ from ..methods import GetUpdates, Request, TelegramMethod from ..types import Update, User from ..types.update import UpdateTypeLookupError from ..utils.backoff import Backoff, BackoffConfig +from .event.bases import UNHANDLED, SkipHandler +from .event.telegram import TelegramEventObserver +from .middlewares.error import ErrorsMiddleware +from .middlewares.user_context import UserContextMiddleware +from .router import Router DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1) @@ -427,14 +427,16 @@ class Dispatcher(Router): :return: """ - if not self._running_lock.locked() or self._stop_signal.is_set(): + if not self._running_lock.locked(): raise RuntimeError("Polling is not started") self._stop_signal.set() await self._stopped_signal.wait() def _signal_stop_polling(self, sig: signal.Signals) -> None: - if self._running_lock.locked(): - loggers.dispatcher.warning("Received %s signal", sig.name) + if not self._running_lock.locked(): + return + + loggers.dispatcher.warning("Received %s signal", sig.name) self._stop_signal.set() async def start_polling( @@ -461,14 +463,29 @@ class Dispatcher(Router): :param kwargs: contextual data :return: """ + if not bots: + raise ValueError("At least one bot instance is required to start polling") + if "bot" in kwargs: + raise ValueError( + "Keyword argument 'bot' is not acceptable, " + "the bot instance should be passed as positional argument" + ) + async with self._running_lock: # Prevent to run this method twice at a once self._stop_signal.clear() self._stopped_signal.clear() if handle_signals: loop = asyncio.get_running_loop() - loop.add_signal_handler(signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM) - loop.add_signal_handler(signal.SIGINT, self._signal_stop_polling, signal.SIGINT) + with suppress(NotImplementedError): # pragma: no cover + # Signals handling is not supported on Windows + # It also can't be covered on Windows + loop.add_signal_handler( + signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM + ) + loop.add_signal_handler( + signal.SIGINT, self._signal_stop_polling, signal.SIGINT + ) workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]} workflow_data.update(kwargs) @@ -532,15 +549,16 @@ class Dispatcher(Router): :param kwargs: contextual data :return: """ - return asyncio.run( - self.start_polling( - *bots, - **kwargs, - polling_timeout=polling_timeout, - handle_as_tasks=handle_as_tasks, - backoff_config=backoff_config, - allowed_updates=allowed_updates, - handle_signals=handle_signals, - close_bot_session=close_bot_session, + with suppress(KeyboardInterrupt): + return asyncio.run( + self.start_polling( + *bots, + **kwargs, + polling_timeout=polling_timeout, + handle_as_tasks=handle_as_tasks, + backoff_config=backoff_config, + allowed_updates=allowed_updates, + handle_signals=handle_signals, + close_bot_session=close_bot_session, + ) ) - ) diff --git a/aiogram/dispatcher/flags.py b/aiogram/dispatcher/flags.py index 82cdbc04..aad8a29f 100644 --- a/aiogram/dispatcher/flags.py +++ b/aiogram/dispatcher/flags.py @@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, """ if isinstance(handler, dict) and "handler" in handler: handler = handler["handler"] - if not hasattr(handler, "flags"): - return {} - return handler.flags + if hasattr(handler, "flags"): + return handler.flags + return {} def get_flag( diff --git a/aiogram/exceptions.py b/aiogram/exceptions.py index dcf433c6..1c1e59fb 100644 --- a/aiogram/exceptions.py +++ b/aiogram/exceptions.py @@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError): super().__init__(message=message) self.method = method + def __str__(self) -> str: + original_message = super().__str__() + return f"Telegram server says {original_message}" + class TelegramNetworkError(TelegramAPIError): pass diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 6d6fa1ab..509bff35 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -1,7 +1,9 @@ import asyncio import datetime +import signal import time import warnings +from asyncio import Event from collections import Counter from typing import Any from unittest.mock import AsyncMock, patch @@ -667,6 +669,17 @@ class TestDispatcher: async def test_start_polling(self, bot: MockedBot): dispatcher = Dispatcher() + with pytest.raises( + ValueError, match="At least one bot instance is required to start polling" + ): + await dispatcher.start_polling() + with pytest.raises( + ValueError, + match="Keyword argument 'bot' is not acceptable, " + "the bot instance should be passed as positional argument", + ): + await dispatcher.start_polling(bot, bot=bot) + bot.add_result_for( GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") ) @@ -690,6 +703,65 @@ class TestDispatcher: mocked_process_update.assert_awaited() mocked_emit_shutdown.assert_awaited() + async def test_stop_polling(self): + dispatcher = Dispatcher() + with pytest.raises(RuntimeError): + await dispatcher.stop_polling() + + assert not dispatcher._stop_signal.is_set() + assert not dispatcher._stopped_signal.is_set() + with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait: + async with dispatcher._running_lock: + await dispatcher.stop_polling() + assert dispatcher._stop_signal.is_set() + mocked_wait.assert_awaited() + + async def test_signal_stop_polling(self): + dispatcher = Dispatcher() + with patch("asyncio.locks.Event.set") as mocked_set: + dispatcher._signal_stop_polling(signal.SIGINT) + mocked_set.assert_not_called() + + async with dispatcher._running_lock: + dispatcher._signal_stop_polling(signal.SIGINT) + mocked_set.assert_called() + + async def test_stop_polling_by_method(self, bot: MockedBot): + dispatcher = Dispatcher() + bot.add_result_for( + GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") + ) + running = Event() + + async def _mock_updates(*_): + running.set() + while True: + yield Update(update_id=42) + await asyncio.sleep(1) + + with patch( + "aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock + ) as mocked_process_update, patch( + "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates", + return_value=_mock_updates(), + ): + task = asyncio.ensure_future(dispatcher.start_polling(bot)) + await running.wait() + + assert not dispatcher._stop_signal.is_set() + assert not dispatcher._stopped_signal.is_set() + + await dispatcher.stop_polling() + assert dispatcher._stop_signal.is_set() + assert dispatcher._stopped_signal.is_set() + assert not task.exception() + + mocked_process_update.assert_awaited() + + @pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method") + async def test_stop_polling_by_signal(self, bot: MockedBot): + pass + def test_run_polling(self, bot: MockedBot): dispatcher = Dispatcher() with patch(