From 80b4abd3b5c2262cab32fdd987c8efcdced3a9c0 Mon Sep 17 00:00:00 2001 From: Vitaly312 Date: Sun, 1 Mar 2026 15:25:59 +0300 Subject: [PATCH] fix lock releasing --- aiogram/dispatcher/middlewares/media_group.py | 59 ++++++++++++------- .../test_middlewares/test_media_group.py | 17 +++--- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/aiogram/dispatcher/middlewares/media_group.py b/aiogram/dispatcher/middlewares/media_group.py index 83158f4a..e53213ca 100644 --- a/aiogram/dispatcher/middlewares/media_group.py +++ b/aiogram/dispatcher/middlewares/media_group.py @@ -1,5 +1,6 @@ import asyncio import time +import uuid from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Awaitable, Callable @@ -20,27 +21,27 @@ TTL_SEC = 600 class BaseMediaGroupAggregator(ABC): @abstractmethod async def add_into_group(self, media_group_id: str, media: Message) -> int: - raise NotImplementedError + pass @abstractmethod - async def acquire_lock(self, media_group_id: str) -> bool: - raise NotImplementedError + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: + pass @abstractmethod - async def release_lock(self, media_group_id: str) -> None: - raise NotImplementedError + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + pass @abstractmethod async def get_group(self, media_group_id: str) -> list[Message]: - raise NotImplementedError + pass @abstractmethod async def delete_group(self, media_group_id: str) -> None: - raise NotImplementedError + pass @abstractmethod async def get_last_message_time(self, media_group_id: str) -> float | None: - raise NotImplementedError + pass @staticmethod def deduplicate_messages(messages: list[Message]) -> list[Message]: @@ -53,8 +54,9 @@ class BaseMediaGroupAggregator(ABC): message_ids.add(message.message_id) return result + @abstractmethod async def get_current_time(self) -> float: - return time.time() + pass class RedisMediaGroupAggregator(BaseMediaGroupAggregator): @@ -98,16 +100,24 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator): res = await pipe.execute() return cast(int, res[1]) - async def acquire_lock(self, media_group_id: str) -> bool: + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: return cast( bool, await self.redis.set( - self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec + self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec ), ) - async def release_lock(self, media_group_id: str) -> None: - await self.redis.delete(self.get_group_lock_key(media_group_id)) + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + release_script = ( + 'if redis.call("get", KEYS[1]) == ARGV[1] then ' + 'return redis.call("del", KEYS[1]) ' + "else return 0 end" + ) + await cast( + Awaitable[str], + self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id), + ) async def get_group(self, media_group_id: str) -> list[Message]: result = await cast( @@ -136,12 +146,12 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): def __init__(self, ttl_sec: int = TTL_SEC) -> None: self.groups: dict[str, list[Message]] = defaultdict(list) self.last_message_timers: dict[str, float] = {} - self.locks: dict[str, bool] = {} + self.locks: dict[str, str] = {} self.ttl_sec = ttl_sec def remove_expired_objects(self) -> None: expired_group_ids = [] - current_time = time.time() + current_time = time.monotonic() for group_id, last_message_time in self.last_message_timers.items(): if current_time - last_message_time > self.ttl_sec: expired_group_ids.append(group_id) @@ -158,17 +168,18 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]): self.groups[media_group_id].append(media) self.last_message_timers.pop(media_group_id, None) - self.last_message_timers[media_group_id] = time.time() + self.last_message_timers[media_group_id] = time.monotonic() return len(self.groups[media_group_id]) - async def acquire_lock(self, media_group_id: str) -> bool: + async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool: if self.locks.get(media_group_id): return False - self.locks[media_group_id] = True + self.locks[media_group_id] = lock_id return True - async def release_lock(self, media_group_id: str) -> None: - self.locks.pop(media_group_id, None) + async def release_lock(self, media_group_id: str, lock_id: str) -> None: + if self.locks.get(media_group_id) == lock_id: + self.locks.pop(media_group_id) async def get_group(self, media_group_id: str) -> list[Message]: return self.groups.get(media_group_id, []) @@ -180,6 +191,9 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator): async def get_last_message_time(self, media_group_id: str) -> float | None: return self.last_message_timers.get(media_group_id) + async def get_current_time(self) -> float: + return time.monotonic() + class MediaGroupAggregatorMiddleware(BaseMiddleware): def __init__( @@ -202,7 +216,8 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): if not isinstance(event, Message) or not event.media_group_id: return await handler(event, data) await self.media_group_aggregator.add_into_group(event.media_group_id, event) - if not await self.media_group_aggregator.acquire_lock(event.media_group_id): + lock_id = str(uuid.uuid4()) + if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id): return None try: last_message_time = await self.media_group_aggregator.get_current_time() @@ -230,4 +245,4 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware): return None last_message_time = new_last_message_time finally: - await self.media_group_aggregator.release_lock(event.media_group_id) + await self.media_group_aggregator.release_lock(event.media_group_id, lock_id) diff --git a/tests/test_dispatcher/test_middlewares/test_media_group.py b/tests/test_dispatcher/test_middlewares/test_media_group.py index b3f83978..a5c65d9e 100644 --- a/tests/test_dispatcher/test_middlewares/test_media_group.py +++ b/tests/test_dispatcher/test_middlewares/test_media_group.py @@ -189,15 +189,18 @@ class TestMediaGroupAggregator: assert await aggregator.get_group("42") == [] async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator): - for _ in range(2): - assert await aggregator.acquire_lock("42") - assert not await aggregator.acquire_lock("42") - await aggregator.release_lock("42") + await aggregator.acquire_lock("42", "key1") + assert not await aggregator.acquire_lock("42", "key2") + await aggregator.release_lock("42", "key1") + for i in ("key2", "key3"): + assert await aggregator.acquire_lock("42", i) + assert not await aggregator.acquire_lock("42", i) + await aggregator.release_lock("42", i) async def test_expired_objects_removed(self): aggregator = MemoryMediaGroupAggregator() await aggregator.add_into_group("42", _get_message(1)) - with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1): + with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1): new_msg = _get_message(2) await aggregator.add_into_group("24", new_msg) assert await aggregator.get_group("42") == [] @@ -205,7 +208,7 @@ class TestMediaGroupAggregator: async def test_get_current_time_memory_aggregator(self): aggregator = MemoryMediaGroupAggregator() - with mock.patch("time.time", return_value=1.1): + with mock.patch("time.monotonic", return_value=1.1): assert await aggregator.get_current_time() == 1.1 async def test_get_current_time_redis_aggregator(self): @@ -216,7 +219,7 @@ class TestMediaGroupAggregator: async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator): assert await aggregator.get_last_message_time("42") is None await aggregator.add_into_group("42", _get_message(1)) - time_before_second_message = time.time() + time_before_second_message = await aggregator.get_current_time() assert await aggregator.get_last_message_time("42") <= time_before_second_message await aggregator.add_into_group("42", _get_message(2)) assert await aggregator.get_last_message_time("42") >= time_before_second_message