From 24fdd285fde7da638610e157c1378742990c8e94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=B8=D1=82=D0=B0=D0=BB=D0=B8=D0=B9?= <128831423+Vitaly312@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:16:47 +0300 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- aiogram/dispatcher/middlewares/media_group.py | 24 +++++++++---------- aiogram/filters/__init__.py | 2 ++ docs/utils/media_group.rst | 4 ++-- .../test_middlewares/test_media_group.py | 10 ++++---- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index e53213ca..63765949 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -71,9 +71,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): ) -> None: """ :param ttl_sec: ttl for media group data in seconds - :param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock - releasing until handler finished and too small to expire until telegram send retry if - handler failed. + :param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the + lock from expiring before the handler finishes, but small enough to expire before + Telegram retries a failed delivery. """ self.redis = redis self.ttl_sec = ttl_sec @@ -115,7 +115,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): "else return 0 end" ) await cast( - Awaitable[str], + Awaitable[int], self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id), ) @@ -123,7 +123,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): result = await cast( Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1) ) - return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)]) + return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) async def delete_group(self, media_group_id: str) -> None: async with self.redis.pipeline(transaction=True) as pipe: @@ -172,7 +172,7 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): return len(self.groups[media_group_id]) async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: - if self.locks.get(media_group_id): + if self.locks.get(media_group_id) is not None: return False self.locks[media_group_id] = lock_id return True @@ -220,8 +220,12 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id): return None try: - last_message_time = await self.media_group_aggregator.get_current_time() while True: + last_message_time = await self.media_group_aggregator.get_last_message_time( + event.media_group_id + ) + if not last_message_time: + return None delta = self.delay - ( await self.media_group_aggregator.get_current_time() - last_message_time ) @@ -238,11 +242,5 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): await self.media_group_aggregator.delete_group(event.media_group_id) return result await asyncio.sleep(delta) - new_last_message_time = await self.media_group_aggregator.get_last_message_time( - event.media_group_id - ) - if not new_last_message_time: - return None - last_message_time = new_last_message_time finally: await self.media_group_aggregator.release_lock(event.media_group_id, lock_id) diff --git a/aiogram/filters/__init__.py b/aiogram/filters/__init__.py index e2668830..42f1f740 100644 --- a/aiogram/filters/__init__.py +++ b/aiogram/filters/__init__.py @@ -18,6 +18,7 @@ from .command import Command, CommandObject, CommandStart from .exception import ExceptionMessageFilter, ExceptionTypeFilter from .logic import and_f, invert_f, or_f from .magic_data import MagicData +from .media_group import MediaGroupFilter from .state import StateFilter BaseFilter = Filter @@ -44,6 +45,7 @@ __all__ = ( "ExceptionTypeFilter", "Filter", "MagicData", + "MediaGroupFilter", "StateFilter", "and_f", "invert_f", diff --git a/docs/utils/media_group.rst b/docs/utils/media_group.rst index 51f018d0..97c6ab31 100644 --- a/docs/utils/media_group.rst +++ b/docs/utils/media_group.rst @@ -60,7 +60,7 @@ You also can use :class:`aiogram.filters.media_group.MediaGroupFilter` to filter media groups. Usage -===== +----- .. code-block:: python @@ -69,7 +69,7 @@ Usage # register middleware from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware - from aiogram.filters.media_group import MediaGroupFilter + from aiogram.filters import MediaGroupFilter router.message.outer_middleware(MediaGroupAggregatorMiddleware()) diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py index a5c65d9e..ef2fd33f 100644 --- a/tests/test_dispatcher/test_middlewares/test_media_group.py +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -189,14 +189,16 @@ class TestMediaGroupAggregator: assert await aggregator.get_group("42") == [] async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator): - await aggregator.acquire_lock("42", "key1") - assert not await aggregator.acquire_lock("42", "key2") - await aggregator.release_lock("42", "key1") - for i in ("key2", "key3"): + for i in ("key1", "key2"): assert await aggregator.acquire_lock("42", i) assert not await aggregator.acquire_lock("42", i) await aggregator.release_lock("42", i) + async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator): + await aggregator.acquire_lock("42", "key1") + await aggregator.release_lock("42", "key2") + assert not await aggregator.acquire_lock("42", "key1") + async def test_expired_objects_removed(self): aggregator = MemoryMediaGroupAggregator() await aggregator.add_into_group("42", _get_message(1))