From a7916c11031b216a145fdc4a9766c228a162a6d7 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Wed, 2 Aug 2023 21:41:07 +0300 Subject: [PATCH] Reworked InputFile sending (#1238) * Reworked InputFile sending * Added changelog --- CHANGES/1238.misc.rst | 3 +++ aiogram/client/session/aiohttp.py | 7 +++++- aiogram/types/input_file.py | 23 ++++++++----------- aiogram/webhook/aiohttp_server.py | 2 +- .../test_session/test_aiohttp_session.py | 4 ++-- .../test_types/test_chat_join_request.py | 6 +---- tests/test_api/test_types/test_input_file.py | 21 ++++++++--------- 7 files changed, 32 insertions(+), 34 deletions(-) create mode 100644 CHANGES/1238.misc.rst diff --git a/CHANGES/1238.misc.rst b/CHANGES/1238.misc.rst new file mode 100644 index 00000000..4f62cae1 --- /dev/null +++ b/CHANGES/1238.misc.rst @@ -0,0 +1,3 @@ +Reworked InputFile reading, removed :code:`__aiter__` method, added `bot: Bot` argument to +the :code:`.read(...)` method, so, from now URLInputFile can be used without specifying +bot instance. diff --git a/aiogram/client/session/aiohttp.py b/aiogram/client/session/aiohttp.py index 79e2fa4f..19a8edbd 100644 --- a/aiogram/client/session/aiohttp.py +++ b/aiogram/client/session/aiohttp.py @@ -6,6 +6,7 @@ from typing import ( TYPE_CHECKING, Any, AsyncGenerator, + AsyncIterator, Dict, Iterable, List, @@ -147,7 +148,11 @@ class AiohttpSession(BaseSession): continue form.add_field(key, value) for key, value in files.items(): - form.add_field(key, value, filename=value.filename or key) + form.add_field( + key, + value.read(bot), + filename=value.filename or key, + ) return form async def make_request( diff --git a/aiogram/types/input_file.py b/aiogram/types/input_file.py index ed0a2433..7ad8bee3 100644 --- a/aiogram/types/input_file.py +++ b/aiogram/types/input_file.py @@ -41,13 +41,9 @@ class InputFile(ABC): self.chunk_size = chunk_size @abstractmethod - async def read(self, chunk_size: int) -> AsyncGenerator[bytes, None]: # pragma: no cover + async def read(self, bot: "Bot") -> AsyncGenerator[bytes, None]: # pragma: no cover yield b"" - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self.read(self.chunk_size): - yield chunk - class BufferedInputFile(InputFile): def __init__(self, file: bytes, filename: str, chunk_size: int = DEFAULT_CHUNK_SIZE): @@ -84,9 +80,9 @@ class BufferedInputFile(InputFile): data = f.read() return cls(data, filename=filename, chunk_size=chunk_size) - async def read(self, chunk_size: int) -> AsyncGenerator[bytes, None]: + async def read(self, bot: "Bot") -> AsyncGenerator[bytes, None]: buffer = io.BytesIO(self.data) - while chunk := buffer.read(chunk_size): + while chunk := buffer.read(self.chunk_size): yield chunk @@ -111,9 +107,9 @@ class FSInputFile(InputFile): self.path = path - async def read(self, chunk_size: int) -> AsyncGenerator[bytes, None]: + async def read(self, bot: "Bot") -> AsyncGenerator[bytes, None]: async with aiofiles.open(self.path, "rb") as f: - while chunk := await f.read(chunk_size): + while chunk := await f.read(self.chunk_size): yield chunk @@ -121,11 +117,11 @@ class URLInputFile(InputFile): def __init__( self, url: str, - bot: "Bot", headers: Optional[Dict[str, Any]] = None, filename: Optional[str] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, timeout: int = 30, + bot: Optional["Bot"] = None, ): """ Represents object for streaming files from internet @@ -136,7 +132,7 @@ class URLInputFile(InputFile): :param chunk_size: Uploading chunk size :param timeout: Timeout for downloading :param bot: Bot instance to use HTTP session from. - If not specified, will be used current bot from context. + If not specified, will be used current bot """ super().__init__(filename=filename, chunk_size=chunk_size) if headers is None: @@ -147,8 +143,9 @@ class URLInputFile(InputFile): self.timeout = timeout self.bot = bot - async def read(self, chunk_size: int) -> AsyncGenerator[bytes, None]: - stream = self.bot.session.stream_content( + async def read(self, bot: "Bot") -> AsyncGenerator[bytes, None]: + bot = self.bot or bot + stream = bot.session.stream_content( url=self.url, headers=self.headers, timeout=self.timeout, diff --git a/aiogram/webhook/aiohttp_server.py b/aiogram/webhook/aiohttp_server.py index f623484f..af1c3c56 100644 --- a/aiogram/webhook/aiohttp_server.py +++ b/aiogram/webhook/aiohttp_server.py @@ -170,7 +170,7 @@ class BaseRequestHandler(ABC): payload.set_content_disposition("form-data", name=key) for key, value in files.items(): - payload = writer.append(value) + payload = writer.append(value.read(bot)) payload.set_content_disposition( "form-data", name=key, diff --git a/tests/test_api/test_client/test_session/test_aiohttp_session.py b/tests/test_api/test_client/test_session/test_aiohttp_session.py index fd7cabfd..d79eb875 100644 --- a/tests/test_api/test_client/test_session/test_aiohttp_session.py +++ b/tests/test_api/test_client/test_session/test_aiohttp_session.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, AsyncContextManager, AsyncGenerator, Dict, List +from typing import Any, AsyncContextManager, AsyncGenerator, AsyncIterable, Dict, List from unittest.mock import AsyncMock, patch import aiohttp_socks @@ -156,7 +156,7 @@ class TestAiohttpSession: assert fields[1][2].startswith("attach://") assert fields[2][0]["name"] == fields[1][2][9:] assert fields[2][0]["filename"] == "file.txt" - assert isinstance(fields[2][2], BareInputFile) + assert isinstance(fields[2][2], AsyncIterable) async def test_make_request(self, bot: MockedBot, aresponses: ResponsesMockServer): aresponses.add( diff --git a/tests/test_api/test_types/test_chat_join_request.py b/tests/test_api/test_types/test_chat_join_request.py index a36dd935..ec442cb1 100644 --- a/tests/test_api/test_types/test_chat_join_request.py +++ b/tests/test_api/test_types/test_chat_join_request.py @@ -24,11 +24,7 @@ from aiogram.methods import ( SendVideoNote, SendVoice, ) -from aiogram.types import ( - Chat, - ChatJoinRequest, - User, -) +from aiogram.types import Chat, ChatJoinRequest, User class TestChatJoinRequest: diff --git a/tests/test_api/test_types/test_input_file.py b/tests/test_api/test_types/test_input_file.py index 81e80ad5..28daa92c 100644 --- a/tests/test_api/test_types/test_input_file.py +++ b/tests/test_api/test_types/test_input_file.py @@ -12,19 +12,18 @@ class TestInputFile: file = FSInputFile(__file__) assert isinstance(file, InputFile) - assert isinstance(file, AsyncIterable) assert file.filename is not None assert file.filename.startswith("test_") assert file.filename.endswith(".py") assert file.chunk_size > 0 - async def test_fs_input_file_readable(self): + async def test_fs_input_file_readable(self, bot: MockedBot): file = FSInputFile(__file__, chunk_size=1) assert file.chunk_size == 1 size = 0 - async for chunk in file: + async for chunk in file.read(bot): chunk_size = len(chunk) assert chunk_size == 1 size += chunk_size @@ -34,15 +33,14 @@ class TestInputFile: file = BufferedInputFile(b"\f" * 10, filename="file.bin") assert isinstance(file, InputFile) - assert isinstance(file, AsyncIterable) assert file.filename == "file.bin" assert isinstance(file.data, bytes) - async def test_buffered_input_file_readable(self): + async def test_buffered_input_file_readable(self, bot: MockedBot): file = BufferedInputFile(b"\f" * 10, filename="file.bin", chunk_size=1) size = 0 - async for chunk in file: + async for chunk in file.read(bot): chunk_size = len(chunk) assert chunk_size == 1 size += chunk_size @@ -52,32 +50,31 @@ class TestInputFile: file = BufferedInputFile.from_file(__file__, chunk_size=10) assert isinstance(file, InputFile) - assert isinstance(file, AsyncIterable) assert file.filename is not None assert file.filename.startswith("test_") assert file.filename.endswith(".py") assert isinstance(file.data, bytes) assert file.chunk_size == 10 - async def test_buffered_input_file_from_file_readable(self): + async def test_buffered_input_file_from_file_readable(self, bot: MockedBot): file = BufferedInputFile.from_file(__file__, chunk_size=1) size = 0 - async for chunk in file: + async for chunk in file.read(bot): chunk_size = len(chunk) assert chunk_size == 1 size += chunk_size assert size > 0 - async def test_uri_input_file(self, aresponses: ResponsesMockServer): + async def test_url_input_file(self, aresponses: ResponsesMockServer): aresponses.add( aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10) ) bot = Bot(token="42:TEST") - file = URLInputFile("https://test.org/", bot, chunk_size=1) + file = URLInputFile("https://test.org/", chunk_size=1) size = 0 - async for chunk in file: + async for chunk in file.read(bot): assert chunk == b"\f" chunk_size = len(chunk) assert chunk_size == 1