diff --git a/aiogram/bot/base.py b/aiogram/bot/base.py index 86347e88..f45546c3 100644 --- a/aiogram/bot/base.py +++ b/aiogram/bot/base.py @@ -56,6 +56,8 @@ class BaseBot: :type timeout: :obj:`typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]` :raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError` """ + self._main_loop = loop + # Authentication if validate_token: api.check_token(token) @@ -66,19 +68,12 @@ class BaseBot: self.proxy = proxy self.proxy_auth = proxy_auth - # Asyncio loop instance - if loop is None: - loop = asyncio.get_event_loop() - self.loop = loop - # aiohttp main session ssl_context = ssl.create_default_context(cafile=certifi.where()) self._session: Optional[aiohttp.ClientSession] = None self._connector_class: Type[aiohttp.TCPConnector] = aiohttp.TCPConnector - self._connector_init = dict( - limit=connections_limit, ssl=ssl_context, loop=self.loop - ) + self._connector_init = dict(limit=connections_limit, ssl=ssl_context) if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')): from aiohttp_socks import SocksConnector @@ -106,11 +101,15 @@ class BaseBot: def get_new_session(self) -> aiohttp.ClientSession: return aiohttp.ClientSession( - connector=self._connector_class(**self._connector_init), - loop=self.loop, + connector=self._connector_class(**self._connector_init, loop=self._main_loop), + loop=self._main_loop, json_serialize=json.dumps ) + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + return self._main_loop + @property def session(self) -> Optional[aiohttp.ClientSession]: if self._session is None or self._session.closed: diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 7a3aa5b3..b38d3af1 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -27,6 +27,13 @@ log = logging.getLogger(__name__) DEFAULT_RATE_LIMIT = .1 +def _ensure_loop(x: "asyncio.AbstractEventLoop"): + assert isinstance( + x, asyncio.AbstractEventLoop + ), f"Loop must be the implementation of {asyncio.AbstractEventLoop!r}, " \ + f"not {type(x)!r}" + + class Dispatcher(DataMixin, ContextInstanceMixin): """ Simple Updates dispatcher @@ -43,15 +50,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin): if not isinstance(bot, Bot): raise TypeError(f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'") - if loop is None: - loop = bot.loop if storage is None: storage = DisabledStorage() if filters_factory is None: filters_factory = FiltersFactory(self) self.bot: Bot = bot - self.loop = loop + if loop is not None: + _ensure_loop(loop) + self._main_loop = loop self.storage = storage self.run_tasks_by_default = run_tasks_by_default @@ -79,10 +86,27 @@ class Dispatcher(DataMixin, ContextInstanceMixin): self._polling = False self._closed = True - self._close_waiter = loop.create_future() + self._dispatcher_close_waiter = None self._setup_filters() + @property + def loop(self) -> typing.Optional[asyncio.AbstractEventLoop]: + # for the sake of backward compatibility + # lib internally must delegate tasks with respect to _main_loop attribute + # however should never be used by the library itself + # use more generic approaches from asyncio's namespace + return self._main_loop + + @property + def _close_waiter(self) -> "asyncio.Future": + if self._dispatcher_close_waiter is None: + if self._main_loop is not None: + self._dispatcher_close_waiter = self._main_loop.create_future() + else: + self._dispatcher_close_waiter = asyncio.get_event_loop().create_future() + return self._dispatcher_close_waiter + def _setup_filters(self): filters_factory = self.filters_factory @@ -282,6 +306,13 @@ class Dispatcher(DataMixin, ContextInstanceMixin): return await self.bot.delete_webhook() + def _loop_create_task(self, coro): + if self._main_loop is None: + return asyncio.create_task(coro) + else: + _ensure_loop(self._main_loop) + return self._main_loop.create_task(coro) + async def start_polling(self, timeout=20, relax=0.1, @@ -337,7 +368,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin): log.debug(f"Received {len(updates)} updates.") offset = updates[-1].update_id + 1 - self.loop.create_task(self._process_polling_updates(updates, fast)) + self._loop_create_task(self._process_polling_updates(updates, fast)) if relax: await asyncio.sleep(relax) @@ -381,7 +412,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin): :return: """ - await asyncio.shield(self._close_waiter, loop=self.loop) + await asyncio.shield(self._close_waiter) def is_polling(self): """ @@ -1158,15 +1189,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin): try: response = task.result() except Exception as e: - self.loop.create_task( + self._loop_create_task( self.errors_handlers.notify(types.Update.get_current(), e)) else: if isinstance(response, BaseResponse): - self.loop.create_task(response.execute_response(self.bot)) + self._loop_create_task(response.execute_response(self.bot)) @functools.wraps(func) async def wrapper(*args, **kwargs): - task = self.loop.create_task(func(*args, **kwargs)) + task = self._loop_create_task(func(*args, **kwargs)) task.add_done_callback(process_response) return wrapper diff --git a/aiogram/dispatcher/middlewares.py b/aiogram/dispatcher/middlewares.py index dba3db4c..5fa09830 100644 --- a/aiogram/dispatcher/middlewares.py +++ b/aiogram/dispatcher/middlewares.py @@ -16,11 +16,14 @@ class MiddlewareManager: :param dispatcher: instance of Dispatcher """ self.dispatcher = dispatcher - self.loop = dispatcher.loop self.bot = dispatcher.bot self.storage = dispatcher.storage self.applications = [] + @property + def loop(self): + return self.dispatcher.loop + def setup(self, middleware): """ Setup middleware