From 998d6c37428ee4118b4c28116dd57480769169dd Mon Sep 17 00:00:00 2001 From: Vitaly312 Date: Sat, 28 Feb 2026 14:54:30 +0300 Subject: [PATCH] feat: add TTL to MemoryMediaGroupAggregator --- aiogram/dispatcher/middlewares/media_group.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index eca9f931..7ca2ac6b 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -91,7 +91,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): await self.redis.delete(self.get_group_lock_key(media_group_id)) async def get_group(self, media_group_id: str) -> list[Message]: - result = await self.redis.lrange(self.get_group_key(media_group_id), 0, -1) + result = await cast( + Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1) + ) return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) async def delete_group(self, media_group_id: str) -> None: @@ -113,9 +115,25 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): self.last_message_timers: dict[str, float] = {} self.locks: dict[str, bool] = {} + def remove_expired_objects(self) -> None: + expired_group_ids = [] + current_time = time.time() + for group_id, last_message_time in self.last_message_timers.items(): + if current_time - last_message_time > TTL_SEC: + expired_group_ids.append(group_id) + else: + break # the list is sorted in ascending order + # because python 3.7+ save dict in insertion order + for group_id in expired_group_ids: + self.groups.pop(group_id, None) + self.last_message_timers.pop(group_id, None) + self.locks.pop(group_id, None) + async def add_into_group(self, media_group_id: str, media: Message) -> int: + self.remove_expired_objects() if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]): self.groups[media_group_id].append(media) + self.last_message_timers.pop(media_group_id, None) self.last_message_timers[media_group_id] = time.time() return len(self.groups[media_group_id]) @@ -145,6 +163,9 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): media_group_aggregator: BaseMediaGroupAggregator | None = None, delay: float = DELAY_SEC, ) -> None: + """ + :param delay: delay between last received message in media group and processing it + """ self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator() self.delay = delay