mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat: add MediaGroupFilter
This commit is contained in:
parent
998d6c3742
commit
a8bd68eb35
3 changed files with 137 additions and 5 deletions
60
aiogram/filters/media_group.py
Normal file
60
aiogram/filters/media_group.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import Message
|
||||
|
||||
MIN_MEDIA_COUNT = 2
|
||||
DEFAULT_MAX_MEDIA_COUNT = 10
|
||||
|
||||
|
||||
class MediaGroupFilter(Filter):
|
||||
"""
|
||||
This filter helps to handle media groups.
|
||||
|
||||
Works only with :class:`aiogram.types.message.Message` events which have the :code:`album`
|
||||
in the handler context.
|
||||
"""
|
||||
|
||||
__slots__ = ("min_media_count", "max_media_count")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
count: int | None = None,
|
||||
min_media_count: int | None = None,
|
||||
max_media_count: int | None = None,
|
||||
):
|
||||
"""
|
||||
:param count: expected count of media in the group.
|
||||
:param min_media_count: min count of media in the group, inclusively
|
||||
:param max_media_count: max count of media in the group, inclusively
|
||||
"""
|
||||
if count is None:
|
||||
min_media_count = min_media_count or MIN_MEDIA_COUNT
|
||||
max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT
|
||||
else:
|
||||
if min_media_count is not None or max_media_count is not None:
|
||||
raise ValueError(
|
||||
"count and min_media_count or max_media_count can not be used together"
|
||||
)
|
||||
if count < MIN_MEDIA_COUNT:
|
||||
raise ValueError(f"count should be greater or equal to {MIN_MEDIA_COUNT}")
|
||||
min_media_count = max_media_count = count
|
||||
if min_media_count < MIN_MEDIA_COUNT:
|
||||
raise ValueError(f"min_media_count should be greater or equal to {MIN_MEDIA_COUNT}")
|
||||
if max_media_count < min_media_count:
|
||||
raise ValueError("max_media_count should be greater or equal to min_media_count")
|
||||
self.min_media_count = min_media_count
|
||||
self.max_media_count = max_media_count
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._signature_to_string(
|
||||
min_media_count=self.min_media_count, max_media_count=self.max_media_count
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self, message: Message, album: list[Message] = None
|
||||
) -> Literal[False] | dict[str, Any]:
|
||||
media_count = len(album or [])
|
||||
if not (self.min_media_count <= media_count <= self.max_media_count):
|
||||
return False
|
||||
return {"media_count": media_count}
|
||||
|
|
@ -6,20 +6,22 @@ from datetime import datetime
|
|||
from typing import Any
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMediaGroupAggregatorMiddleware:
|
||||
def _get_message(self, message_id: int, **kwargs):
|
||||
chat = Chat(id=1, type="private", title="Test")
|
||||
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
||||
|
||||
|
||||
def get_middleware(self):
|
||||
return MediaGroupAggregatorMiddleware(delay=0.1)
|
||||
|
||||
async def test_skip_non_media_group(self):
|
||||
is_called = False
|
||||
|
||||
async def next_handler(*args, **kwargs):
|
||||
nonlocal is_called
|
||||
is_called = True
|
||||
|
||||
await self.get_middleware()(next_handler, self._get_message(1), {})
|
||||
assert is_called
|
||||
|
||||
|
|
@ -27,13 +29,15 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
middleware = self.get_middleware()
|
||||
counter = 0
|
||||
album = None
|
||||
|
||||
async def next_handler(_, data: dict[str, Any]):
|
||||
nonlocal counter, album
|
||||
counter += 1
|
||||
album = data.get("album")
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {})
|
||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
||||
)
|
||||
assert album is not None
|
||||
assert len(album) == 2
|
||||
|
|
@ -42,12 +46,14 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
async def test_propagate_first_media_in_album(self):
|
||||
middleware = self.get_middleware()
|
||||
first_message = None
|
||||
|
||||
async def next_handler(message: Message, _):
|
||||
nonlocal first_message
|
||||
first_message = message
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {})
|
||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
||||
)
|
||||
assert isinstance(first_message, Message)
|
||||
assert first_message.message_id == 1
|
||||
|
|
@ -56,13 +62,15 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
middleware = self.get_middleware()
|
||||
counter = 0
|
||||
albums = []
|
||||
|
||||
async def next_handler(_, data: dict[str, Any]):
|
||||
nonlocal counter, albums
|
||||
counter += 1
|
||||
albums.append(data.get("album"))
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(1, media_group_id="1"), {}),
|
||||
middleware(next_handler, self._get_message(2, media_group_id="2"), {})
|
||||
middleware(next_handler, self._get_message(2, media_group_id="2"), {}),
|
||||
)
|
||||
assert counter == 2
|
||||
assert len(albums) == 2
|
||||
|
|
@ -70,17 +78,20 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
async def test_retry_handling(self):
|
||||
middleware = self.get_middleware()
|
||||
album = None
|
||||
|
||||
async def failed_handler(*args, **kwargs):
|
||||
raise Exception("Failed")
|
||||
|
||||
async def working_handler(_, data: dict[str, Any]):
|
||||
nonlocal album
|
||||
album = data.get("album")
|
||||
|
||||
first_message = self._get_message(1, media_group_id="42")
|
||||
second_message = self._get_message(2, media_group_id="42")
|
||||
with pytest.raises(Exception):
|
||||
await asyncio.gather(
|
||||
middleware(failed_handler, first_message, {}),
|
||||
middleware(failed_handler, second_message, {})
|
||||
middleware(failed_handler, second_message, {}),
|
||||
)
|
||||
await middleware(working_handler, first_message, {})
|
||||
assert len(album) == 2
|
||||
|
|
|
|||
61
tests/test_filters/test_media_group.py
Normal file
61
tests/test_filters/test_media_group.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from aiogram.filters.media_group import MediaGroupFilter, MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT
|
||||
import pytest
|
||||
import datetime
|
||||
from aiogram.types import Message, Chat
|
||||
|
||||
|
||||
class TestMediaGroupFilter:
|
||||
@pytest.mark.parametrize(
|
||||
"args,min_count,max_count",
|
||||
[
|
||||
((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT),
|
||||
((3,), 3, 3),
|
||||
((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT),
|
||||
((None, None, 3), MIN_MEDIA_COUNT, 3),
|
||||
],
|
||||
)
|
||||
def test_init_range(self, args, min_count, max_count):
|
||||
filter = MediaGroupFilter(*args)
|
||||
assert filter.max_media_count == max_count
|
||||
assert filter.min_media_count == min_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,min_count,max_count",
|
||||
[
|
||||
(1, None, 1),
|
||||
(1, 1, None),
|
||||
(None, 1, None),
|
||||
(None, None, 1),
|
||||
(1, None, None),
|
||||
(None, 5, 3),
|
||||
],
|
||||
)
|
||||
def test_raise_error(self, count, min_count, max_count):
|
||||
with pytest.raises(ValueError):
|
||||
MediaGroupFilter(count, min_count, max_count)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"min_count,max_count,media_count,result",
|
||||
[
|
||||
[2, 2, 1, False],
|
||||
[2, 2, 2, True],
|
||||
[2, 2, 3, False],
|
||||
[2, 5, 2, True],
|
||||
[2, 5, 5, True],
|
||||
[2, 5, 6, False],
|
||||
],
|
||||
)
|
||||
async def test_call(self, min_count, max_count, media_count, result):
|
||||
filter = MediaGroupFilter(min_media_count=min_count, max_media_count=max_count)
|
||||
album = [
|
||||
Message(
|
||||
message_id=i,
|
||||
date=datetime.datetime.now(),
|
||||
chat=Chat(id=42, type="private"),
|
||||
)
|
||||
for i in range(media_count)
|
||||
]
|
||||
response = await filter(album[0], album)
|
||||
assert bool(response) is result
|
||||
if result:
|
||||
assert response.get("media_count") == media_count
|
||||
Loading…
Add table
Add a link
Reference in a new issue