diff --git a/CHANGES/1697.feature.rst b/CHANGES/1697.feature.rst new file mode 100644 index 00000000..99cf7cc4 --- /dev/null +++ b/CHANGES/1697.feature.rst @@ -0,0 +1 @@ +Added ``MediaGroupAggregatorMiddleware`` to aggregate media groups into a single event with an album list in the handler data. diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py new file mode 100644 index 00000000..63765949 --- /dev/null +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -0,0 +1,246 @@ +import asyncio +import time +import uuid +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, cast + +from aiogram import Bot +from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.types import Message, TelegramObject + +if TYPE_CHECKING: + from redis.asyncio.client import Redis + +DELAY_SEC = 1.0 +LOCK_TTL_SEC = 30 +TTL_SEC = 600 + + +class BaseMediaGroupAggregator(ABC): + @abstractmethod + async def add_into_group(self, media_group_id: str, media: Message) -> int: + pass + + @abstractmethod + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: + pass + + @abstractmethod + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + pass + + @abstractmethod + async def get_group(self, media_group_id: str) -> list[Message]: + pass + + @abstractmethod + async def delete_group(self, media_group_id: str) -> None: + pass + + @abstractmethod + async def get_last_message_time(self, media_group_id: str) -> float | None: + pass + + @staticmethod + def deduplicate_messages(messages: list[Message]) -> list[Message]: + message_ids = set() + result = [] + for message in messages: + if message.message_id in message_ids: + continue + result.append(message) + message_ids.add(message.message_id) + return result + + @abstractmethod + async def get_current_time(self) -> float: + pass + + +class RedisMediaGroupAggregator(BaseMediaGroupAggregator): + """ + Aggregates media groups in Redis. + """ + + redis: "Redis" + + def __init__( + self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC + ) -> None: + """ + :param ttl_sec: ttl for media group data in seconds + :param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the + lock from expiring before the handler finishes, but small enough to expire before + Telegram retries a failed delivery. + """ + self.redis = redis + self.ttl_sec = ttl_sec + self.lock_ttl_sec = lock_ttl_sec + + @staticmethod + def get_group_key(media_group_id: str) -> str: + return f"media_group:{media_group_id}:album" + + @staticmethod + def get_last_message_time_key(media_group_id: str) -> str: + return f"media_group:{media_group_id}:last_message_time" + + @staticmethod + def get_group_lock_key(media_group_id: str) -> str: + return f"media_group:{media_group_id}:lock" + + async def add_into_group(self, media_group_id: str, media: Message) -> int: + current_time = await self.get_current_time() + async with self.redis.pipeline(transaction=True) as pipe: + pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec) + pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json()) + pipe.expire(self.get_group_key(media_group_id), self.ttl_sec) + res = await pipe.execute() + return cast(int, res[1]) + + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: + return cast( + bool, + await self.redis.set( + self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec + ), + ) + + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + release_script = ( + 'if redis.call("get", KEYS[1]) == ARGV[1] then ' + 'return redis.call("del", KEYS[1]) ' + "else return 0 end" + ) + await cast( + Awaitable[int], + self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id), + ) + + async def get_group(self, media_group_id: str) -> list[Message]: + result = await cast( + Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1) + ) + return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) + + async def delete_group(self, media_group_id: str) -> None: + async with self.redis.pipeline(transaction=True) as pipe: + pipe.delete(self.get_group_key(media_group_id)) + pipe.delete(self.get_last_message_time_key(media_group_id)) + await pipe.execute() + + async def get_last_message_time(self, media_group_id: str) -> float | None: + result = await self.redis.get(self.get_last_message_time_key(media_group_id)) + if result is None: + return None + return float(result) + + async def get_current_time(self) -> float: + seconds, microseconds = cast(tuple[int, int], await self.redis.time()) + return seconds + microseconds / 1e6 + + +class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): + def __init__(self, ttl_sec: int = TTL_SEC) -> None: + self.groups: dict[str, list[Message]] = defaultdict(list) + self.last_message_timers: dict[str, float] = {} + self.locks: dict[str, str] = {} + self.ttl_sec = ttl_sec + + def remove_expired_objects(self) -> None: + expired_group_ids = [] + current_time = time.monotonic() + for group_id, last_message_time in self.last_message_timers.items(): + if current_time - last_message_time > self.ttl_sec: + expired_group_ids.append(group_id) + else: + break # the list is sorted in ascending order + # because python 3.7+ save dict in insertion order + for group_id in expired_group_ids: + self.groups.pop(group_id, None) + self.last_message_timers.pop(group_id, None) + self.locks.pop(group_id, None) + + async def add_into_group(self, media_group_id: str, media: Message) -> int: + self.remove_expired_objects() + if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]): + self.groups[media_group_id].append(media) + self.last_message_timers.pop(media_group_id, None) + self.last_message_timers[media_group_id] = time.monotonic() + return len(self.groups[media_group_id]) + + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: + if self.locks.get(media_group_id) is not None: + return False + self.locks[media_group_id] = lock_id + return True + + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + if self.locks.get(media_group_id) == lock_id: + self.locks.pop(media_group_id) + + async def get_group(self, media_group_id: str) -> list[Message]: + return self.groups.get(media_group_id, []) + + async def delete_group(self, media_group_id: str) -> None: + self.groups.pop(media_group_id, None) + self.last_message_timers.pop(media_group_id, None) + + async def get_last_message_time(self, media_group_id: str) -> float | None: + return self.last_message_timers.get(media_group_id) + + async def get_current_time(self) -> float: + return time.monotonic() + + +class MediaGroupAggregatorMiddleware(BaseMiddleware): + def __init__( + self, + media_group_aggregator: BaseMediaGroupAggregator | None = None, + delay: float = DELAY_SEC, + ) -> None: + """ + :param delay: delay between last received message in media group and processing it + """ + self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator() + self.delay = delay + + async def __call__( + self, + handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]], + event: TelegramObject, + data: dict[str, Any], + ) -> Any: + if not isinstance(event, Message) or not event.media_group_id: + return await handler(event, data) + await self.media_group_aggregator.add_into_group(event.media_group_id, event) + lock_id = str(uuid.uuid4()) + if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id): + return None + try: + while True: + last_message_time = await self.media_group_aggregator.get_last_message_time( + event.media_group_id + ) + if not last_message_time: + return None + delta = self.delay - ( + await self.media_group_aggregator.get_current_time() - last_message_time + ) + if delta <= 0: + album = await self.media_group_aggregator.get_group(event.media_group_id) + if not album: + return None + album = sorted( + (msg.as_(cast(Bot, data.get("bot"))) for msg in album), + key=lambda msg: msg.message_id, + ) + data.update(album=album) + result = await handler(album[0], data) + await self.media_group_aggregator.delete_group(event.media_group_id) + return result + await asyncio.sleep(delta) + finally: + await self.media_group_aggregator.release_lock(event.media_group_id, lock_id) diff --git a/docs/utils/media_group.rst b/docs/utils/media_group.rst index c9501a66..cab5962b 100644 --- a/docs/utils/media_group.rst +++ b/docs/utils/media_group.rst @@ -1,8 +1,13 @@ -=================== -Media group builder -=================== +=========== +Media group +=========== -This module provides a builder for media groups, it can be used to build media groups +This module provides tools for media groups. + +Building media groups +===================== + +Media group builder can be used to build media groups for :class:`aiogram.types.input_media_photo.InputMediaPhoto`, :class:`aiogram.types.input_media_video.InputMediaVideo`, :class:`aiogram.types.input_media_document.InputMediaDocument` and :class:`aiogram.types.input_media_audio.InputMediaAudio`. @@ -39,8 +44,58 @@ it will be used as ``caption`` for first media in group. await bot.send_media_group(chat_id=chat_id, media=media_group.build()) +Handling media groups +===================== + +By default each media in the group is processed separately. + +You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware` +to process media groups as one. If you do, only one message from the group will be processed, and updates for +other messages with the same media group ID will be suppressed. There are two options to store media groups: + +- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default +- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment + +You also can use :class:`aiogram.filters.magic_data.MagicData` with ``F.album`` +to filter media groups. + +Usage +----- + +.. code-block:: python + + from aiogram import F + from aiogram.types import Message + + # register middleware + from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware + from aiogram.filters import MagicData + + router.message.outer_middleware(MediaGroupAggregatorMiddleware()) + + # use middleware + @router.message( + MagicData(F.album.len() <= 5), + F.caption == "album_caption" # other filters will be applied to the first message in the group + ) + async def start(message: Message, album: list[Message]): + # message is the first media in this group + # album is list of all messages with the same mediaGroupId, including current message + await message.answer( + f"You sent {len(album)} media in the group. " + f"Media group ID: {message.media_group_id}. " + f"Album messages: {', '.join(str(m.message_id) for m in album)}" + ) + References ========== .. autoclass:: aiogram.utils.media_group.MediaGroupBuilder :members: +.. autoclass:: aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware + :members: +.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator + :members: +.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator + :members: + :special-members: __init__ diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py new file mode 100644 index 00000000..ef2fd33f --- /dev/null +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -0,0 +1,227 @@ +import asyncio +import time +from datetime import datetime +from typing import Any, Awaitable, Callable +from unittest import mock + +import pytest +from redis.asyncio.client import Redis + +from aiogram.dispatcher.middlewares.media_group import ( + BaseMediaGroupAggregator, + MediaGroupAggregatorMiddleware, + MemoryMediaGroupAggregator, + RedisMediaGroupAggregator, +) +from aiogram.types import Chat, Message + + +def _get_message(message_id: int, **kwargs): + chat = Chat(id=1, type="private", title="Test") + return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs) + + +async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any: + start_sleep = asyncio.Event() + real_sleep = asyncio.sleep + + async def mock_sleep(*args, **kwargs): + start_sleep.set() + await real_sleep(0) + + with mock.patch("asyncio.sleep", mock_sleep): + task1 = func(*args, **kwargs) + await start_sleep.wait() + return task1 + + +class TestMediaGroupAggregatorMiddleware: + def get_middleware(self): + return MediaGroupAggregatorMiddleware(delay=0.1) + + async def test_skip_non_media_group(self): + is_called = False + + async def next_handler(*args, **kwargs): + nonlocal is_called + is_called = True + + await self.get_middleware()(next_handler, _get_message(1), {}) + assert is_called + + async def test_called_once_for_album(self): + middleware = self.get_middleware() + counter = 0 + album = None + + async def next_handler(_, data: dict[str, Any]): + nonlocal counter, album + counter += 1 + album = data.get("album") + + await asyncio.gather( + middleware(next_handler, _get_message(1, media_group_id="42"), {}), + middleware(next_handler, _get_message(2, media_group_id="42"), {}), + ) + assert album is not None + assert len(album) == 2 + assert counter == 1 + + async def test_bot_object_saved(self, bot): + middleware = self.get_middleware() + event = album = None + + async def next_handler(message: Message, data: dict[str, Any]): + nonlocal event, album + event = message + album = data.get("album") + + await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot}) + assert event.bot is bot + assert all(msg.bot is bot for msg in album) + + async def test_propagate_first_media_in_album(self): + middleware = self.get_middleware() + first_message = None + + async def next_handler(message: Message, _): + nonlocal first_message + first_message = message + + task1 = await wait_until_func_call_sleep( + asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {}) + ) + await middleware(next_handler, _get_message(1, media_group_id="42"), {}) + await task1 + assert isinstance(first_message, Message) + assert first_message.message_id == 1 + + @pytest.mark.parametrize("deleted_object", ["album", "last_message_time"]) + async def test_skip_propagating_if_data_deleted(self, deleted_object): + middleware = self.get_middleware() + counter = 0 + + async def next_handler(*args, **kwargs): + nonlocal counter + counter += 1 + + task1 = await wait_until_func_call_sleep( + asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {}) + ) + if deleted_object == "album": + middleware.media_group_aggregator.groups.pop("42") + else: + middleware.media_group_aggregator.last_message_timers.pop("42") + await task1 + assert counter == 0 + + async def test_different_albums_non_interfere(self): + middleware = self.get_middleware() + counter = 0 + albums = [] + + async def next_handler(_, data: dict[str, Any]): + nonlocal counter, albums + counter += 1 + albums.append(data.get("album")) + + await asyncio.gather( + middleware(next_handler, _get_message(1, media_group_id="1"), {}), + middleware(next_handler, _get_message(2, media_group_id="2"), {}), + ) + assert counter == 2 + assert len(albums) == 2 + + async def test_retry_handling(self): + middleware = self.get_middleware() + album = None + + async def failed_handler(*args, **kwargs): + raise RuntimeError("Failed") + + async def working_handler(_, data: dict[str, Any]): + nonlocal album + album = data.get("album") + + first_message = _get_message(1, media_group_id="42") + second_message = _get_message(2, media_group_id="42") + with pytest.raises(RuntimeError): + await asyncio.gather( + middleware(failed_handler, first_message, {}), + middleware(failed_handler, second_message, {}), + ) + await middleware(working_handler, first_message, {}) + assert len(album) == 2 + + +def test_message_deduplication(): + message_1, message_2 = _get_message(1), _get_message(2) + res = [message_1, message_2] + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res + assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res + + +@pytest.fixture(params=["memory", "redis"], scope="function") +async def aggregator(request): + if request.param == "memory": + yield MemoryMediaGroupAggregator() + else: + redis = Redis.from_url(request.getfixturevalue("redis_server")) + yield RedisMediaGroupAggregator(redis) + keys = await redis.keys("media_group:*") + if keys: + await redis.delete(*keys) + await redis.aclose() + + +class TestMediaGroupAggregator: + async def test_group_creating(self, aggregator: BaseMediaGroupAggregator): + msg1 = _get_message(1) + msg2 = _get_message(2) + assert await aggregator.add_into_group("42", msg1) == 1 + assert await aggregator.add_into_group("42", msg2) == 2 + assert {msg.message_id for msg in await aggregator.get_group("42")} == { + msg1.message_id, + msg2.message_id, + } + await aggregator.delete_group("42") + assert await aggregator.get_group("42") == [] + + async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator): + for i in ("key1", "key2"): + assert await aggregator.acquire_lock("42", i) + assert not await aggregator.acquire_lock("42", i) + await aggregator.release_lock("42", i) + + async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator): + await aggregator.acquire_lock("42", "key1") + await aggregator.release_lock("42", "key2") + assert not await aggregator.acquire_lock("42", "key1") + + async def test_expired_objects_removed(self): + aggregator = MemoryMediaGroupAggregator() + await aggregator.add_into_group("42", _get_message(1)) + with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1): + new_msg = _get_message(2) + await aggregator.add_into_group("24", new_msg) + assert await aggregator.get_group("42") == [] + assert await aggregator.get_group("24") == [new_msg] + + async def test_get_current_time_memory_aggregator(self): + aggregator = MemoryMediaGroupAggregator() + with mock.patch("time.monotonic", return_value=1.1): + assert await aggregator.get_current_time() == 1.1 + + async def test_get_current_time_redis_aggregator(self): + aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis)) + aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456)) + assert await aggregator.get_current_time() == 1.123456 + + async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): + assert await aggregator.get_last_message_time("42") is None + await aggregator.add_into_group("42", _get_message(1)) + time_before_second_message = await aggregator.get_current_time() + assert await aggregator.get_last_message_time("42") <= time_before_second_message + await aggregator.add_into_group("42", _get_message(2)) + assert await aggregator.get_last_message_time("42") >= time_before_second_message