mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Reworked graceful shutdown (#1124)
* Reworked graceful shutdown * Remove special errors from polling process * Update dependencies * Coverage * Added changelog
This commit is contained in:
parent
a332e88bc3
commit
d0b7135ca6
7 changed files with 197 additions and 50 deletions
|
|
@ -2,8 +2,10 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import signal
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Lock
|
||||
from asyncio import CancelledError, Event, Future, Lock
|
||||
from contextlib import suppress
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from .. import loggers
|
||||
|
|
@ -89,6 +91,8 @@ class Dispatcher(Router):
|
|||
|
||||
self.workflow_data: Dict[str, Any] = kwargs
|
||||
self._running_lock = Lock()
|
||||
self._stop_signal = Event()
|
||||
self._stopped_signal = Event()
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self.workflow_data[item]
|
||||
|
|
@ -321,17 +325,26 @@ class Dispatcher(Router):
|
|||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
async for update in self._listen_updates(
|
||||
bot,
|
||||
polling_timeout=polling_timeout,
|
||||
backoff_config=backoff_config,
|
||||
allowed_updates=allowed_updates,
|
||||
):
|
||||
handle_update = self._process_update(bot=bot, update=update, **kwargs)
|
||||
if handle_as_tasks:
|
||||
asyncio.create_task(handle_update)
|
||||
else:
|
||||
await handle_update
|
||||
user: User = await bot.me()
|
||||
loggers.dispatcher.info(
|
||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
try:
|
||||
async for update in self._listen_updates(
|
||||
bot,
|
||||
polling_timeout=polling_timeout,
|
||||
backoff_config=backoff_config,
|
||||
allowed_updates=allowed_updates,
|
||||
):
|
||||
handle_update = self._process_update(bot=bot, update=update, **kwargs)
|
||||
if handle_as_tasks:
|
||||
asyncio.create_task(handle_update)
|
||||
else:
|
||||
await handle_update
|
||||
finally:
|
||||
loggers.dispatcher.info(
|
||||
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
|
||||
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
|
@ -408,6 +421,24 @@ class Dispatcher(Router):
|
|||
|
||||
return None
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""
|
||||
Execute this method if you want to stop polling programmatically
|
||||
|
||||
:return:
|
||||
"""
|
||||
if not self._running_lock.locked():
|
||||
raise RuntimeError("Polling is not started")
|
||||
self._stop_signal.set()
|
||||
await self._stopped_signal.wait()
|
||||
|
||||
def _signal_stop_polling(self, sig: signal.Signals) -> None:
|
||||
if not self._running_lock.locked():
|
||||
return
|
||||
|
||||
loggers.dispatcher.warning("Received %s signal", sig.name)
|
||||
self._stop_signal.set()
|
||||
|
||||
async def start_polling(
|
||||
self,
|
||||
*bots: Bot,
|
||||
|
|
@ -415,32 +446,54 @@ class Dispatcher(Router):
|
|||
handle_as_tasks: bool = True,
|
||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||
allowed_updates: Optional[List[str]] = None,
|
||||
handle_signals: bool = True,
|
||||
close_bot_session: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Polling runner
|
||||
|
||||
:param bots:
|
||||
:param polling_timeout:
|
||||
:param handle_as_tasks:
|
||||
:param kwargs:
|
||||
:param backoff_config:
|
||||
:param allowed_updates:
|
||||
:param bots: Bot instances (one or mre)
|
||||
:param polling_timeout: Long-polling wait time
|
||||
:param handle_as_tasks: Run task for each event and no wait result
|
||||
:param backoff_config: backoff-retry config
|
||||
:param allowed_updates: List of the update types you want your bot to receive
|
||||
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||
:param close_bot_session: close bot sessions on shutdown
|
||||
:param kwargs: contextual data
|
||||
:return:
|
||||
"""
|
||||
if not bots:
|
||||
raise ValueError("At least one bot instance is required to start polling")
|
||||
if "bot" in kwargs:
|
||||
raise ValueError(
|
||||
"Keyword argument 'bot' is not acceptable, "
|
||||
"the bot instance should be passed as positional argument"
|
||||
)
|
||||
|
||||
async with self._running_lock: # Prevent to run this method twice at a once
|
||||
self._stop_signal.clear()
|
||||
self._stopped_signal.clear()
|
||||
|
||||
if handle_signals:
|
||||
loop = asyncio.get_running_loop()
|
||||
with suppress(NotImplementedError): # pragma: no cover
|
||||
# Signals handling is not supported on Windows
|
||||
# It also can't be covered on Windows
|
||||
loop.add_signal_handler(
|
||||
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
|
||||
)
|
||||
loop.add_signal_handler(
|
||||
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
|
||||
)
|
||||
|
||||
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
|
||||
workflow_data.update(kwargs)
|
||||
await self.emit_startup(**workflow_data)
|
||||
loggers.dispatcher.info("Start polling")
|
||||
try:
|
||||
coro_list = []
|
||||
for bot in bots:
|
||||
user: User = await bot.me()
|
||||
loggers.dispatcher.info(
|
||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
||||
)
|
||||
coro_list.append(
|
||||
tasks: List[asyncio.Task[Any]] = [
|
||||
asyncio.create_task(
|
||||
self._polling(
|
||||
bot=bot,
|
||||
handle_as_tasks=handle_as_tasks,
|
||||
|
|
@ -450,36 +503,53 @@ class Dispatcher(Router):
|
|||
**kwargs,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*coro_list)
|
||||
for bot in bots
|
||||
]
|
||||
tasks.append(asyncio.create_task(self._stop_signal.wait()))
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in pending:
|
||||
# (mostly) Graceful shutdown unfinished tasks
|
||||
task.cancel()
|
||||
with suppress(CancelledError):
|
||||
await task
|
||||
# Wait finished tasks to propagate unhandled exceptions
|
||||
await asyncio.gather(*done)
|
||||
|
||||
finally:
|
||||
loggers.dispatcher.info("Polling stopped")
|
||||
try:
|
||||
await self.emit_shutdown(**workflow_data)
|
||||
finally:
|
||||
for bot in bots: # Close sessions
|
||||
await bot.session.close()
|
||||
if close_bot_session:
|
||||
await asyncio.gather(*(bot.session.close() for bot in bots))
|
||||
self._stopped_signal.set()
|
||||
|
||||
def run_polling(
|
||||
self,
|
||||
*bots: Bot,
|
||||
polling_timeout: int = 30,
|
||||
polling_timeout: int = 10,
|
||||
handle_as_tasks: bool = True,
|
||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||
allowed_updates: Optional[List[str]] = None,
|
||||
handle_signals: bool = True,
|
||||
close_bot_session: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Run many bots with polling
|
||||
|
||||
:param bots: Bot instances
|
||||
:param polling_timeout: Poling timeout
|
||||
:param backoff_config:
|
||||
:param bots: Bot instances (one or mre)
|
||||
:param polling_timeout: Long-polling wait time
|
||||
:param handle_as_tasks: Run task for each event and no wait result
|
||||
:param backoff_config: backoff-retry config
|
||||
:param allowed_updates: List of the update types you want your bot to receive
|
||||
:param handle_signals: handle signals (SIGINT/SIGTERM)
|
||||
:param close_bot_session: close bot sessions on shutdown
|
||||
:param kwargs: contextual data
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
with suppress(KeyboardInterrupt):
|
||||
return asyncio.run(
|
||||
self.start_polling(
|
||||
*bots,
|
||||
|
|
@ -488,8 +558,7 @@ class Dispatcher(Router):
|
|||
handle_as_tasks=handle_as_tasks,
|
||||
backoff_config=backoff_config,
|
||||
allowed_updates=allowed_updates,
|
||||
handle_signals=handle_signals,
|
||||
close_bot_session=close_bot_session,
|
||||
)
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit): # pragma: no cover
|
||||
# Allow to graceful shutdown
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
|
|||
"""
|
||||
if isinstance(handler, dict) and "handler" in handler:
|
||||
handler = handler["handler"]
|
||||
if not hasattr(handler, "flags"):
|
||||
return {}
|
||||
return handler.flags
|
||||
if hasattr(handler, "flags"):
|
||||
return handler.flags
|
||||
return {}
|
||||
|
||||
|
||||
def get_flag(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
|
|||
super().__init__(message=message)
|
||||
self.method = method
|
||||
|
||||
def __str__(self) -> str:
|
||||
original_message = super().__str__()
|
||||
return f"Telegram server says {original_message}"
|
||||
|
||||
|
||||
class TelegramNetworkError(TelegramAPIError):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue