From 6f1206b75176b4af744d4ba61f428472d0fbc948 Mon Sep 17 00:00:00 2001 From: darksidecat Date: Mon, 30 Aug 2021 18:03:50 +0300 Subject: [PATCH] add backward compatibility --- aiogram/types/mixins.py | 22 ++++++++++++++-------- tests/types/test_mixins.py | 35 +++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/aiogram/types/mixins.py b/aiogram/types/mixins.py index ba635c61..aee06308 100644 --- a/aiogram/types/mixins.py +++ b/aiogram/types/mixins.py @@ -27,7 +27,7 @@ class Downloadable: At most one of these parameters can be used: :param destination_dir:, :param destination_file: - :param destination: deprecated, alias for :param destination_dir: + :param destination: deprecated, use :param destination_dir: or :param destination_file: :param timeout: Integer :param chunk_size: Integer :param seek: Boolean - go to start of file when downloading is finished. @@ -38,11 +38,11 @@ class Downloadable: """ if destination: warn_deprecated("destination parameter is deprecated, please use destination_dir.") - destination_dir = destination if destination_dir and destination_file: raise ValueError("Use only one of the parameters: destination_dir or destination_file.") file, destination = await self._prepare_destination( + destination, destination_dir, destination_file, make_dirs @@ -56,16 +56,22 @@ class Downloadable: seek=seek, ) - async def _prepare_destination(self, destination_dir, destination_file, make_dirs): + async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs): file = await self.get_file() - if destination_dir is None and destination_file is None: + if not(any([dest, destination_dir, destination_file])): destination = file.file_path + elif dest: # backward compatibility + if isinstance(dest, IOBase): + return file, dest + if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest): + destination = os.path.join(dest, file.file_path) + else: + destination = dest + elif destination_dir: - if isinstance(destination_dir, IOBase): # for backward compatibility - return file, destination_dir - elif isinstance(destination_dir, (str, pathlib.Path)): + if isinstance(destination_dir, (str, pathlib.Path)): destination = os.path.join(destination_dir, file.file_path) else: raise TypeError("destination_dir must be str or pathlib.Path") @@ -77,7 +83,7 @@ class Downloadable: else: raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type") - if make_dirs: + if make_dirs and os.path.dirname(destination) != '': os.makedirs(os.path.dirname(destination), exist_ok=True) return file, destination diff --git a/tests/types/test_mixins.py b/tests/types/test_mixins.py index ba8e461d..4860e76e 100644 --- a/tests/types/test_mixins.py +++ b/tests/types/test_mixins.py @@ -60,10 +60,30 @@ class TestDownloadable: await downloadable.download(make_dirs=True) assert os.path.isfile(work_directory.joinpath(FILE["file_path"])) - async def test_download_warning(self, work_directory, downloadable): + async def test_download_deprecation_warning(self, work_directory, downloadable): with pytest.deprecated_call(): - await downloadable.download("test") - assert os.path.isfile(work_directory.joinpath('test', FILE["file_path"])) + await downloadable.download("test.file") + + async def test_download_destination(self, work_directory, downloadable): + with pytest.deprecated_call(): + await downloadable.download("test.file") + assert os.path.isfile(work_directory.joinpath('test.file')) + + async def test_download_destination_dir_exist(self, work_directory, downloadable): + os.mkdir("test_folder") + with pytest.deprecated_call(): + await downloadable.download("test_folder") + assert os.path.isfile(work_directory.joinpath('test_folder', FILE["file_path"])) + + async def test_download_destination_with_dir(self, work_directory, downloadable): + with pytest.deprecated_call(): + await downloadable.download(os.path.join('dir_name', 'file_name')) + assert os.path.isfile(work_directory.joinpath(os.path.join('dir_name', 'file_name'))) + + async def test_download_destination_io_bytes(self, work_directory, downloadable): + file = BytesIO() + await downloadable.download(file) + assert len(file.read()) != 0 async def test_download_raise_value_error(self, work_directory, downloadable): with pytest.raises(ValueError): @@ -74,6 +94,10 @@ class TestDownloadable: assert os.path.isfile(work_directory.joinpath('test_dir', FILE["file_path"])) async def test_download_destination_file(self, work_directory, downloadable): + await downloadable.download(destination_file='file_name') + assert os.path.isfile(work_directory.joinpath('file_name')) + + async def test_download_destination_file_with_dir(self, work_directory, downloadable): await downloadable.download(destination_file=os.path.join('dir_name', 'file_name')) assert os.path.isfile(work_directory.joinpath('dir_name', 'file_name')) @@ -81,8 +105,3 @@ class TestDownloadable: file = BytesIO() await downloadable.download(destination_file=file) assert len(file.read()) != 0 - - async def test_download_io_bytes_backward_compatibility(self, work_directory, downloadable): - file = BytesIO() - await downloadable.download(destination_dir=file) - assert len(file.read()) != 0