Add download_file method

This commit is contained in:
Gabben 2020-03-16 22:22:35 +05:00
parent 26708154b0
commit 28c0295496
4 changed files with 177 additions and 2 deletions

View file

@ -1,7 +1,11 @@
from __future__ import annotations
import io
import pathlib
from contextlib import asynccontextmanager
from typing import Any, Optional, TypeVar
from typing import Any, AsyncGenerator, BinaryIO, Optional, TypeVar, Union
import aiofiles
from ...utils.mixins import ContextInstanceMixin, DataMixin
from ...utils.token import extract_bot_id, validate_token
@ -69,6 +73,63 @@ class BaseBot(ContextInstanceMixin, DataMixin):
await self.close()
self.reset_current(token)
@staticmethod
async def __download_file_binary_io(
destination: BinaryIO, seek: bool, stream: AsyncGenerator[bytes, None]
) -> BinaryIO:
async for chunk in stream:
destination.write(chunk)
destination.flush()
if seek is True:
destination.seek(0)
return destination
@staticmethod
async def __download_file(
destination: Union[str, pathlib.Path], stream: AsyncGenerator[bytes, None]
):
async with aiofiles.open(destination, "wb") as f:
async for chunk in stream:
await f.write(chunk)
async def download_file(
self,
file_path: str,
destination: Optional[Union[BinaryIO, pathlib.Path, str]] = None,
timeout: int = 30,
chunk_size: int = 65536,
seek: bool = True,
) -> Optional[BinaryIO]:
"""
Download file by file_path to destination.
If you want to automatically create destination (:class:`io.BytesIO`) use default
value of destination and handle result of this method.
:param file_path: File path on Telegram server (You can get it from :obj:`aiogram.types.File`)
:type file_path: str
:param destination: Filename, file path or instance of :class:`io.IOBase`. For e.g. :class:`io.BytesIO`, defaults to None
:type destination: Optional[Union[BinaryIO, pathlib.Path, str]]
:param timeout: Total timeout in seconds, defaults to 30
:type timeout: int
:param chunk_size: Chunk size, defaults to 65536
:type chunk_size: int
:param seek: Go to start of file when downloading is finished. Used only for :class:`typing.BinaryIO` type destination, defaults to True
:type seek: bool
"""
if destination is None:
destination = io.BytesIO()
url = self.session.api.file_url(token=self.__token, path=file_path)
stream = self.session.stream_content(url=url, timeout=timeout, chunk_size=chunk_size)
if isinstance(destination, (str, pathlib.Path)):
return await self.__download_file(destination=destination, stream=stream)
else:
return await self.__download_file_binary_io(
destination=destination, seek=seek, stream=stream
)
def __hash__(self) -> int:
"""
Get hash for the token

View file

@ -0,0 +1,61 @@
# How to download file?
Before you start, read the documentation for the [getFile](./methods/get_file.md) method.
## Download file manually
First, you must get the `file_id` of the file you want to download. Information about files sent to the bot is contained in [Message](./types/message.md).
For example, download the document that came to the bot.
```python3
file_id = message.document.file_id
```
Then use the [getFile](./methods/get_file.md) method to get `file_path`.
```python3
file = await bot.get_file(file_id)
file_path = file.file_path
```
After that, use the `download_file` method from the bot object.
### download_file(...)
Download file by file_path to destination.
If you want to automatically create destination (`#!python3 io.BytesIO`) use default
value of destination and handle result of this method.
|Argument|Type|Description|
|---|---|---|
| file_path | `#!python3 str` | File path on Telegram server |
| destination | `#!python3 Optional[Union[BinaryIO, pathlib.Path, str]]` | Filename, file path or instance of `#!python3 io.IOBase`. For e.g. `#!python3 io.BytesIO` (Default: `#!python3 None`) |
| chunk_size | `#!python3 int` | File chunks size (Default: `64 kb`) |
| timeout | `#!python3 int` | Total timeout in seconds (Default: `30`) |
| chunk_size | `#!python3 int` | Chunk size (Default: `65536`) |
| seek | `#!python3 bool` | Go to start of file when downloading is finished. Used only for destination with `#!python3 typing.BinaryIO` type (Default: `#!python3 True`) |
There are two options where you can download the file: to **disk** or to **binary I/O object**.
### Download file to disk
To download file to disk, you must specify the file name or path where to download the file. In this case, the function will return nothing.
```python3
await bot.download_file(file_path, "text.txt")
```
### Download file to binary I/O object
To download file to binary I/O object, you must specify an object with the `#!python3 typing.BinaryIO` type or use the default (`#!python3 None`) value.
In the first case, the function will return your object:
```python3
my_object = MyBinaryIO()
result: MyBinaryIO = await bot.download_file(file_path, my_object)
# print(result is my_object) # True
```
If you leave the default value, an `#!python3 io.BytesIO` object will be created and returned.
```python3
result: io.BytesIO = await bot.download_file(file_path)
```

View file

@ -218,6 +218,7 @@ nav:
- api/types/callback_game.md
- api/types/game_high_score.md
- api/sending_files.md
- api/downloading_files.md
- Dispatcher:
- dispatcher/index.md
- dispatcher/router.md

View file

@ -1,4 +1,9 @@
import io
import aiofiles
import pytest
from aiofiles import threadpool
from aresponses import ResponsesMockServer
from aiogram.api.client.base import BaseBot
from aiogram.api.client.session.aiohttp import AiohttpSession
@ -7,7 +12,7 @@ from aiogram.api.methods import GetMe
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
from unittest.mock import AsyncMock as CoroutineMock, MagicMock, patch # type: ignore
class TestBaseBot:
@ -63,3 +68,50 @@ class TestBaseBot:
mocked_close.assert_awaited()
else:
mocked_close.assert_not_awaited()
@pytest.mark.asyncio
async def test_download_file(self, aresponses: ResponsesMockServer):
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
# https://github.com/Tinche/aiofiles#writing-tests-for-aiofiles
aiofiles.threadpool.wrap.register(MagicMock)(
lambda *args, **kwargs: threadpool.AsyncBufferedIOBase(*args, **kwargs)
)
mock_file = MagicMock()
base_bot = BaseBot("42:TEST")
with patch("aiofiles.threadpool.sync_open", return_value=mock_file):
await base_bot.download_file("TEST", "file.png")
mock_file.write.assert_called_once_with(b"\f" * 10)
@pytest.mark.asyncio
async def test_download_file_default_destination(self, aresponses: ResponsesMockServer):
base_bot = BaseBot("42:TEST")
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
result = await base_bot.download_file("TEST")
assert isinstance(result, io.BytesIO)
assert result.read() == b"\f" * 10
@pytest.mark.asyncio
async def test_download_file_custom_destination(self, aresponses: ResponsesMockServer):
base_bot = BaseBot("42:TEST")
aresponses.add(
aresponses.ANY, aresponses.ANY, "get", aresponses.Response(status=200, body=b"\f" * 10)
)
custom = io.BytesIO()
result = await base_bot.download_file("TEST", custom)
assert isinstance(result, io.BytesIO)
assert result is custom
assert result.read() == b"\f" * 10