fix: graceful aiohttp session close

This commit is contained in:
Oleg A 2023-09-20 12:53:26 +03:00
parent cb68075f70
commit 56a78b97eb
No known key found for this signature in database
GPG key ID: 5FE046817A9657C5
4 changed files with 150 additions and 142 deletions

View file

@ -139,6 +139,10 @@ class AiohttpSession(BaseSession):
if self._session is not None and not self._session.closed:
await self._session.close()
# Wait 250 ms for the underlying SSL connections to close
# https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
await asyncio.sleep(0.25)
def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData:
form = FormData(quote_fields=False)
files: Dict[str, InputFile] = {}

View file

@ -31,27 +31,26 @@ class TestBot:
assert bot != "42:TEST"
async def test_emit(self):
bot = Bot("42:TEST")
async with Bot("42:TEST").context() as bot:
method = GetMe()
method = GetMe()
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.make_request",
new_callable=AsyncMock,
) as mocked_make_request:
await bot(method)
mocked_make_request.assert_awaited_with(bot, method, timeout=None)
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.make_request",
new_callable=AsyncMock,
) as mocked_make_request:
await bot(method)
mocked_make_request.assert_awaited_with(bot, method, timeout=None)
async def test_close(self):
session = AiohttpSession()
bot = Bot("42:TEST", session=session)
await session.create_session()
async with AiohttpSession() as session:
bot = Bot("42:TEST", session=session)
await session.create_session()
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.close", new_callable=AsyncMock
) as mocked_close:
await bot.session.close()
mocked_close.assert_awaited()
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.close", new_callable=AsyncMock
) as mocked_close:
await bot.session.close()
mocked_close.assert_awaited()
@pytest.mark.parametrize("close", [True, False])
async def test_context_manager(self, close: bool):
@ -64,10 +63,14 @@ class TestBot:
mocked_close.assert_awaited()
else:
mocked_close.assert_not_awaited()
await bot.session.close()
async def test_download_file(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
aresponses.ANY,
aresponses.ANY,
"get",
aresponses.Response(status=200, body=b"\f" * 10),
)
# https://github.com/Tinche/aiofiles#writing-tests-for-aiofiles
@ -77,37 +80,42 @@ class TestBot:
mock_file = MagicMock()
bot = Bot("42:TEST")
with patch("aiofiles.threadpool.sync_open", return_value=mock_file):
await bot.download_file("TEST", "file.png")
mock_file.write.assert_called_once_with(b"\f" * 10)
async with Bot("42:TEST").context() as bot:
with patch("aiofiles.threadpool.sync_open", return_value=mock_file):
await bot.download_file("TEST", "file.png")
mock_file.write.assert_called_once_with(b"\f" * 10)
async def test_download_file_default_destination(self, aresponses: ResponsesMockServer):
bot = Bot("42:TEST")
async with Bot("42:TEST").context() as bot:
aresponses.add(
aresponses.ANY,
aresponses.ANY,
"get",
aresponses.Response(status=200, body=b"\f" * 10),
)
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
result = await bot.download_file("TEST")
result = await bot.download_file("TEST")
assert isinstance(result, io.BytesIO)
assert result.read() == b"\f" * 10
assert isinstance(result, io.BytesIO)
assert result.read() == b"\f" * 10
async def test_download_file_custom_destination(self, aresponses: ResponsesMockServer):
bot = Bot("42:TEST")
async with Bot("42:TEST").context() as bot:
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
aresponses.add(
aresponses.ANY,
aresponses.ANY,
"get",
aresponses.Response(status=200, body=b"\f" * 10),
)
custom = io.BytesIO()
custom = io.BytesIO()
result = await bot.download_file("TEST", custom)
result = await bot.download_file("TEST", custom)
assert isinstance(result, io.BytesIO)
assert result is custom
assert result.read() == b"\f" * 10
assert isinstance(result, io.BytesIO)
assert result is custom
assert result.read() == b"\f" * 10
async def test_download(self, bot: MockedBot, aresponses: ResponsesMockServer):
bot.add_result_for(

View file

@ -12,7 +12,7 @@ from aiogram.client.session import aiohttp
from aiogram.client.session.aiohttp import AiohttpSession
from aiogram.exceptions import TelegramNetworkError
from aiogram.methods import TelegramMethod
from aiogram.types import UNSET_PARSE_MODE, InputFile
from aiogram.types import InputFile, UNSET_PARSE_MODE
from tests.mocked_bot import MockedBot
@ -24,61 +24,53 @@ class BareInputFile(InputFile):
class TestAiohttpSession:
async def test_create_session(self):
session = AiohttpSession()
assert session._session is None
aiohttp_session = await session.create_session()
assert session._session is not None
assert isinstance(aiohttp_session, aiohttp.ClientSession)
await session.close()
async def test_create_proxy_session(self):
session = AiohttpSession(
proxy=("socks5://proxy.url/", aiohttp.BasicAuth("login", "password", "encoding"))
)
auth = aiohttp.BasicAuth("login", "password", "encoding")
async with AiohttpSession(proxy=("socks5://proxy.url/", auth)) as session:
assert session._connector_type == aiohttp_socks.ProxyConnector
assert session._connector_type == aiohttp_socks.ProxyConnector
assert isinstance(session._connector_init, dict)
assert session._connector_init["proxy_type"] is aiohttp_socks.ProxyType.SOCKS5
assert isinstance(session._connector_init, dict)
assert session._connector_init["proxy_type"] is aiohttp_socks.ProxyType.SOCKS5
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ProxyConnector)
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ProxyConnector)
async def test_create_proxy_session_proxy_url(self):
session = AiohttpSession(proxy="socks4://proxy.url/")
async with AiohttpSession(proxy="socks4://proxy.url/") as session:
assert isinstance(session.proxy, str)
assert isinstance(session.proxy, str)
assert isinstance(session._connector_init, dict)
assert session._connector_init["proxy_type"] is aiohttp_socks.ProxyType.SOCKS4
assert isinstance(session._connector_init, dict)
assert session._connector_init["proxy_type"] is aiohttp_socks.ProxyType.SOCKS4
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ProxyConnector)
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ProxyConnector)
async def test_create_proxy_session_chained_proxies(self):
session = AiohttpSession(
proxy=[
"socks4://proxy.url/",
"socks5://proxy.url/",
"http://user:password@127.0.0.1:3128",
]
)
proxy_chain = [
"socks4://proxy.url/",
"socks5://proxy.url/",
"http://user:password@127.0.0.1:3128",
]
async with AiohttpSession(proxy=proxy_chain) as session:
assert isinstance(session.proxy, list)
assert isinstance(session.proxy, list)
assert isinstance(session._connector_init, dict)
assert isinstance(session._connector_init, dict)
assert isinstance(session._connector_init["proxy_infos"], list)
assert isinstance(session._connector_init["proxy_infos"][0], aiohttp_socks.ProxyInfo)
proxy_infos = session._connector_init["proxy_infos"]
assert isinstance(proxy_infos, list)
assert isinstance(proxy_infos[0], aiohttp_socks.ProxyInfo)
assert proxy_infos[0].proxy_type is aiohttp_socks.ProxyType.SOCKS4
assert proxy_infos[1].proxy_type is aiohttp_socks.ProxyType.SOCKS5
assert proxy_infos[2].proxy_type is aiohttp_socks.ProxyType.HTTP
assert (
session._connector_init["proxy_infos"][0].proxy_type is aiohttp_socks.ProxyType.SOCKS4
)
assert (
session._connector_init["proxy_infos"][1].proxy_type is aiohttp_socks.ProxyType.SOCKS5
)
assert session._connector_init["proxy_infos"][2].proxy_type is aiohttp_socks.ProxyType.HTTP
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
aiohttp_session = await session.create_session()
assert isinstance(aiohttp_session.connector, aiohttp_socks.ChainProxyConnector)
async def test_reset_connector(self):
session = AiohttpSession()
@ -93,6 +85,7 @@ class TestAiohttpSession:
assert session._should_reset_connector
await session.create_session()
assert session._should_reset_connector is False
await session.close()
async def test_close_session(self):
@ -170,54 +163,53 @@ class TestAiohttpSession:
),
)
session = AiohttpSession()
async with AiohttpSession() as session:
class TestMethod(TelegramMethod[int]):
__returning__ = int
__api_method__ = "method"
class TestMethod(TelegramMethod[int]):
__returning__ = int
__api_method__ = "method"
call = TestMethod()
call = TestMethod()
result = await session.make_request(bot, call)
assert isinstance(result, int)
assert result == 42
result = await session.make_request(bot, call)
assert isinstance(result, int)
assert result == 42
@pytest.mark.parametrize("error", [ClientError("mocked"), asyncio.TimeoutError()])
async def test_make_request_network_error(self, error):
bot = Bot("42:TEST")
async def side_effect(*args, **kwargs):
raise error
with patch(
"aiohttp.client.ClientSession._request",
new_callable=AsyncMock,
side_effect=side_effect,
):
with pytest.raises(TelegramNetworkError):
await bot.get_me()
async with Bot("42:TEST").context() as bot:
with patch(
"aiohttp.client.ClientSession._request",
new_callable=AsyncMock,
side_effect=side_effect,
):
with pytest.raises(TelegramNetworkError):
await bot.get_me()
async def test_stream_content(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
assert isinstance(stream, AsyncGenerator)
async with AiohttpSession() as session:
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
assert isinstance(stream, AsyncGenerator)
size = 0
async for chunk in stream:
assert isinstance(chunk, bytes)
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
size = 0
async for chunk in stream:
assert isinstance(chunk, bytes)
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
async def test_stream_content_404(self, aresponses: ResponsesMockServer):
aresponses.add(
@ -229,29 +221,30 @@ class TestAiohttpSession:
body=b"File not found",
),
)
session = AiohttpSession()
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
async with AiohttpSession() as session:
stream = session.stream_content(
"https://www.python.org/static/img/python-logo.png",
timeout=5,
chunk_size=1,
raise_for_status=True,
)
with pytest.raises(ClientError):
async for _ in stream:
...
with pytest.raises(ClientError):
async for _ in stream:
...
async def test_context_manager(self):
session = AiohttpSession()
assert isinstance(session, AsyncContextManager)
async with AiohttpSession() as session:
assert isinstance(session, AsyncContextManager)
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.create_session",
new_callable=AsyncMock,
) as mocked_create_session, patch(
"aiogram.client.session.aiohttp.AiohttpSession.close", new_callable=AsyncMock
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.assert_awaited_once()
mocked_create_session.assert_awaited_once()
with patch(
"aiogram.client.session.aiohttp.AiohttpSession.create_session",
new_callable=AsyncMock,
) as mocked_create_session, patch(
"aiogram.client.session.aiohttp.AiohttpSession.close",
new_callable=AsyncMock,
) as mocked_close:
async with session as ctx:
assert session == ctx
mocked_close.assert_awaited_once()
mocked_create_session.assert_awaited_once()

View file

@ -68,15 +68,18 @@ class TestInputFile:
async def test_url_input_file(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
aresponses.ANY,
aresponses.ANY,
"get",
aresponses.Response(status=200, body=b"\f" * 10),
)
bot = Bot(token="42:TEST")
file = URLInputFile("https://test.org/", chunk_size=1)
async with Bot(token="42:TEST").context() as bot:
file = URLInputFile("https://test.org/", chunk_size=1)
size = 0
async for chunk in file.read(bot):
assert chunk == b"\f"
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10
size = 0
async for chunk in file.read(bot):
assert chunk == b"\f"
chunk_size = len(chunk)
assert chunk_size == 1
size += chunk_size
assert size == 10