From 10fa7315dafea3c39ae2b2585c9a05dd2b1b14ab Mon Sep 17 00:00:00 2001 From: Vitaly312 Date: Sun, 1 Mar 2026 11:21:44 +0300 Subject: [PATCH] update lock TTL --- aiogram/dispatcher/middlewares/media_group.py | 37 +++++++++++++++---- aiogram/filters/media_group.py | 6 ++- docs/utils/media_group.rst | 12 +++++- .../test_middlewares/test_media_group.py | 18 ++++++++- tests/test_filters/test_media_group.py | 2 + 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index 5efcd0f8..83158f4a 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from redis.asyncio.client import Redis DELAY_SEC = 1.0 -TIMEOUT_SEC = 10 +LOCK_TTL_SEC = 30 TTL_SEC = 600 @@ -53,13 +53,29 @@ class BaseMediaGroupAggregator(ABC): message_ids.add(message.message_id) return result + async def get_current_time(self) -> float: + return time.time() + class RedisMediaGroupAggregator(BaseMediaGroupAggregator): + """ + Aggregates media groups in Redis. + """ + redis: "Redis" - def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None: + def __init__( + self, redis: "Redis", ttl_sec: int = TTL_SEC, lock_ttl_sec: int = LOCK_TTL_SEC + ) -> None: + """ + :param ttl_sec: ttl for media group data in seconds + :param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock + releasing until handler finished and too small to expire until telegram send retry if + handler failed. + """ self.redis = redis self.ttl_sec = ttl_sec + self.lock_ttl_sec = lock_ttl_sec @staticmethod def get_group_key(media_group_id: str) -> str: @@ -74,8 +90,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): return f"media_group:{media_group_id}:lock" async def add_into_group(self, media_group_id: str, media: Message) -> int: + current_time = await self.get_current_time() async with self.redis.pipeline(transaction=True) as pipe: - pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec) + pipe.set(self.get_last_message_time_key(media_group_id), current_time, ex=self.ttl_sec) pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json()) pipe.expire(self.get_group_key(media_group_id), self.ttl_sec) res = await pipe.execute() @@ -85,7 +102,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): return cast( bool, await self.redis.set( - self.get_group_lock_key(media_group_id), "1", nx=True, ex=TIMEOUT_SEC + self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec ), ) @@ -96,7 +113,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): result = await cast( Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1) ) - return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result]) + return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)]) async def delete_group(self, media_group_id: str) -> None: async with self.redis.pipeline(transaction=True) as pipe: @@ -110,6 +127,10 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): return None return float(result) + async def get_current_time(self) -> float: + seconds, microseconds = cast(tuple[int, int], await self.redis.time()) + return seconds + microseconds / 1e6 + class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): def __init__(self, ttl_sec: int = TTL_SEC) -> None: @@ -184,9 +205,11 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): if not await self.media_group_aggregator.acquire_lock(event.media_group_id): return None try: - last_message_time = time.time() + last_message_time = await self.media_group_aggregator.get_current_time() while True: - delta = self.delay - (time.time() - last_message_time) + delta = self.delay - ( + await self.media_group_aggregator.get_current_time() - last_message_time + ) if delta <= 0: album = await self.media_group_aggregator.get_group(event.media_group_id) if not album: diff --git a/aiogram/filters/media_group.py b/aiogram/filters/media_group.py index a685bb23..9e9b7d3f 100644 --- a/aiogram/filters/media_group.py +++ b/aiogram/filters/media_group.py @@ -29,8 +29,10 @@ class MediaGroupFilter(Filter): :param max_media_count: max count of media in the group, inclusively """ if count is None: - min_media_count = min_media_count or MIN_MEDIA_COUNT - max_media_count = max_media_count or DEFAULT_MAX_MEDIA_COUNT + if min_media_count is None: + min_media_count = MIN_MEDIA_COUNT + if max_media_count is None: + max_media_count = max(DEFAULT_MAX_MEDIA_COUNT, min_media_count) else: if min_media_count is not None or max_media_count is not None: raise ValueError( diff --git a/docs/utils/media_group.rst b/docs/utils/media_group.rst index dbf33a49..51f018d0 100644 --- a/docs/utils/media_group.rst +++ b/docs/utils/media_group.rst @@ -51,7 +51,10 @@ By default each media in the group is processed separately. You can use :class:`aiogram.dispatcher.middlewares.media_group.MediaGroupAggregatorMiddleware` to process media groups as one. If you do, only one message from the group will be processed, and updates for -other messages with the same media group ID will be suppressed. +other messages with the same media group ID will be suppressed. There are two options to store media groups: + +- :class:`aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator` - simple in-memory storage, used by default +- :class:`aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator` - support distributed environment You also can use :class:`aiogram.filters.media_group.MediaGroupFilter` to filter media groups. @@ -72,7 +75,7 @@ Usage # use middleware @router.message( - MediaGroupFilter(max_count=5), + MediaGroupFilter(max_media_count=5), F.caption == "album_caption" # other filters will be applied to the first message in the group ) async def start(message: Message, album: list[Message]): @@ -93,3 +96,8 @@ References :members: .. autoclass:: aiogram.filters.media_group.MediaGroupFilter :members: +.. autoclass:: aiogram.dispatcher.middlewares.media_group.MemoryMediaGroupAggregator + :members: +.. autoclass:: aiogram.dispatcher.middlewares.media_group.RedisMediaGroupAggregator + :members: + :special-members: __init__ diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py index 91b4bdba..b3f83978 100644 --- a/tests/test_dispatcher/test_middlewares/test_media_group.py +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -96,7 +96,8 @@ class TestMediaGroupAggregatorMiddleware: assert isinstance(first_message, Message) assert first_message.message_id == 1 - async def test_skip_propagating_if_data_deleted(self): + @pytest.mark.parametrize("deleted_object", ["album", "last_message_time"]) + async def test_skip_propagating_if_data_deleted(self, deleted_object): middleware = self.get_middleware() counter = 0 @@ -107,7 +108,10 @@ class TestMediaGroupAggregatorMiddleware: task1 = await wait_until_func_call_sleep( asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {}) ) - await middleware.media_group_aggregator.delete_group("42") + if deleted_object == "album": + middleware.media_group_aggregator.groups.pop("42") + else: + middleware.media_group_aggregator.last_message_timers.pop("42") await task1 assert counter == 0 @@ -199,6 +203,16 @@ class TestMediaGroupAggregator: assert await aggregator.get_group("42") == [] assert await aggregator.get_group("24") == [new_msg] + async def test_get_current_time_memory_aggregator(self): + aggregator = MemoryMediaGroupAggregator() + with mock.patch("time.time", return_value=1.1): + assert await aggregator.get_current_time() == 1.1 + + async def test_get_current_time_redis_aggregator(self): + aggregator = RedisMediaGroupAggregator(mock.Mock(spec=Redis)) + aggregator.redis.time = mock.AsyncMock(return_value=(1, 123456)) + assert await aggregator.get_current_time() == 1.123456 + async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): assert await aggregator.get_last_message_time("42") is None await aggregator.add_into_group("42", _get_message(1)) diff --git a/tests/test_filters/test_media_group.py b/tests/test_filters/test_media_group.py index 3d1af49b..a436d062 100644 --- a/tests/test_filters/test_media_group.py +++ b/tests/test_filters/test_media_group.py @@ -12,6 +12,8 @@ class TestMediaGroupFilter: [ ((), MIN_MEDIA_COUNT, DEFAULT_MAX_MEDIA_COUNT), ((3,), 3, 3), + ((11,), 11, 11), + ((None, 11, None), 11, 11), ((None, 3), 3, DEFAULT_MAX_MEDIA_COUNT), ((None, None, 3), MIN_MEDIA_COUNT, 3), ],