This commit is contained in:
Alex Root Junior 2023-02-18 15:33:17 +02:00
parent 82295c94e0
commit b225fab8ff
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
4 changed files with 119 additions and 25 deletions

View file

@ -8,14 +8,9 @@ from asyncio import CancelledError, Event, Future, Lock
from contextlib import suppress
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
from .middlewares.error import ErrorsMiddleware
from .middlewares.user_context import UserContextMiddleware
from .router import Router
from .. import loggers
from ..client.bot import Bot
from ..exceptions import TelegramAPIError, TelegramNotFound, TelegramUnauthorizedError
from ..exceptions import TelegramAPIError
from ..fsm.middleware import FSMContextMiddleware
from ..fsm.storage.base import BaseEventIsolation, BaseStorage
from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage
@ -24,6 +19,11 @@ from ..methods import GetUpdates, Request, TelegramMethod
from ..types import Update, User
from ..types.update import UpdateTypeLookupError
from ..utils.backoff import Backoff, BackoffConfig
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
from .middlewares.error import ErrorsMiddleware
from .middlewares.user_context import UserContextMiddleware
from .router import Router
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
@ -427,14 +427,16 @@ class Dispatcher(Router):
:return:
"""
if not self._running_lock.locked() or self._stop_signal.is_set():
if not self._running_lock.locked():
raise RuntimeError("Polling is not started")
self._stop_signal.set()
await self._stopped_signal.wait()
def _signal_stop_polling(self, sig: signal.Signals) -> None:
if self._running_lock.locked():
loggers.dispatcher.warning("Received %s signal", sig.name)
if not self._running_lock.locked():
return
loggers.dispatcher.warning("Received %s signal", sig.name)
self._stop_signal.set()
async def start_polling(
@ -461,14 +463,29 @@ class Dispatcher(Router):
:param kwargs: contextual data
:return:
"""
if not bots:
raise ValueError("At least one bot instance is required to start polling")
if "bot" in kwargs:
raise ValueError(
"Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument"
)
async with self._running_lock: # Prevent to run this method twice at a once
self._stop_signal.clear()
self._stopped_signal.clear()
if handle_signals:
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM)
loop.add_signal_handler(signal.SIGINT, self._signal_stop_polling, signal.SIGINT)
with suppress(NotImplementedError): # pragma: no cover
# Signals handling is not supported on Windows
# It also can't be covered on Windows
loop.add_signal_handler(
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
)
loop.add_signal_handler(
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
)
workflow_data = {"dispatcher": self, "bots": bots, "bot": bots[-1]}
workflow_data.update(kwargs)
@ -532,15 +549,16 @@ class Dispatcher(Router):
:param kwargs: contextual data
:return:
"""
return asyncio.run(
self.start_polling(
*bots,
**kwargs,
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
handle_signals=handle_signals,
close_bot_session=close_bot_session,
with suppress(KeyboardInterrupt):
return asyncio.run(
self.start_polling(
*bots,
**kwargs,
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
handle_signals=handle_signals,
close_bot_session=close_bot_session,
)
)
)

View file

@ -92,9 +92,9 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
"""
if isinstance(handler, dict) and "handler" in handler:
handler = handler["handler"]
if not hasattr(handler, "flags"):
return {}
return handler.flags
if hasattr(handler, "flags"):
return handler.flags
return {}
def get_flag(

View file

@ -42,6 +42,10 @@ class TelegramAPIError(DetailedAiogramError):
super().__init__(message=message)
self.method = method
def __str__(self) -> str:
original_message = super().__str__()
return f"Telegram server says {original_message}"
class TelegramNetworkError(TelegramAPIError):
pass

View file

@ -1,7 +1,9 @@
import asyncio
import datetime
import signal
import time
import warnings
from asyncio import Event
from collections import Counter
from typing import Any
from unittest.mock import AsyncMock, patch
@ -667,6 +669,17 @@ class TestDispatcher:
async def test_start_polling(self, bot: MockedBot):
dispatcher = Dispatcher()
with pytest.raises(
ValueError, match="At least one bot instance is required to start polling"
):
await dispatcher.start_polling()
with pytest.raises(
ValueError,
match="Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument",
):
await dispatcher.start_polling(bot, bot=bot)
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
@ -690,6 +703,65 @@ class TestDispatcher:
mocked_process_update.assert_awaited()
mocked_emit_shutdown.assert_awaited()
async def test_stop_polling(self):
dispatcher = Dispatcher()
with pytest.raises(RuntimeError):
await dispatcher.stop_polling()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
with patch("asyncio.locks.Event.wait", new_callable=AsyncMock) as mocked_wait:
async with dispatcher._running_lock:
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
mocked_wait.assert_awaited()
async def test_signal_stop_polling(self):
dispatcher = Dispatcher()
with patch("asyncio.locks.Event.set") as mocked_set:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_not_called()
async with dispatcher._running_lock:
dispatcher._signal_stop_polling(signal.SIGINT)
mocked_set.assert_called()
async def test_stop_polling_by_method(self, bot: MockedBot):
dispatcher = Dispatcher()
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
running = Event()
async def _mock_updates(*_):
running.set()
while True:
yield Update(update_id=42)
await asyncio.sleep(1)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=AsyncMock
) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates",
return_value=_mock_updates(),
):
task = asyncio.ensure_future(dispatcher.start_polling(bot))
await running.wait()
assert not dispatcher._stop_signal.is_set()
assert not dispatcher._stopped_signal.is_set()
await dispatcher.stop_polling()
assert dispatcher._stop_signal.is_set()
assert dispatcher._stopped_signal.is_set()
assert not task.exception()
mocked_process_update.assert_awaited()
@pytest.mark.skip("Stopping by signal should also be tested as the same as stopping by method")
async def test_stop_polling_by_signal(self, bot: MockedBot):
pass
def test_run_polling(self, bot: MockedBot):
dispatcher = Dispatcher()
with patch(