mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
fix(bot,dispatcher): do not use _MainThread event loop (#397)
* fix(bot,dispatcher): do not use _MainThread event loop on ::Bot, ::Dispatcher initializations * fix: use more generic get approach * docs: comments * chore(task): asyncio.create_task comes with py3.7 * fix(dispatcher): todo
This commit is contained in:
parent
7d1c8c42d3
commit
c99b165668
3 changed files with 53 additions and 20 deletions
|
|
@ -56,6 +56,8 @@ class BaseBot:
|
|||
:type timeout: :obj:`typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]`
|
||||
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
||||
"""
|
||||
self._main_loop = loop
|
||||
|
||||
# Authentication
|
||||
if validate_token:
|
||||
api.check_token(token)
|
||||
|
|
@ -66,19 +68,12 @@ class BaseBot:
|
|||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
|
||||
# Asyncio loop instance
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self.loop = loop
|
||||
|
||||
# aiohttp main session
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._connector_class: Type[aiohttp.TCPConnector] = aiohttp.TCPConnector
|
||||
self._connector_init = dict(
|
||||
limit=connections_limit, ssl=ssl_context, loop=self.loop
|
||||
)
|
||||
self._connector_init = dict(limit=connections_limit, ssl=ssl_context)
|
||||
|
||||
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
|
||||
from aiohttp_socks import SocksConnector
|
||||
|
|
@ -106,11 +101,15 @@ class BaseBot:
|
|||
|
||||
def get_new_session(self) -> aiohttp.ClientSession:
|
||||
return aiohttp.ClientSession(
|
||||
connector=self._connector_class(**self._connector_init),
|
||||
loop=self.loop,
|
||||
connector=self._connector_class(**self._connector_init, loop=self._main_loop),
|
||||
loop=self._main_loop,
|
||||
json_serialize=json.dumps
|
||||
)
|
||||
|
||||
@property
|
||||
def loop(self) -> Optional[asyncio.AbstractEventLoop]:
|
||||
return self._main_loop
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[aiohttp.ClientSession]:
|
||||
if self._session is None or self._session.closed:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,13 @@ log = logging.getLogger(__name__)
|
|||
DEFAULT_RATE_LIMIT = .1
|
||||
|
||||
|
||||
def _ensure_loop(x: "asyncio.AbstractEventLoop"):
|
||||
assert isinstance(
|
||||
x, asyncio.AbstractEventLoop
|
||||
), f"Loop must be the implementation of {asyncio.AbstractEventLoop!r}, " \
|
||||
f"not {type(x)!r}"
|
||||
|
||||
|
||||
class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||
"""
|
||||
Simple Updates dispatcher
|
||||
|
|
@ -43,15 +50,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
if not isinstance(bot, Bot):
|
||||
raise TypeError(f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'")
|
||||
|
||||
if loop is None:
|
||||
loop = bot.loop
|
||||
if storage is None:
|
||||
storage = DisabledStorage()
|
||||
if filters_factory is None:
|
||||
filters_factory = FiltersFactory(self)
|
||||
|
||||
self.bot: Bot = bot
|
||||
self.loop = loop
|
||||
if loop is not None:
|
||||
_ensure_loop(loop)
|
||||
self._main_loop = loop
|
||||
self.storage = storage
|
||||
self.run_tasks_by_default = run_tasks_by_default
|
||||
|
||||
|
|
@ -79,10 +86,27 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
|
||||
self._polling = False
|
||||
self._closed = True
|
||||
self._close_waiter = loop.create_future()
|
||||
self._dispatcher_close_waiter = None
|
||||
|
||||
self._setup_filters()
|
||||
|
||||
@property
|
||||
def loop(self) -> typing.Optional[asyncio.AbstractEventLoop]:
|
||||
# for the sake of backward compatibility
|
||||
# lib internally must delegate tasks with respect to _main_loop attribute
|
||||
# however should never be used by the library itself
|
||||
# use more generic approaches from asyncio's namespace
|
||||
return self._main_loop
|
||||
|
||||
@property
|
||||
def _close_waiter(self) -> "asyncio.Future":
|
||||
if self._dispatcher_close_waiter is None:
|
||||
if self._main_loop is not None:
|
||||
self._dispatcher_close_waiter = self._main_loop.create_future()
|
||||
else:
|
||||
self._dispatcher_close_waiter = asyncio.get_event_loop().create_future()
|
||||
return self._dispatcher_close_waiter
|
||||
|
||||
def _setup_filters(self):
|
||||
filters_factory = self.filters_factory
|
||||
|
||||
|
|
@ -282,6 +306,13 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
|
||||
return await self.bot.delete_webhook()
|
||||
|
||||
def _loop_create_task(self, coro):
|
||||
if self._main_loop is None:
|
||||
return asyncio.create_task(coro)
|
||||
else:
|
||||
_ensure_loop(self._main_loop)
|
||||
return self._main_loop.create_task(coro)
|
||||
|
||||
async def start_polling(self,
|
||||
timeout=20,
|
||||
relax=0.1,
|
||||
|
|
@ -337,7 +368,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
log.debug(f"Received {len(updates)} updates.")
|
||||
offset = updates[-1].update_id + 1
|
||||
|
||||
self.loop.create_task(self._process_polling_updates(updates, fast))
|
||||
self._loop_create_task(self._process_polling_updates(updates, fast))
|
||||
|
||||
if relax:
|
||||
await asyncio.sleep(relax)
|
||||
|
|
@ -381,7 +412,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
|
||||
:return:
|
||||
"""
|
||||
await asyncio.shield(self._close_waiter, loop=self.loop)
|
||||
await asyncio.shield(self._close_waiter)
|
||||
|
||||
def is_polling(self):
|
||||
"""
|
||||
|
|
@ -1158,15 +1189,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
try:
|
||||
response = task.result()
|
||||
except Exception as e:
|
||||
self.loop.create_task(
|
||||
self._loop_create_task(
|
||||
self.errors_handlers.notify(types.Update.get_current(), e))
|
||||
else:
|
||||
if isinstance(response, BaseResponse):
|
||||
self.loop.create_task(response.execute_response(self.bot))
|
||||
self._loop_create_task(response.execute_response(self.bot))
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
task = self.loop.create_task(func(*args, **kwargs))
|
||||
task = self._loop_create_task(func(*args, **kwargs))
|
||||
task.add_done_callback(process_response)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ class MiddlewareManager:
|
|||
:param dispatcher: instance of Dispatcher
|
||||
"""
|
||||
self.dispatcher = dispatcher
|
||||
self.loop = dispatcher.loop
|
||||
self.bot = dispatcher.bot
|
||||
self.storage = dispatcher.storage
|
||||
self.applications = []
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
return self.dispatcher.loop
|
||||
|
||||
def setup(self, middleware):
|
||||
"""
|
||||
Setup middleware
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue