From a8bd68eb358bbe3c26f1dc0966a6260d27989cd2 Mon Sep 17 00:00:00 2001 From: Vitaly312 Date: Sat, 28 Feb 2026 16:43:54 +0300 Subject: [PATCH] feat: add MediaGroupFilter --- aiogram/filters/media_group.py | 60 ++++++++++++++++++ .../test_middlewares/test_media_group.py | 21 +++++-- tests/test_filters/test_media_group.py | 61 +++++++++++++++++++ 3 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 aiogram/filters/media_group.py create mode 100644 tests/test_filters/test_media_group.py diff --git a/aiogram/filters/media_group.py b/aiogram/filters/media_group.py new file mode 100644 index 00000000..bc146e01 --- /dev/null +++ b/aiogram/filters/media_group.py @@ -0,0 +1,60 @@ +from typing import Any, Literal + +from aiogram.filters.base import Filter +from aiogram.types import Message + +MIN_MEDIA_COUNT = 2 +DEFAULT_MAX_MEDIA_COUNT = 10 + + +class MediaGroupFilter(Filter): + """ + This filter helps to handle media groups. + + Works only with :class:`aiogram.types.message.Message` events which have the :code:`album` + in the handler context. + """ + + __slots__ = ("min_media_count", "max_media_count") + + def __init__( + self, + count: int | None = None, + min_media_count: int | None = None, + max_media_count: int | None = None, + ): + """ + :param count: expected count of media in the group. + :param min_media_count: min count of media in the group, inclusively + :param max_media_count: max count of media in the group, inclusively + """ + if count is None: + min_media_count = min_media_count or MIN_MEDIA_COUNT + max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT + else: + if min_media_count is not None or max_media_count is not None: + raise ValueError( + "count and min_media_count or max_media_count can not be used together" + ) + if count < MIN_MEDIA_COUNT: + raise ValueError(f"count should be greater or equal to {MIN_MEDIA_COUNT}") + min_media_count = max_media_count = count + if min_media_count < MIN_MEDIA_COUNT: + raise ValueError(f"min_media_count should be greater or equal to {MIN_MEDIA_COUNT}") + if max_media_count < min_media_count: + raise ValueError("max_media_count should be greater or equal to min_media_count") + self.min_media_count = min_media_count + self.max_media_count = max_media_count + + def __str__(self) -> str: + return self._signature_to_string( + min_media_count=self.min_media_count, max_media_count=self.max_media_count + ) + + async def __call__( + self, message: Message, album: list[Message] = None + ) -> Literal[False] | dict[str, Any]: + media_count = len(album or []) + if not (self.min_media_count <= media_count <= self.max_media_count): + return False + return {"media_count": media_count} diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py index 07ce6941..f9fd7373 100644 --- a/tests/test_dispatcher/test_middlewares/test_media_group.py +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -6,20 +6,22 @@ from datetime import datetime from typing import Any import pytest + class TestMediaGroupAggregatorMiddleware: def _get_message(self, message_id: int, **kwargs): chat = Chat(id=1, type="private", title="Test") return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs) - def get_middleware(self): return MediaGroupAggregatorMiddleware(delay=0.1) async def test_skip_non_media_group(self): is_called = False + async def next_handler(*args, **kwargs): nonlocal is_called is_called = True + await self.get_middleware()(next_handler, self._get_message(1), {}) assert is_called @@ -27,13 +29,15 @@ class TestMediaGroupAggregatorMiddleware: middleware = self.get_middleware() counter = 0 album = None + async def next_handler(_, data: dict[str, Any]): nonlocal counter, album counter += 1 album = data.get("album") + await asyncio.gather( middleware(next_handler, self._get_message(1, media_group_id="42"), {}), - middleware(next_handler, self._get_message(2, media_group_id="42"), {}) + middleware(next_handler, self._get_message(2, media_group_id="42"), {}), ) assert album is not None assert len(album) == 2 @@ -42,12 +46,14 @@ class TestMediaGroupAggregatorMiddleware: async def test_propagate_first_media_in_album(self): middleware = self.get_middleware() first_message = None + async def next_handler(message: Message, _): nonlocal first_message first_message = message + await asyncio.gather( middleware(next_handler, self._get_message(2, media_group_id="42"), {}), - middleware(next_handler, self._get_message(1, media_group_id="42"), {}) + middleware(next_handler, self._get_message(1, media_group_id="42"), {}), ) assert isinstance(first_message, Message) assert first_message.message_id == 1 @@ -56,13 +62,15 @@ class TestMediaGroupAggregatorMiddleware: middleware = self.get_middleware() counter = 0 albums = [] + async def next_handler(_, data: dict[str, Any]): nonlocal counter, albums counter += 1 albums.append(data.get("album")) + await asyncio.gather( middleware(next_handler, self._get_message(1, media_group_id="1"), {}), - middleware(next_handler, self._get_message(2, media_group_id="2"), {}) + middleware(next_handler, self._get_message(2, media_group_id="2"), {}), ) assert counter == 2 assert len(albums) == 2 @@ -70,17 +78,20 @@ class TestMediaGroupAggregatorMiddleware: async def test_retry_handling(self): middleware = self.get_middleware() album = None + async def failed_handler(*args, **kwargs): raise Exception("Failed") + async def working_handler(_, data: dict[str, Any]): nonlocal album album = data.get("album") + first_message = self._get_message(1, media_group_id="42") second_message = self._get_message(2, media_group_id="42") with pytest.raises(Exception): await asyncio.gather( middleware(failed_handler, first_message, {}), - middleware(failed_handler, second_message, {}) + middleware(failed_handler, second_message, {}), ) await middleware(working_handler, first_message, {}) assert len(album) == 2 diff --git a/tests/test_filters/test_media_group.py b/tests/test_filters/test_media_group.py new file mode 100644 index 00000000..e2471210 --- /dev/null +++ b/tests/test_filters/test_media_group.py @@ -0,0 +1,61 @@ +from aiogram.filters.media_group import MediaGroupFilter, MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT +import pytest +import datetime +from aiogram.types import Message, Chat + + +class TestMediaGroupFilter: + @pytest.mark.parametrize( + "args,min_count,max_count", + [ + ((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT), + ((3,), 3, 3), + ((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT), + ((None, None, 3), MIN_MEDIA_COUNT, 3), + ], + ) + def test_init_range(self, args, min_count, max_count): + filter = MediaGroupFilter(*args) + assert filter.max_media_count == max_count + assert filter.min_media_count == min_count + + @pytest.mark.parametrize( + "count,min_count,max_count", + [ + (1, None, 1), + (1, 1, None), + (None, 1, None), + (None, None, 1), + (1, None, None), + (None, 5, 3), + ], + ) + def test_raise_error(self, count, min_count, max_count): + with pytest.raises(ValueError): + MediaGroupFilter(count, min_count, max_count) + + @pytest.mark.parametrize( + "min_count,max_count,media_count,result", + [ + [2, 2, 1, False], + [2, 2, 2, True], + [2, 2, 3, False], + [2, 5, 2, True], + [2, 5, 5, True], + [2, 5, 6, False], + ], + ) + async def test_call(self, min_count, max_count, media_count, result): + filter = MediaGroupFilter(min_media_count=min_count, max_media_count=max_count) + album = [ + Message( + message_id=i, + date=datetime.datetime.now(), + chat=Chat(id=42, type="private"), + ) + for i in range(media_count) + ] + response = await filter(album[0], album) + assert bool(response) is result + if result: + assert response.get("media_count") == media_count