diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index 9ba860b1..eca9f931 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -26,13 +26,32 @@ class BaseMediaGroupAggregator(ABC): raise NotImplementedError @abstractmethod - async def get_and_delete_group(self, media_group_id: str) -> list[Message]: + async def release_lock(self, media_group_id: str) -> None: + raise NotImplementedError + + @abstractmethod + async def get_group(self, media_group_id: str) -> list[Message]: + raise NotImplementedError + + @abstractmethod + async def delete_group(self, media_group_id: str) -> None: raise NotImplementedError @abstractmethod async def get_last_message_time(self, media_group_id: str) -> float | None: raise NotImplementedError + @staticmethod + def deduplicate_messages(messages: list[Message]) -> list[Message]: + message_ids = set() + result = [] + for message in messages: + if message.message_id in message_ids: + continue + result.append(message) + message_ids.add(message.message_id) + return result + class RedisMediaGroupAggregator(BaseMediaGroupAggregator): redis: "Redis" @@ -68,13 +87,18 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): ), ) - async def get_and_delete_group(self, media_group_id: str) -> list[Message]: + async def release_lock(self, media_group_id: str) -> None: + await self.redis.delete(self.get_group_lock_key(media_group_id)) + + async def get_group(self, media_group_id: str) -> list[Message]: + result = await self.redis.lrange(self.get_group_key(media_group_id), 0, -1) + return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) + + async def delete_group(self, media_group_id: str) -> None: async with self.redis.pipeline(transaction=True) as pipe: - pipe.lrange(self.get_group_key(media_group_id), 0, -1) pipe.delete(self.get_group_key(media_group_id)) pipe.delete(self.get_last_message_time_key(media_group_id)) - result = await pipe.execute() - return [Message.model_validate_json(msg) for msg in result[0]] + await pipe.execute() async def get_last_message_time(self, media_group_id: str) -> float | None: result = await self.redis.get(self.get_last_message_time_key(media_group_id)) @@ -90,7 +114,8 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): self.locks: dict[str, bool] = {} async def add_into_group(self, media_group_id: str, media: Message) -> int: - self.groups[media_group_id].append(media) + if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]): + self.groups[media_group_id].append(media) self.last_message_timers[media_group_id] = time.time() return len(self.groups[media_group_id]) @@ -100,12 +125,15 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): self.locks[media_group_id] = True return True - async def get_and_delete_group(self, media_group_id: str) -> list[Message]: - group = self.groups[media_group_id] + async def release_lock(self, media_group_id: str) -> None: + self.locks.pop(media_group_id, None) + + async def get_group(self, media_group_id: str) -> list[Message]: + return self.groups.get(media_group_id, []) + + async def delete_group(self, media_group_id: str) -> None: self.groups.pop(media_group_id, None) self.last_message_timers.pop(media_group_id, None) - self.locks.pop(media_group_id, None) - return group async def get_last_message_time(self, media_group_id: str) -> float | None: return self.last_message_timers.get(media_group_id) @@ -115,8 +143,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): def __init__( self, media_group_aggregator: BaseMediaGroupAggregator | None = None, + delay: float = DELAY_SEC, ) -> None: self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator() + self.delay = delay async def __call__( self, @@ -129,23 +159,25 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): await self.media_group_aggregator.add_into_group(event.media_group_id, event) if not await self.media_group_aggregator.acquire_lock(event.media_group_id): return None - last_message_time = time.time() - while True: - delta = DELAY_SEC - (time.time() - last_message_time) - if delta <= 0: - album = await self.media_group_aggregator.get_and_delete_group( + try: + last_message_time = time.time() + while True: + delta = self.delay - (time.time() - last_message_time) + if delta <= 0: + album = await self.media_group_aggregator.get_group(event.media_group_id) + if not album: + return None + album.sort(key=lambda msg: msg.message_id) + data.update(album=album) + result = await handler(album[0], data) + await self.media_group_aggregator.delete_group(event.media_group_id) + return result + await asyncio.sleep(delta) + new_last_message_time = await self.media_group_aggregator.get_last_message_time( event.media_group_id ) - if not album: + if not new_last_message_time: return None - album.sort(key=lambda msg: msg.message_id) - data.update(album=album) - return await handler(album[0], data) - - await asyncio.sleep(delta) - new_last_message_time = await self.media_group_aggregator.get_last_message_time( - event.media_group_id - ) - if not new_last_message_time: - return None - last_message_time = new_last_message_time + last_message_time = new_last_message_time + finally: + await self.media_group_aggregator.release_lock(event.media_group_id)