From d8368cbebe234cb584693ec2ff2624df860cbdc4 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 12 Feb 2023 17:54:50 +0200 Subject: [PATCH] Reworked graceful shutdown --- Makefile | 2 +- aiogram/dispatcher/dispatcher.py | 146 +++++++++++++++++++++---------- 2 files changed, 101 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index cbe5c333..5696616b 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ clean: rm -rf *.egg-info rm -f report.html rm -f .coverage - rm -rf {build,dist,site,.cache,.mypy_cache,reports} + rm -rf {build,dist,site,.cache,.mypy_cache,.ruff_cache,reports} # ================================================================================================= # Code quality diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 7f058ec2..335d2857 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -2,13 +2,15 @@ from __future__ import annotations import asyncio import contextvars +import signal import warnings -from asyncio import CancelledError, Future, Lock +from asyncio import CancelledError, Event, Future, Lock +from contextlib import suppress from typing import Any, AsyncGenerator, Dict, List, Optional, Union from .. import loggers from ..client.bot import Bot -from ..exceptions import TelegramAPIError +from ..exceptions import TelegramAPIError, TelegramNotFound from ..fsm.middleware import FSMContextMiddleware from ..fsm.storage.base import BaseEventIsolation, BaseStorage from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage @@ -89,6 +91,8 @@ class Dispatcher(Router): self.workflow_data: Dict[str, Any] = kwargs self._running_lock = Lock() + self._stop_signal = Event() + self._stopped_signal = Event() def __getitem__(self, item: str) -> Any: return self.workflow_data[item] @@ -197,6 +201,9 @@ class Dispatcher(Router): while True: try: updates = await bot(get_updates, **kwargs) + except TelegramNotFound: + loggers.dispatcher.error("Seems like Bot token is invalid") + raise except Exception as e: failed = True # In cases when Telegram Bot API was inaccessible don't need to stop polling @@ -321,17 +328,26 @@ class Dispatcher(Router): :param kwargs: :return: """ - async for update in self._listen_updates( - bot, - polling_timeout=polling_timeout, - backoff_config=backoff_config, - allowed_updates=allowed_updates, - ): - handle_update = self._process_update(bot=bot, update=update, **kwargs) - if handle_as_tasks: - asyncio.create_task(handle_update) - else: - await handle_update + user: User = await bot.me() + loggers.dispatcher.info( + "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name + ) + try: + async for update in self._listen_updates( + bot, + polling_timeout=polling_timeout, + backoff_config=backoff_config, + allowed_updates=allowed_updates, + ): + handle_update = self._process_update(bot=bot, update=update, **kwargs) + if handle_as_tasks: + asyncio.create_task(handle_update) + else: + await handle_update + finally: + loggers.dispatcher.info( + "Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name + ) async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: """ @@ -408,6 +424,22 @@ class Dispatcher(Router): return None + async def stop_polling(self) -> None: + """ + Execute this method if you want to stop polling programmatically + + :return: + """ + if not self._running_lock.locked() or self._stop_signal.is_set(): + raise RuntimeError("Polling is not started") + self._stop_signal.set() + await self._stopped_signal.wait() + + def _signal_stop_polling(self, sig: signal.Signals) -> None: + if self._running_lock.locked(): + loggers.dispatcher.warning("Received %s signal", sig.name) + self._stop_signal.set() + async def start_polling( self, *bots: Bot, @@ -415,32 +447,39 @@ class Dispatcher(Router): handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, allowed_updates: Optional[List[str]] = None, + handle_signals: bool = True, + close_bot_session: bool = True, **kwargs: Any, ) -> None: """ Polling runner - :param bots: - :param polling_timeout: - :param handle_as_tasks: - :param kwargs: - :param backoff_config: - :param allowed_updates: + :param bots: Bot instances (one or mre) + :param polling_timeout: Long-polling wait time + :param handle_as_tasks: Run task for each event and no wait result + :param backoff_config: backoff-retry config + :param allowed_updates: List of the update types you want your bot to receive + :param handle_signals: handle signals (SIGINT/SIGTERM) + :param close_bot_session: close bot sessions on shutdown + :param kwargs: contextual data :return: """ async with self._running_lock: # Prevent to run this method twice at a once + self._stop_signal.clear() + self._stopped_signal.clear() + + if handle_signals: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM) + loop.add_signal_handler(signal.SIGINT, self._signal_stop_polling, signal.SIGINT) + workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]} workflow_data.update(kwargs) await self.emit_startup(**workflow_data) loggers.dispatcher.info("Start polling") try: - coro_list = [] - for bot in bots: - user: User = await bot.me() - loggers.dispatcher.info( - "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name - ) - coro_list.append( + tasks: List[asyncio.Task[Any]] = [ + asyncio.create_task( self._polling( bot=bot, handle_as_tasks=handle_as_tasks, @@ -450,46 +489,61 @@ class Dispatcher(Router): **kwargs, ) ) - await asyncio.gather(*coro_list) + for bot in bots + ] + tasks.append(asyncio.create_task(self._stop_signal.wait())) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in pending: + # (mostly) Graceful shutdown unfinished tasks + task.cancel() + with suppress(CancelledError): + await task + # Wait finished tasks to propagate unhandled exceptions + await asyncio.gather(*done) + finally: loggers.dispatcher.info("Polling stopped") try: await self.emit_shutdown(**workflow_data) finally: - for bot in bots: # Close sessions - await bot.session.close() + if close_bot_session: + await asyncio.gather(*(bot.session.close() for bot in bots)) + self._stopped_signal.set() def run_polling( self, *bots: Bot, - polling_timeout: int = 30, + polling_timeout: int = 10, handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, allowed_updates: Optional[List[str]] = None, + handle_signals: bool = True, + close_bot_session: bool = True, **kwargs: Any, ) -> None: """ Run many bots with polling - :param bots: Bot instances - :param polling_timeout: Poling timeout - :param backoff_config: + :param bots: Bot instances (one or mre) + :param polling_timeout: Long-polling wait time :param handle_as_tasks: Run task for each event and no wait result + :param backoff_config: backoff-retry config :param allowed_updates: List of the update types you want your bot to receive + :param handle_signals: handle signals (SIGINT/SIGTERM) + :param close_bot_session: close bot sessions on shutdown :param kwargs: contextual data :return: """ - try: - return asyncio.run( - self.start_polling( - *bots, - **kwargs, - polling_timeout=polling_timeout, - handle_as_tasks=handle_as_tasks, - backoff_config=backoff_config, - allowed_updates=allowed_updates, - ) + return asyncio.run( + self.start_polling( + *bots, + **kwargs, + polling_timeout=polling_timeout, + handle_as_tasks=handle_as_tasks, + backoff_config=backoff_config, + allowed_updates=allowed_updates, + handle_signals=handle_signals, + close_bot_session=close_bot_session, ) - except (KeyboardInterrupt, SystemExit): # pragma: no cover - # Allow to graceful shutdown - pass + )