From e6d0ca2c97af2a12c4a2fdc256e7cff35fecbec1 Mon Sep 17 00:00:00 2001 From: darksidecat Date: Mon, 30 Aug 2021 16:31:51 +0300 Subject: [PATCH] close #665 --- aiogram/types/mixins.py | 70 +++++++++++++++++++++++++----- tests/types/test_mixins.py | 88 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 11 deletions(-) create mode 100644 tests/types/test_mixins.py diff --git a/aiogram/types/mixins.py b/aiogram/types/mixins.py index 13f8412f..ba635c61 100644 --- a/aiogram/types/mixins.py +++ b/aiogram/types/mixins.py @@ -1,5 +1,9 @@ import os import pathlib +from io import IOBase +from typing import Union, Optional + +from aiogram.utils.deprecated import warn_deprecated class Downloadable: @@ -7,32 +11,76 @@ class Downloadable: Mixin for files """ - async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True): + async def download( + self, + destination=None, + timeout=30, + chunk_size=65536, + seek=True, + make_dirs=True, + *, + destination_dir: Optional[Union[str, pathlib.Path]] = None, + destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None + ): """ Download file - :param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` + At most one of these parameters can be used: :param destination_dir:, :param destination_file: + + :param destination: deprecated, alias for :param destination_dir: :param timeout: Integer :param chunk_size: Integer :param seek: Boolean - go to start of file when downloading is finished. :param make_dirs: Make dirs if not exist + :param destination_dir: directory for saving files + :param destination_file: the path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` :return: destination """ + if destination: + warn_deprecated("destination parameter is deprecated, please use destination_dir.") + destination_dir = destination + if destination_dir and destination_file: + raise ValueError("Use only one of the parameters: destination_dir or destination_file.") + + file, destination = await self._prepare_destination( + destination_dir, + destination_file, + make_dirs + ) + + return await self.bot.download_file( + file_path=file.file_path, + destination=destination, + timeout=timeout, + chunk_size=chunk_size, + seek=seek, + ) + + async def _prepare_destination(self, destination_dir, destination_file, make_dirs): file = await self.get_file() - is_path = True - if destination is None: + if destination_dir is None and destination_file is None: destination = file.file_path - elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination): - destination = os.path.join(destination, file.file_path) - else: - is_path = False - if is_path and make_dirs: + elif destination_dir: + if isinstance(destination_dir, IOBase): # for backward compatibility + return file, destination_dir + elif isinstance(destination_dir, (str, pathlib.Path)): + destination = os.path.join(destination_dir, file.file_path) + else: + raise TypeError("destination_dir must be str or pathlib.Path") + else: + if isinstance(destination_file, IOBase): + return file, destination_file + elif isinstance(destination_file, (str, pathlib.Path)): + destination = destination_file + else: + raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type") + + if make_dirs: os.makedirs(os.path.dirname(destination), exist_ok=True) - return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout, - chunk_size=chunk_size, seek=seek) + return file, destination async def get_file(self): """ diff --git a/tests/types/test_mixins.py b/tests/types/test_mixins.py new file mode 100644 index 00000000..ba8e461d --- /dev/null +++ b/tests/types/test_mixins.py @@ -0,0 +1,88 @@ +import os +import shutil +from io import BytesIO +from pathlib import Path + +import pytest + +from aiogram import Bot +from aiogram.types import File +from aiogram.types.mixins import Downloadable +from tests import TOKEN +from tests.types.dataset import FILE + +DIR_NAME = 'downloadable_tests' +DIR = Path.joinpath(Path(__file__).parent, DIR_NAME) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(name='bot') +async def bot_fixture(): + """ Bot fixture """ + _bot = Bot(TOKEN) + yield _bot + await _bot.session.close() + + +@pytest.fixture +def work_directory(request): + os.makedirs(DIR, exist_ok=True) + os.chdir(Path.joinpath(Path(request.fspath.dirname), DIR_NAME)) + yield DIR + os.chdir(request.config.invocation_dir) + shutil.rmtree(DIR) + + +@pytest.fixture +def downloadable(bot): + async def get_file(): + return File(**FILE) + + downloadable = Downloadable() + downloadable.get_file = get_file + downloadable.bot = bot + + return downloadable + + +class TestDownloadable: + async def test_download_make_dirs_false_nodir(self, work_directory, downloadable): + with pytest.raises(FileNotFoundError): + await downloadable.download(make_dirs=False) + + async def test_download_make_dirs_false_mkdir(self, work_directory, downloadable): + os.mkdir('voice') + await downloadable.download(make_dirs=False) + assert os.path.isfile(work_directory.joinpath(FILE["file_path"])) + + async def test_download_make_dirs_true(self, work_directory, downloadable): + await downloadable.download(make_dirs=True) + assert os.path.isfile(work_directory.joinpath(FILE["file_path"])) + + async def test_download_warning(self, work_directory, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test") + assert os.path.isfile(work_directory.joinpath('test', FILE["file_path"])) + + async def test_download_raise_value_error(self, work_directory, downloadable): + with pytest.raises(ValueError): + await downloadable.download(destination_dir="a", destination_file="b") + + async def test_download_destination_dir(self, work_directory, downloadable): + await downloadable.download(destination_dir='test_dir') + assert os.path.isfile(work_directory.joinpath('test_dir', FILE["file_path"])) + + async def test_download_destination_file(self, work_directory, downloadable): + await downloadable.download(destination_file=os.path.join('dir_name', 'file_name')) + assert os.path.isfile(work_directory.joinpath('dir_name', 'file_name')) + + async def test_download_io_bytes(self, work_directory, downloadable): + file = BytesIO() + await downloadable.download(destination_file=file) + assert len(file.read()) != 0 + + async def test_download_io_bytes_backward_compatibility(self, work_directory, downloadable): + file = BytesIO() + await downloadable.download(destination_dir=file) + assert len(file.read()) != 0