From 6508235d16413d6b35e7a082ac53d41c8d94a689 Mon Sep 17 00:00:00 2001 From: mpa Date: Sun, 10 May 2020 01:13:00 +0400 Subject: [PATCH] fix(BaseBot): remove __del__ method from BaseBot implement "lazy" session property getter and new get_new_session for BaseBot --- aiogram/bot/base.py | 43 ++++++++++++++---------- tests/test_bot/test_session.py | 61 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 tests/test_bot/test_session.py diff --git a/aiogram/bot/base.py b/aiogram/bot/base.py index b7015881..86347e88 100644 --- a/aiogram/bot/base.py +++ b/aiogram/bot/base.py @@ -5,7 +5,7 @@ import ssl import typing import warnings from contextvars import ContextVar -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Type import aiohttp import certifi @@ -74,6 +74,12 @@ class BaseBot: # aiohttp main session ssl_context = ssl.create_default_context(cafile=certifi.where()) + self._session: Optional[aiohttp.ClientSession] = None + self._connector_class: Type[aiohttp.TCPConnector] = aiohttp.TCPConnector + self._connector_init = dict( + limit=connections_limit, ssl=ssl_context, loop=self.loop + ) + if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')): from aiohttp_socks import SocksConnector from aiohttp_socks.utils import parse_proxy_url @@ -85,30 +91,31 @@ class BaseBot: if not password: password = proxy_auth.password - connector = SocksConnector(socks_ver=socks_ver, host=host, port=port, - username=username, password=password, - limit=connections_limit, ssl_context=ssl_context, - rdns=True, loop=self.loop) - + self._connector_class = SocksConnector + self._connector_init.update( + socks_ver=socks_ver, host=host, port=port, + username=username, password=password, rdns=True, + ) self.proxy = None self.proxy_auth = None - else: - connector = aiohttp.TCPConnector(limit=connections_limit, ssl=ssl_context, loop=self.loop) + self._timeout = None self.timeout = timeout - self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps) - self.parse_mode = parse_mode - def __del__(self): - if not hasattr(self, 'loop') or not hasattr(self, 'session'): - return - if self.loop.is_running(): - self.loop.create_task(self.close()) - return - loop = asyncio.new_event_loop() - loop.run_until_complete(self.close()) + def get_new_session(self) -> aiohttp.ClientSession: + return aiohttp.ClientSession( + connector=self._connector_class(**self._connector_init), + loop=self.loop, + json_serialize=json.dumps + ) + + @property + def session(self) -> Optional[aiohttp.ClientSession]: + if self._session is None or self._session.closed: + self._session = self.get_new_session() + return self._session @staticmethod def _prepare_timeout( diff --git a/tests/test_bot/test_session.py b/tests/test_bot/test_session.py new file mode 100644 index 00000000..dec6379c --- /dev/null +++ b/tests/test_bot/test_session.py @@ -0,0 +1,61 @@ +import aiohttp +import aiohttp_socks +import pytest + +from aiogram.bot.base import BaseBot + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +class TestAiohttpSession: + @pytest.mark.asyncio + async def test_create_bot(self): + bot = BaseBot(token="42:correct") + + assert bot._session is None + assert isinstance(bot._connector_init, dict) + assert all(key in {"limit", "ssl", "loop"} for key in bot._connector_init) + assert isinstance(bot._connector_class, type) + assert issubclass(bot._connector_class, aiohttp.TCPConnector) + + assert bot._session is None + + assert isinstance(bot.session, aiohttp.ClientSession) + assert bot.session == bot._session + + @pytest.mark.asyncio + async def test_create_proxy_bot(self): + socks_ver, host, port, username, password = ( + "socks5", "124.90.90.90", 9999, "login", "password" + ) + + bot = BaseBot( + token="42:correct", + proxy=f"{socks_ver}://{host}:{port}/", + proxy_auth=aiohttp.BasicAuth(username, password, "encoding"), + ) + + assert bot._connector_class == aiohttp_socks.SocksConnector + + assert isinstance(bot._connector_init, dict) + + init_kwargs = bot._connector_init + assert init_kwargs["username"] == username + assert init_kwargs["password"] == password + assert init_kwargs["host"] == host + assert init_kwargs["port"] == port + + @pytest.mark.asyncio + async def test_close_session(self): + bot = BaseBot(token="42:correct",) + aiohttp_client_0 = bot.session + + with patch("aiohttp.ClientSession.close", new=CoroutineMock()) as mocked_close: + await aiohttp_client_0.close() + mocked_close.assert_called_once() + + await aiohttp_client_0.close() + assert aiohttp_client_0 != bot.session # will create new session