From 358ecc78213183adeb70809d7f3e95f1524657fb Mon Sep 17 00:00:00 2001 From: darksidecat <58224121+darksidecat@users.noreply.github.com> Date: Mon, 6 Sep 2021 00:05:52 +0300 Subject: [PATCH] Fix #665, add separate parametrs for saving to directory and file (#677) * close #665 * add backward compatibility * improve doc, codestyle * warning text update * use tmpdir fixture in tests --- aiogram/types/mixins.py | 80 +++++++++++++++++++++++++---- tests/types/test_mixins.py | 102 +++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 11 deletions(-) create mode 100644 tests/types/test_mixins.py diff --git a/aiogram/types/mixins.py b/aiogram/types/mixins.py index 13f8412f..83c65032 100644 --- a/aiogram/types/mixins.py +++ b/aiogram/types/mixins.py @@ -1,5 +1,9 @@ import os import pathlib +from io import IOBase +from typing import Union, Optional + +from aiogram.utils.deprecated import warn_deprecated class Downloadable: @@ -7,32 +11,86 @@ class Downloadable: Mixin for files """ - async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True): + async def download( + self, + destination=None, + timeout=30, + chunk_size=65536, + seek=True, + make_dirs=True, + *, + destination_dir: Optional[Union[str, pathlib.Path]] = None, + destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None + ): """ Download file - :param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` + At most one of these parameters can be used: :param destination_dir:, :param destination_file: + + :param destination: deprecated, use :param destination_dir: or :param destination_file: instead :param timeout: Integer :param chunk_size: Integer :param seek: Boolean - go to start of file when downloading is finished. :param make_dirs: Make dirs if not exist + :param destination_dir: directory for saving files + :param destination_file: path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` :return: destination """ + if destination: + warn_deprecated( + "destination parameter is deprecated, please use destination_dir or destination_file." + ) + if destination_dir and destination_file: + raise ValueError( + "Use only one of the parameters: destination_dir or destination_file." + ) + + file, destination = await self._prepare_destination( + destination, + destination_dir, + destination_file, + make_dirs + ) + + return await self.bot.download_file( + file_path=file.file_path, + destination=destination, + timeout=timeout, + chunk_size=chunk_size, + seek=seek, + ) + + async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs): file = await self.get_file() - is_path = True - if destination is None: + if not(any((dest, destination_dir, destination_file))): destination = file.file_path - elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination): - destination = os.path.join(destination, file.file_path) - else: - is_path = False - if is_path and make_dirs: + elif dest: # backward compatibility + if isinstance(dest, IOBase): + return file, dest + if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest): + destination = os.path.join(dest, file.file_path) + else: + destination = dest + + elif destination_dir: + if isinstance(destination_dir, (str, pathlib.Path)): + destination = os.path.join(destination_dir, file.file_path) + else: + raise TypeError("destination_dir must be str or pathlib.Path") + else: + if isinstance(destination_file, IOBase): + return file, destination_file + elif isinstance(destination_file, (str, pathlib.Path)): + destination = destination_file + else: + raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type") + + if make_dirs and os.path.dirname(destination): os.makedirs(os.path.dirname(destination), exist_ok=True) - return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout, - chunk_size=chunk_size, seek=seek) + return file, destination async def get_file(self): """ diff --git a/tests/types/test_mixins.py b/tests/types/test_mixins.py new file mode 100644 index 00000000..4327e8aa --- /dev/null +++ b/tests/types/test_mixins.py @@ -0,0 +1,102 @@ +import os +from io import BytesIO +from pathlib import Path + +import pytest + +from aiogram import Bot +from aiogram.types import File +from aiogram.types.mixins import Downloadable +from tests import TOKEN +from tests.types.dataset import FILE + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(name='bot') +async def bot_fixture(): + """ Bot fixture """ + _bot = Bot(TOKEN) + yield _bot + await _bot.session.close() + + +@pytest.fixture +def tmppath(tmpdir, request): + os.chdir(tmpdir) + yield Path(tmpdir) + os.chdir(request.config.invocation_dir) + + +@pytest.fixture +def downloadable(bot): + async def get_file(): + return File(**FILE) + + downloadable = Downloadable() + downloadable.get_file = get_file + downloadable.bot = bot + + return downloadable + + +class TestDownloadable: + async def test_download_make_dirs_false_nodir(self, tmppath, downloadable): + with pytest.raises(FileNotFoundError): + await downloadable.download(make_dirs=False) + + async def test_download_make_dirs_false_mkdir(self, tmppath, downloadable): + os.mkdir('voice') + await downloadable.download(make_dirs=False) + assert os.path.isfile(tmppath.joinpath(FILE["file_path"])) + + async def test_download_make_dirs_true(self, tmppath, downloadable): + await downloadable.download(make_dirs=True) + assert os.path.isfile(tmppath.joinpath(FILE["file_path"])) + + async def test_download_deprecation_warning(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test.file") + + async def test_download_destination(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test.file") + assert os.path.isfile(tmppath.joinpath('test.file')) + + async def test_download_destination_dir_exist(self, tmppath, downloadable): + os.mkdir("test_folder") + with pytest.deprecated_call(): + await downloadable.download("test_folder") + assert os.path.isfile(tmppath.joinpath('test_folder', FILE["file_path"])) + + async def test_download_destination_with_dir(self, tmppath, downloadable): + with pytest.deprecated_call(): + await downloadable.download(os.path.join('dir_name', 'file_name')) + assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name')) + + async def test_download_destination_io_bytes(self, tmppath, downloadable): + file = BytesIO() + with pytest.deprecated_call(): + await downloadable.download(file) + assert len(file.read()) != 0 + + async def test_download_raise_value_error(self, tmppath, downloadable): + with pytest.raises(ValueError): + await downloadable.download(destination_dir="a", destination_file="b") + + async def test_download_destination_dir(self, tmppath, downloadable): + await downloadable.download(destination_dir='test_dir') + assert os.path.isfile(tmppath.joinpath('test_dir', FILE["file_path"])) + + async def test_download_destination_file(self, tmppath, downloadable): + await downloadable.download(destination_file='file_name') + assert os.path.isfile(tmppath.joinpath('file_name')) + + async def test_download_destination_file_with_dir(self, tmppath, downloadable): + await downloadable.download(destination_file=os.path.join('dir_name', 'file_name')) + assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name')) + + async def test_download_io_bytes(self, tmppath, downloadable): + file = BytesIO() + await downloadable.download(destination_file=file) + assert len(file.read()) != 0