fix lock releasing

This commit is contained in:
Vitaly312 2026-03-01 15:25:59 +03:00
parent 10fa7315da
commit 80b4abd3b5
2 changed files with 47 additions and 29 deletions

View file

@ -1,5 +1,6 @@
import asyncio
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable
@ -20,27 +21,27 @@ TTL_SEC = 600
class BaseMediaGroupAggregator(ABC):
@abstractmethod
async def add_into_group(self, media_group_id: str, media: Message) -> int:
raise NotImplementedError
pass
@abstractmethod
async def acquire_lock(self, media_group_id: str) -> bool:
raise NotImplementedError
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
pass
@abstractmethod
async def release_lock(self, media_group_id: str) -> None:
raise NotImplementedError
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
pass
@abstractmethod
async def get_group(self, media_group_id: str) -> list[Message]:
raise NotImplementedError
pass
@abstractmethod
async def delete_group(self, media_group_id: str) -> None:
raise NotImplementedError
pass
@abstractmethod
async def get_last_message_time(self, media_group_id: str) -> float | None:
raise NotImplementedError
pass
@staticmethod
def deduplicate_messages(messages: list[Message]) -> list[Message]:
@ -53,8 +54,9 @@ class BaseMediaGroupAggregator(ABC):
message_ids.add(message.message_id)
return result
@abstractmethod
async def get_current_time(self) -> float:
return time.time()
pass
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
@ -98,16 +100,24 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
res = await pipe.execute()
return cast(int, res[1])
async def acquire_lock(self, media_group_id: str) -> bool:
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
return cast(
bool,
await self.redis.set(
self.get_group_lock_key(media_group_id), "1", nx=True, ex=self.lock_ttl_sec
self.get_group_lock_key(media_group_id), lock_id, nx=True, ex=self.lock_ttl_sec
),
)
async def release_lock(self, media_group_id: str) -> None:
await self.redis.delete(self.get_group_lock_key(media_group_id))
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
release_script = (
'if redis.call("get", KEYS[1]) == ARGV[1] then '
'return redis.call("del", KEYS[1]) '
"else return 0 end"
)
await cast(
Awaitable[str],
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
)
async def get_group(self, media_group_id: str) -> list[Message]:
result = await cast(
@ -136,12 +146,12 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
self.groups: dict[str, list[Message]] = defaultdict(list)
self.last_message_timers: dict[str, float] = {}
self.locks: dict[str, bool] = {}
self.locks: dict[str, str] = {}
self.ttl_sec = ttl_sec
def remove_expired_objects(self) -> None:
expired_group_ids = []
current_time = time.time()
current_time = time.monotonic()
for group_id, last_message_time in self.last_message_timers.items():
if current_time - last_message_time > self.ttl_sec:
expired_group_ids.append(group_id)
@ -158,17 +168,18 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
self.groups[media_group_id].append(media)
self.last_message_timers.pop(media_group_id, None)
self.last_message_timers[media_group_id] = time.time()
self.last_message_timers[media_group_id] = time.monotonic()
return len(self.groups[media_group_id])
async def acquire_lock(self, media_group_id: str) -> bool:
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
if self.locks.get(media_group_id):
return False
self.locks[media_group_id] = True
self.locks[media_group_id] = lock_id
return True
async def release_lock(self, media_group_id: str) -> None:
self.locks.pop(media_group_id, None)
async def release_lock(self, media_group_id: str, lock_id: str) -> None:
if self.locks.get(media_group_id) == lock_id:
self.locks.pop(media_group_id)
async def get_group(self, media_group_id: str) -> list[Message]:
return self.groups.get(media_group_id, [])
@ -180,6 +191,9 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
async def get_last_message_time(self, media_group_id: str) -> float | None:
return self.last_message_timers.get(media_group_id)
async def get_current_time(self) -> float:
return time.monotonic()
class MediaGroupAggregatorMiddleware(BaseMiddleware):
def __init__(
@ -202,7 +216,8 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
if not isinstance(event, Message) or not event.media_group_id:
return await handler(event, data)
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
lock_id = str(uuid.uuid4())
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
return None
try:
last_message_time = await self.media_group_aggregator.get_current_time()
@ -230,4 +245,4 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
return None
last_message_time = new_last_message_time
finally:
await self.media_group_aggregator.release_lock(event.media_group_id)
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)

View file

@ -189,15 +189,18 @@ class TestMediaGroupAggregator:
assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
for _ in range(2):
assert await aggregator.acquire_lock("42")
assert not await aggregator.acquire_lock("42")
await aggregator.release_lock("42")
await aggregator.acquire_lock("42", "key1")
assert not await aggregator.acquire_lock("42", "key2")
await aggregator.release_lock("42", "key1")
for i in ("key2", "key3"):
assert await aggregator.acquire_lock("42", i)
assert not await aggregator.acquire_lock("42", i)
await aggregator.release_lock("42", i)
async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1))
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1):
with mock.patch("time.monotonic", return_value=time.time() + aggregator.ttl_sec + 1):
new_msg = _get_message(2)
await aggregator.add_into_group("24", new_msg)
assert await aggregator.get_group("42") == []
@ -205,7 +208,7 @@ class TestMediaGroupAggregator:
async def test_get_current_time_memory_aggregator(self):
aggregator = MemoryMediaGroupAggregator()
with mock.patch("time.time", return_value=1.1):
with mock.patch("time.monotonic", return_value=1.1):
assert await aggregator.get_current_time() == 1.1
async def test_get_current_time_redis_aggregator(self):
@ -216,7 +219,7 @@ class TestMediaGroupAggregator:
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1))
time_before_second_message = time.time()
time_before_second_message = await aggregator.get_current_time()
assert await aggregator.get_last_message_time("42") <= time_before_second_message
await aggregator.add_into_group("42", _get_message(2))
assert await aggregator.get_last_message_time("42") >= time_before_second_message