mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge branch 'aiogram:dev-2.x' into dev-2.x-bot-download
This commit is contained in:
commit
962a7edb52
7 changed files with 259 additions and 46 deletions
|
|
@ -43,5 +43,5 @@ __all__ = (
|
|||
'utils',
|
||||
)
|
||||
|
||||
__version__ = '2.14.3'
|
||||
__version__ = '2.15'
|
||||
__api_version__ = '5.3'
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class RedisStorage(BaseStorage):
|
|||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._kwargs = kwargs
|
||||
|
||||
self._redis: typing.Optional[aioredis.RedisConnection] = None
|
||||
self._redis: typing.Optional["aioredis.RedisConnection"] = None
|
||||
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||
|
||||
async def close(self):
|
||||
|
|
@ -62,7 +62,7 @@ class RedisStorage(BaseStorage):
|
|||
return await self._redis.wait_closed()
|
||||
return True
|
||||
|
||||
async def redis(self) -> aioredis.RedisConnection:
|
||||
async def redis(self) -> "aioredis.RedisConnection":
|
||||
"""
|
||||
Get Redis connection
|
||||
"""
|
||||
|
|
@ -220,9 +220,9 @@ class AioRedisAdapterBase(ABC):
|
|||
pool_size: int = 10,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
prefix: str = "fsm",
|
||||
state_ttl: int = 0,
|
||||
data_ttl: int = 0,
|
||||
bucket_ttl: int = 0,
|
||||
state_ttl: typing.Optional[int] = None,
|
||||
data_ttl: typing.Optional[int] = None,
|
||||
bucket_ttl: typing.Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._host = host
|
||||
|
|
@ -247,7 +247,7 @@ class AioRedisAdapterBase(ABC):
|
|||
"""Get Redis connection."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
"""Grace shutdown."""
|
||||
pass
|
||||
|
||||
|
|
@ -257,6 +257,8 @@ class AioRedisAdapterBase(ABC):
|
|||
|
||||
async def set(self, name, value, ex=None, **kwargs):
|
||||
"""Set the value at key ``name`` to ``value``."""
|
||||
if ex == 0:
|
||||
ex = None
|
||||
return await self._redis.set(name, value, ex=ex, **kwargs)
|
||||
|
||||
async def get(self, name, **kwargs):
|
||||
|
|
@ -295,7 +297,7 @@ class AioRedisAdapterV1(AioRedisAdapterBase):
|
|||
)
|
||||
return self._redis
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
async with self._connection_lock:
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
|
|
@ -310,6 +312,8 @@ class AioRedisAdapterV1(AioRedisAdapterBase):
|
|||
return await self._redis.get(name, encoding="utf8", **kwargs)
|
||||
|
||||
async def set(self, name, value, ex=None, **kwargs):
|
||||
if ex == 0:
|
||||
ex = None
|
||||
return await self._redis.set(name, value, expire=ex, **kwargs)
|
||||
|
||||
async def keys(self, pattern, **kwargs):
|
||||
|
|
@ -331,6 +335,7 @@ class AioRedisAdapterV2(AioRedisAdapterBase):
|
|||
password=self._password,
|
||||
ssl=self._ssl,
|
||||
max_connections=self._pool_size,
|
||||
decode_responses=True,
|
||||
**self._kwargs,
|
||||
)
|
||||
return self._redis
|
||||
|
|
@ -367,9 +372,9 @@ class RedisStorage2(BaseStorage):
|
|||
pool_size: int = 10,
|
||||
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
|
||||
prefix: str = "fsm",
|
||||
state_ttl: int = 0,
|
||||
data_ttl: int = 0,
|
||||
bucket_ttl: int = 0,
|
||||
state_ttl: typing.Optional[int] = None,
|
||||
data_ttl: typing.Optional[int] = None,
|
||||
bucket_ttl: typing.Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._host = host
|
||||
|
|
@ -413,6 +418,9 @@ class RedisStorage2(BaseStorage):
|
|||
self._redis = AioRedisAdapterV1(**connection_data)
|
||||
elif redis_version == 2:
|
||||
self._redis = AioRedisAdapterV2(**connection_data)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported aioredis version: {redis_version}")
|
||||
await self._redis.get_redis()
|
||||
return self._redis
|
||||
|
||||
def generate_key(self, *parts):
|
||||
|
|
@ -420,11 +428,12 @@ class RedisStorage2(BaseStorage):
|
|||
|
||||
async def close(self):
|
||||
if self._redis:
|
||||
return self._redis.close()
|
||||
return await self._redis.close()
|
||||
|
||||
async def wait_closed(self):
|
||||
if self._redis:
|
||||
return await self._redis.wait_closed()
|
||||
await self._redis.wait_closed()
|
||||
self._redis = None
|
||||
|
||||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
|
|
@ -451,7 +460,7 @@ class RedisStorage2(BaseStorage):
|
|||
if state is None:
|
||||
await redis.delete(key)
|
||||
else:
|
||||
await redis.set(key, self.resolve_state(state), expire=self._state_ttl)
|
||||
await redis.set(key, self.resolve_state(state), ex=self._state_ttl)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
|
|
@ -459,7 +468,7 @@ class RedisStorage2(BaseStorage):
|
|||
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||
redis = await self._get_adapter()
|
||||
if data:
|
||||
await redis.set(key, json.dumps(data), expire=self._data_ttl)
|
||||
await redis.set(key, json.dumps(data), ex=self._data_ttl)
|
||||
else:
|
||||
await redis.delete(key)
|
||||
|
||||
|
|
@ -490,7 +499,7 @@ class RedisStorage2(BaseStorage):
|
|||
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||
redis = await self._get_adapter()
|
||||
if bucket:
|
||||
await redis.set(key, json.dumps(bucket), expire=self._bucket_ttl)
|
||||
await redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
|
||||
else:
|
||||
await redis.delete(key)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
import os
|
||||
import pathlib
|
||||
from io import IOBase
|
||||
from typing import Union, Optional
|
||||
|
||||
from aiogram.utils.deprecated import warn_deprecated
|
||||
|
||||
|
||||
class Downloadable:
|
||||
|
|
@ -7,32 +11,86 @@ class Downloadable:
|
|||
Mixin for files
|
||||
"""
|
||||
|
||||
async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
|
||||
async def download(
|
||||
self,
|
||||
destination=None,
|
||||
timeout=30,
|
||||
chunk_size=65536,
|
||||
seek=True,
|
||||
make_dirs=True,
|
||||
*,
|
||||
destination_dir: Optional[Union[str, pathlib.Path]] = None,
|
||||
destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None
|
||||
):
|
||||
"""
|
||||
Download file
|
||||
|
||||
:param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
|
||||
At most one of these parameters can be used: :param destination_dir:, :param destination_file:
|
||||
|
||||
:param destination: deprecated, use :param destination_dir: or :param destination_file: instead
|
||||
:param timeout: Integer
|
||||
:param chunk_size: Integer
|
||||
:param seek: Boolean - go to start of file when downloading is finished.
|
||||
:param make_dirs: Make dirs if not exist
|
||||
:param destination_dir: directory for saving files
|
||||
:param destination_file: path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
|
||||
:return: destination
|
||||
"""
|
||||
if destination:
|
||||
warn_deprecated(
|
||||
"destination parameter is deprecated, please use destination_dir or destination_file."
|
||||
)
|
||||
if destination_dir and destination_file:
|
||||
raise ValueError(
|
||||
"Use only one of the parameters: destination_dir or destination_file."
|
||||
)
|
||||
|
||||
file, destination = await self._prepare_destination(
|
||||
destination,
|
||||
destination_dir,
|
||||
destination_file,
|
||||
make_dirs
|
||||
)
|
||||
|
||||
return await self.bot.download_file(
|
||||
file_path=file.file_path,
|
||||
destination=destination,
|
||||
timeout=timeout,
|
||||
chunk_size=chunk_size,
|
||||
seek=seek,
|
||||
)
|
||||
|
||||
async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs):
|
||||
file = await self.get_file()
|
||||
|
||||
is_path = True
|
||||
if destination is None:
|
||||
if not(any((dest, destination_dir, destination_file))):
|
||||
destination = file.file_path
|
||||
elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination):
|
||||
destination = os.path.join(destination, file.file_path)
|
||||
else:
|
||||
is_path = False
|
||||
|
||||
if is_path and make_dirs:
|
||||
elif dest: # backward compatibility
|
||||
if isinstance(dest, IOBase):
|
||||
return file, dest
|
||||
if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest):
|
||||
destination = os.path.join(dest, file.file_path)
|
||||
else:
|
||||
destination = dest
|
||||
|
||||
elif destination_dir:
|
||||
if isinstance(destination_dir, (str, pathlib.Path)):
|
||||
destination = os.path.join(destination_dir, file.file_path)
|
||||
else:
|
||||
raise TypeError("destination_dir must be str or pathlib.Path")
|
||||
else:
|
||||
if isinstance(destination_file, IOBase):
|
||||
return file, destination_file
|
||||
elif isinstance(destination_file, (str, pathlib.Path)):
|
||||
destination = destination_file
|
||||
else:
|
||||
raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type")
|
||||
|
||||
if make_dirs and os.path.dirname(destination):
|
||||
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
||||
|
||||
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
|
||||
chunk_size=chunk_size, seek=seek)
|
||||
return file, destination
|
||||
|
||||
async def get_file(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ class User(base.TelegramObject):
|
|||
return getattr(self, '_locale')
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
return f"tg://user?id={self.id}"
|
||||
|
||||
def get_mention(self, name=None, as_html=None):
|
||||
def get_mention(self, name: Optional[str] = None, as_html: Optional[bool] = None) -> str:
|
||||
if as_html is None and self.bot.parse_mode and self.bot.parse_mode.lower() == 'html':
|
||||
as_html = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,53 @@
|
|||
import aioredis
|
||||
import pytest
|
||||
from _pytest.config import UsageError
|
||||
import aioredis.util
|
||||
|
||||
try:
|
||||
import aioredis.util
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--redis", default=None,
|
||||
help="run tests which require redis connection")
|
||||
parser.addoption(
|
||||
"--redis",
|
||||
default=None,
|
||||
help="run tests which require redis connection",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"redis: marked tests require redis connection to run",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
redis_uri = config.getoption("--redis")
|
||||
if redis_uri is None:
|
||||
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
|
||||
skip_redis = pytest.mark.skip(
|
||||
reason="need --redis option with redis URI to run"
|
||||
)
|
||||
for item in items:
|
||||
if "redis" in item.keywords:
|
||||
item.add_marker(skip_redis)
|
||||
return
|
||||
|
||||
redis_version = int(aioredis.__version__.split(".")[0])
|
||||
options = None
|
||||
if redis_version == 1:
|
||||
(host, port), options = aioredis.util.parse_url(redis_uri)
|
||||
options.update({'host': host, 'port': port})
|
||||
elif redis_version == 2:
|
||||
try:
|
||||
options = aioredis.connection.parse_url(redis_uri)
|
||||
except ValueError as e:
|
||||
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
|
||||
|
||||
try:
|
||||
address, options = aioredis.util.parse_url(redis_uri)
|
||||
assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo."
|
||||
assert isinstance(options, dict), \
|
||||
"Only redis and rediss schemas are supported, eg redis://foo."
|
||||
except AssertionError as e:
|
||||
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
|
||||
|
||||
|
|
@ -30,6 +55,20 @@ def pytest_collection_modifyitems(config, items):
|
|||
@pytest.fixture(scope='session')
|
||||
def redis_options(request):
|
||||
redis_uri = request.config.getoption("--redis")
|
||||
(host, port), options = aioredis.util.parse_url(redis_uri)
|
||||
options.update({'host': host, 'port': port})
|
||||
return options
|
||||
if redis_uri is None:
|
||||
pytest.skip("need --redis option with redis URI to run")
|
||||
return
|
||||
|
||||
redis_version = int(aioredis.__version__.split(".")[0])
|
||||
if redis_version == 1:
|
||||
(host, port), options = aioredis.util.parse_url(redis_uri)
|
||||
options.update({'host': host, 'port': port})
|
||||
return options
|
||||
|
||||
if redis_version == 2:
|
||||
try:
|
||||
return aioredis.connection.parse_url(redis_uri)
|
||||
except ValueError as e:
|
||||
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
|
||||
|
||||
raise UsageError("Unsupported aioredis version")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
import aioredis
|
||||
import pytest
|
||||
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2, RedisStorage
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage, RedisStorage2
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.mark.redis
|
||||
async def redis_store(redis_options):
|
||||
if int(aioredis.__version__.split(".")[0]) == 2:
|
||||
pytest.skip('aioredis v2 is not supported.')
|
||||
return
|
||||
s = RedisStorage(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
|
|
@ -37,9 +41,9 @@ async def memory_store():
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
pytest.lazy_fixture('redis_store'),
|
||||
pytest.lazy_fixture('redis_store2'),
|
||||
pytest.lazy_fixture('memory_store'),
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
lazy_fixture('memory_store'),
|
||||
]
|
||||
)
|
||||
class TestStorage:
|
||||
|
|
@ -63,8 +67,8 @@ class TestStorage:
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"store", [
|
||||
pytest.lazy_fixture('redis_store'),
|
||||
pytest.lazy_fixture('redis_store2'),
|
||||
lazy_fixture('redis_store'),
|
||||
lazy_fixture('redis_store2'),
|
||||
]
|
||||
)
|
||||
class TestRedisStorage2:
|
||||
|
|
@ -74,6 +78,7 @@ class TestRedisStorage2:
|
|||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
pool_id = id(store._redis)
|
||||
await store.close()
|
||||
await store.wait_closed()
|
||||
assert await store.get_data(chat='1234') == {
|
||||
'foo': 'bar'} # new pool was opened at this point
|
||||
assert id(store._redis) != pool_id
|
||||
|
|
|
|||
102
tests/types/test_mixins.py
Normal file
102
tests/types/test_mixins.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.types import File
|
||||
from aiogram.types.mixins import Downloadable
|
||||
from tests import TOKEN
|
||||
from tests.types.dataset import FILE
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(name='bot')
|
||||
async def bot_fixture():
|
||||
""" Bot fixture """
|
||||
_bot = Bot(TOKEN)
|
||||
yield _bot
|
||||
await _bot.session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmppath(tmpdir, request):
|
||||
os.chdir(tmpdir)
|
||||
yield Path(tmpdir)
|
||||
os.chdir(request.config.invocation_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def downloadable(bot):
|
||||
async def get_file():
|
||||
return File(**FILE)
|
||||
|
||||
downloadable = Downloadable()
|
||||
downloadable.get_file = get_file
|
||||
downloadable.bot = bot
|
||||
|
||||
return downloadable
|
||||
|
||||
|
||||
class TestDownloadable:
|
||||
async def test_download_make_dirs_false_nodir(self, tmppath, downloadable):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await downloadable.download(make_dirs=False)
|
||||
|
||||
async def test_download_make_dirs_false_mkdir(self, tmppath, downloadable):
|
||||
os.mkdir('voice')
|
||||
await downloadable.download(make_dirs=False)
|
||||
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
|
||||
|
||||
async def test_download_make_dirs_true(self, tmppath, downloadable):
|
||||
await downloadable.download(make_dirs=True)
|
||||
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
|
||||
|
||||
async def test_download_deprecation_warning(self, tmppath, downloadable):
|
||||
with pytest.deprecated_call():
|
||||
await downloadable.download("test.file")
|
||||
|
||||
async def test_download_destination(self, tmppath, downloadable):
|
||||
with pytest.deprecated_call():
|
||||
await downloadable.download("test.file")
|
||||
assert os.path.isfile(tmppath.joinpath('test.file'))
|
||||
|
||||
async def test_download_destination_dir_exist(self, tmppath, downloadable):
|
||||
os.mkdir("test_folder")
|
||||
with pytest.deprecated_call():
|
||||
await downloadable.download("test_folder")
|
||||
assert os.path.isfile(tmppath.joinpath('test_folder', FILE["file_path"]))
|
||||
|
||||
async def test_download_destination_with_dir(self, tmppath, downloadable):
|
||||
with pytest.deprecated_call():
|
||||
await downloadable.download(os.path.join('dir_name', 'file_name'))
|
||||
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
|
||||
|
||||
async def test_download_destination_io_bytes(self, tmppath, downloadable):
|
||||
file = BytesIO()
|
||||
with pytest.deprecated_call():
|
||||
await downloadable.download(file)
|
||||
assert len(file.read()) != 0
|
||||
|
||||
async def test_download_raise_value_error(self, tmppath, downloadable):
|
||||
with pytest.raises(ValueError):
|
||||
await downloadable.download(destination_dir="a", destination_file="b")
|
||||
|
||||
async def test_download_destination_dir(self, tmppath, downloadable):
|
||||
await downloadable.download(destination_dir='test_dir')
|
||||
assert os.path.isfile(tmppath.joinpath('test_dir', FILE["file_path"]))
|
||||
|
||||
async def test_download_destination_file(self, tmppath, downloadable):
|
||||
await downloadable.download(destination_file='file_name')
|
||||
assert os.path.isfile(tmppath.joinpath('file_name'))
|
||||
|
||||
async def test_download_destination_file_with_dir(self, tmppath, downloadable):
|
||||
await downloadable.download(destination_file=os.path.join('dir_name', 'file_name'))
|
||||
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
|
||||
|
||||
async def test_download_io_bytes(self, tmppath, downloadable):
|
||||
file = BytesIO()
|
||||
await downloadable.download(destination_file=file)
|
||||
assert len(file.read()) != 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue