This commit is contained in:
darksidecat 2021-08-30 16:31:51 +03:00
parent c89bf6fbf8
commit e6d0ca2c97
2 changed files with 147 additions and 11 deletions

View file

@ -1,5 +1,9 @@
import os
import pathlib
from io import IOBase
from typing import Union, Optional
from aiogram.utils.deprecated import warn_deprecated
class Downloadable:
@ -7,32 +11,76 @@ class Downloadable:
Mixin for files
"""
async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download(
self,
destination=None,
timeout=30,
chunk_size=65536,
seek=True,
make_dirs=True,
*,
destination_dir: Optional[Union[str, pathlib.Path]] = None,
destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None
):
"""
Download file
:param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
At most one of these parameters can be used: :param destination_dir:, :param destination_file:
:param destination: deprecated, alias for :param destination_dir:
:param timeout: Integer
:param chunk_size: Integer
:param seek: Boolean - go to start of file when downloading is finished.
:param make_dirs: Make dirs if not exist
:param destination_dir: directory for saving files
:param destination_file: the path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
:return: destination
"""
if destination:
warn_deprecated("destination parameter is deprecated, please use destination_dir.")
destination_dir = destination
if destination_dir and destination_file:
raise ValueError("Use only one of the parameters: destination_dir or destination_file.")
file, destination = await self._prepare_destination(
destination_dir,
destination_file,
make_dirs
)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def _prepare_destination(self, destination_dir, destination_file, make_dirs):
file = await self.get_file()
is_path = True
if destination is None:
if destination_dir is None and destination_file is None:
destination = file.file_path
elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination):
destination = os.path.join(destination, file.file_path)
else:
is_path = False
if is_path and make_dirs:
elif destination_dir:
if isinstance(destination_dir, IOBase): # for backward compatibility
return file, destination_dir
elif isinstance(destination_dir, (str, pathlib.Path)):
destination = os.path.join(destination_dir, file.file_path)
else:
raise TypeError("destination_dir must be str or pathlib.Path")
else:
if isinstance(destination_file, IOBase):
return file, destination_file
elif isinstance(destination_file, (str, pathlib.Path)):
destination = destination_file
else:
raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type")
if make_dirs:
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return file, destination
async def get_file(self):
"""

View file

@ -0,0 +1,88 @@
import os
import shutil
from io import BytesIO
from pathlib import Path
import pytest
from aiogram import Bot
from aiogram.types import File
from aiogram.types.mixins import Downloadable
from tests import TOKEN
from tests.types.dataset import FILE
DIR_NAME = 'downloadable_tests'
DIR = Path.joinpath(Path(__file__).parent, DIR_NAME)
pytestmark = pytest.mark.asyncio
@pytest.fixture(name='bot')
async def bot_fixture():
""" Bot fixture """
_bot = Bot(TOKEN)
yield _bot
await _bot.session.close()
@pytest.fixture
def work_directory(request):
os.makedirs(DIR, exist_ok=True)
os.chdir(Path.joinpath(Path(request.fspath.dirname), DIR_NAME))
yield DIR
os.chdir(request.config.invocation_dir)
shutil.rmtree(DIR)
@pytest.fixture
def downloadable(bot):
async def get_file():
return File(**FILE)
downloadable = Downloadable()
downloadable.get_file = get_file
downloadable.bot = bot
return downloadable
class TestDownloadable:
async def test_download_make_dirs_false_nodir(self, work_directory, downloadable):
with pytest.raises(FileNotFoundError):
await downloadable.download(make_dirs=False)
async def test_download_make_dirs_false_mkdir(self, work_directory, downloadable):
os.mkdir('voice')
await downloadable.download(make_dirs=False)
assert os.path.isfile(work_directory.joinpath(FILE["file_path"]))
async def test_download_make_dirs_true(self, work_directory, downloadable):
await downloadable.download(make_dirs=True)
assert os.path.isfile(work_directory.joinpath(FILE["file_path"]))
async def test_download_warning(self, work_directory, downloadable):
with pytest.deprecated_call():
await downloadable.download("test")
assert os.path.isfile(work_directory.joinpath('test', FILE["file_path"]))
async def test_download_raise_value_error(self, work_directory, downloadable):
with pytest.raises(ValueError):
await downloadable.download(destination_dir="a", destination_file="b")
async def test_download_destination_dir(self, work_directory, downloadable):
await downloadable.download(destination_dir='test_dir')
assert os.path.isfile(work_directory.joinpath('test_dir', FILE["file_path"]))
async def test_download_destination_file(self, work_directory, downloadable):
await downloadable.download(destination_file=os.path.join('dir_name', 'file_name'))
assert os.path.isfile(work_directory.joinpath('dir_name', 'file_name'))
async def test_download_io_bytes(self, work_directory, downloadable):
file = BytesIO()
await downloadable.download(destination_file=file)
assert len(file.read()) != 0
async def test_download_io_bytes_backward_compatibility(self, work_directory, downloadable):
file = BytesIO()
await downloadable.download(destination_dir=file)
assert len(file.read()) != 0