From 28c0295496e3bdbacbdc2feebbb2c1d8b1206c25 Mon Sep 17 00:00:00 2001 From: Gabben <43146729+gabbhack@users.noreply.github.com> Date: Mon, 16 Mar 2020 22:22:35 +0500 Subject: [PATCH] :sparkles: Add download_file method --- aiogram/api/client/base.py | 63 ++++++++++++++++++++- docs/api/downloading_files.md | 61 ++++++++++++++++++++ mkdocs.yml | 1 + tests/test_api/test_client/test_base_bot.py | 54 +++++++++++++++++- 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 docs/api/downloading_files.md diff --git a/aiogram/api/client/base.py b/aiogram/api/client/base.py index 4d8b7453..70ed56c0 100644 --- a/aiogram/api/client/base.py +++ b/aiogram/api/client/base.py @@ -1,7 +1,11 @@ from __future__ import annotations +import io +import pathlib from contextlib import asynccontextmanager -from typing import Any, Optional, TypeVar +from typing import Any, AsyncGenerator, BinaryIO, Optional, TypeVar, Union + +import aiofiles from ...utils.mixins import ContextInstanceMixin, DataMixin from ...utils.token import extract_bot_id, validate_token @@ -69,6 +73,63 @@ class BaseBot(ContextInstanceMixin, DataMixin): await self.close() self.reset_current(token) + @staticmethod + async def __download_file_binary_io( + destination: BinaryIO, seek: bool, stream: AsyncGenerator[bytes, None] + ) -> BinaryIO: + async for chunk in stream: + destination.write(chunk) + destination.flush() + if seek is True: + destination.seek(0) + return destination + + @staticmethod + async def __download_file( + destination: Union[str, pathlib.Path], stream: AsyncGenerator[bytes, None] + ): + async with aiofiles.open(destination, "wb") as f: + async for chunk in stream: + await f.write(chunk) + + async def download_file( + self, + file_path: str, + destination: Optional[Union[BinaryIO, pathlib.Path, str]] = None, + timeout: int = 30, + chunk_size: int = 65536, + seek: bool = True, + ) -> Optional[BinaryIO]: + """ + Download file by file_path to destination. + + If you want to automatically create destination (:class:`io.BytesIO`) use default + value of destination and handle result of this method. + + :param file_path: File path on Telegram server (You can get it from :obj:`aiogram.types.File`) + :type file_path: str + :param destination: Filename, file path or instance of :class:`io.IOBase`. For e.g. :class:`io.BytesIO`, defaults to None + :type destination: Optional[Union[BinaryIO, pathlib.Path, str]] + :param timeout: Total timeout in seconds, defaults to 30 + :type timeout: int + :param chunk_size: Chunk size, defaults to 65536 + :type chunk_size: int + :param seek: Go to start of file when downloading is finished. Used only for :class:`typing.BinaryIO` type destination, defaults to True + :type seek: bool + """ + if destination is None: + destination = io.BytesIO() + + url = self.session.api.file_url(token=self.__token, path=file_path) + stream = self.session.stream_content(url=url, timeout=timeout, chunk_size=chunk_size) + + if isinstance(destination, (str, pathlib.Path)): + return await self.__download_file(destination=destination, stream=stream) + else: + return await self.__download_file_binary_io( + destination=destination, seek=seek, stream=stream + ) + def __hash__(self) -> int: """ Get hash for the token diff --git a/docs/api/downloading_files.md b/docs/api/downloading_files.md new file mode 100644 index 00000000..b544e202 --- /dev/null +++ b/docs/api/downloading_files.md @@ -0,0 +1,61 @@ +# How to download file? + +Before you start, read the documentation for the [getFile](./methods/get_file.md) method. + +## Download file manually +First, you must get the `file_id` of the file you want to download. Information about files sent to the bot is contained in [Message](./types/message.md). + +For example, download the document that came to the bot. +```python3 +file_id = message.document.file_id +``` + +Then use the [getFile](./methods/get_file.md) method to get `file_path`. +```python3 +file = await bot.get_file(file_id) +file_path = file.file_path +``` + +After that, use the `download_file` method from the bot object. + +### download_file(...) + +Download file by file_path to destination. +If you want to automatically create destination (`#!python3 io.BytesIO`) use default +value of destination and handle result of this method. + +|Argument|Type|Description| +|---|---|---| +| file_path | `#!python3 str` | File path on Telegram server | +| destination | `#!python3 Optional[Union[BinaryIO, pathlib.Path, str]]` | Filename, file path or instance of `#!python3 io.IOBase`. For e.g. `#!python3 io.BytesIO` (Default: `#!python3 None`) | +| chunk_size | `#!python3 int` | File chunks size (Default: `64 kb`) | +| timeout | `#!python3 int` | Total timeout in seconds (Default: `30`) | +| chunk_size | `#!python3 int` | Chunk size (Default: `65536`) | +| seek | `#!python3 bool` | Go to start of file when downloading is finished. Used only for destination with `#!python3 typing.BinaryIO` type (Default: `#!python3 True`) | + +There are two options where you can download the file: to **disk** or to **binary I/O object**. + +### Download file to disk + +To download file to disk, you must specify the file name or path where to download the file. In this case, the function will return nothing. + +```python3 +await bot.download_file(file_path, "text.txt") +``` + +### Download file to binary I/O object + +To download file to binary I/O object, you must specify an object with the `#!python3 typing.BinaryIO` type or use the default (`#!python3 None`) value. + +In the first case, the function will return your object: +```python3 +my_object = MyBinaryIO() +result: MyBinaryIO = await bot.download_file(file_path, my_object) +# print(result is my_object) # True +``` + +If you leave the default value, an `#!python3 io.BytesIO` object will be created and returned. + +```python3 +result: io.BytesIO = await bot.download_file(file_path) +``` diff --git a/mkdocs.yml b/mkdocs.yml index 1f1955f3..ee9f47c3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -218,6 +218,7 @@ nav: - api/types/callback_game.md - api/types/game_high_score.md - api/sending_files.md + - api/downloading_files.md - Dispatcher: - dispatcher/index.md - dispatcher/router.md diff --git a/tests/test_api/test_client/test_base_bot.py b/tests/test_api/test_client/test_base_bot.py index 652f0918..a1036119 100644 --- a/tests/test_api/test_client/test_base_bot.py +++ b/tests/test_api/test_client/test_base_bot.py @@ -1,4 +1,9 @@ +import io + +import aiofiles import pytest +from aiofiles import threadpool +from aresponses import ResponsesMockServer from aiogram.api.client.base import BaseBot from aiogram.api.client.session.aiohttp import AiohttpSession @@ -7,7 +12,7 @@ from aiogram.api.methods import GetMe try: from asynctest import CoroutineMock, patch except ImportError: - from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + from unittest.mock import AsyncMock as CoroutineMock, MagicMock, patch # type: ignore class TestBaseBot: @@ -63,3 +68,50 @@ class TestBaseBot: mocked_close.assert_awaited() else: mocked_close.assert_not_awaited() + + @pytest.mark.asyncio + async def test_download_file(self, aresponses: ResponsesMockServer): + aresponses.add( + aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10) + ) + + # https://github.com/Tinche/aiofiles#writing-tests-for-aiofiles + aiofiles.threadpool.wrap.register(MagicMock)( + lambda *args, **kwargs: threadpool.AsyncBufferedIOBase(*args, **kwargs) + ) + + mock_file = MagicMock() + + base_bot = BaseBot("42:TEST") + with patch("aiofiles.threadpool.sync_open", return_value=mock_file): + await base_bot.download_file("TEST", "file.png") + mock_file.write.assert_called_once_with(b"\f" * 10) + + @pytest.mark.asyncio + async def test_download_file_default_destination(self, aresponses: ResponsesMockServer): + base_bot = BaseBot("42:TEST") + + aresponses.add( + aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10) + ) + + result = await base_bot.download_file("TEST") + + assert isinstance(result, io.BytesIO) + assert result.read() == b"\f" * 10 + + @pytest.mark.asyncio + async def test_download_file_custom_destination(self, aresponses: ResponsesMockServer): + base_bot = BaseBot("42:TEST") + + aresponses.add( + aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10) + ) + + custom = io.BytesIO() + + result = await base_bot.download_file("TEST", custom) + + assert isinstance(result, io.BytesIO) + assert result is custom + assert result.read() == b"\f" * 10