fix: instantly release lock if handler failed

This commit is contained in:
Vitaly312 2026-02-28 12:53:35 +03:00
parent f5bf1afafd
commit 8420f8d53d

View file

@ -26,13 +26,32 @@ class BaseMediaGroupAggregator(ABC):
raise NotImplementedError
@abstractmethod
async def get_and_delete_group(self, media_group_id: str) -> list[Message]:
async def release_lock(self, media_group_id: str) -> None:
raise NotImplementedError
@abstractmethod
async def get_group(self, media_group_id: str) -> list[Message]:
raise NotImplementedError
@abstractmethod
async def delete_group(self, media_group_id: str) -> None:
raise NotImplementedError
@abstractmethod
async def get_last_message_time(self, media_group_id: str) -> float | None:
raise NotImplementedError
@staticmethod
def deduplicate_messages(messages: list[Message]) -> list[Message]:
message_ids = set()
result = []
for message in messages:
if message.message_id in message_ids:
continue
result.append(message)
message_ids.add(message.message_id)
return result
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
redis: "Redis"
@ -68,13 +87,18 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
),
)
async def get_and_delete_group(self, media_group_id: str) -> list[Message]:
async def release_lock(self, media_group_id: str) -> None:
await self.redis.delete(self.get_group_lock_key(media_group_id))
async def get_group(self, media_group_id: str) -> list[Message]:
result = await self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.lrange(self.get_group_key(media_group_id), 0, -1)
pipe.delete(self.get_group_key(media_group_id))
pipe.delete(self.get_last_message_time_key(media_group_id))
result = await pipe.execute()
return [Message.model_validate_json(msg) for msg in result[0]]
await pipe.execute()
async def get_last_message_time(self, media_group_id: str) -> float | None:
result = await self.redis.get(self.get_last_message_time_key(media_group_id))
@ -90,7 +114,8 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
self.locks: dict[str, bool] = {}
async def add_into_group(self, media_group_id: str, media: Message) -> int:
self.groups[media_group_id].append(media)
if media.message_id not in (msg.message_id for msg in self.groups[media_group_id]):
self.groups[media_group_id].append(media)
self.last_message_timers[media_group_id] = time.time()
return len(self.groups[media_group_id])
@ -100,12 +125,15 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
self.locks[media_group_id] = True
return True
async def get_and_delete_group(self, media_group_id: str) -> list[Message]:
group = self.groups[media_group_id]
async def release_lock(self, media_group_id: str) -> None:
self.locks.pop(media_group_id, None)
async def get_group(self, media_group_id: str) -> list[Message]:
return self.groups.get(media_group_id, [])
async def delete_group(self, media_group_id: str) -> None:
self.groups.pop(media_group_id, None)
self.last_message_timers.pop(media_group_id, None)
self.locks.pop(media_group_id, None)
return group
async def get_last_message_time(self, media_group_id: str) -> float | None:
return self.last_message_timers.get(media_group_id)
@ -115,8 +143,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
def __init__(
self,
media_group_aggregator: BaseMediaGroupAggregator | None = None,
delay: float = DELAY_SEC,
) -> None:
self.media_group_aggregator = media_group_aggregator or MemoryMediaGroupAggregator()
self.delay = delay
async def __call__(
self,
@ -129,23 +159,25 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
await self.media_group_aggregator.add_into_group(event.media_group_id, event)
if not await self.media_group_aggregator.acquire_lock(event.media_group_id):
return None
last_message_time = time.time()
while True:
delta = DELAY_SEC - (time.time() - last_message_time)
if delta <= 0:
album = await self.media_group_aggregator.get_and_delete_group(
try:
last_message_time = time.time()
while True:
delta = self.delay - (time.time() - last_message_time)
if delta <= 0:
album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album:
return None
album.sort(key=lambda msg: msg.message_id)
data.update(album=album)
result = await handler(album[0], data)
await self.media_group_aggregator.delete_group(event.media_group_id)
return result
await asyncio.sleep(delta)
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not album:
if not new_last_message_time:
return None
album.sort(key=lambda msg: msg.message_id)
data.update(album=album)
return await handler(album[0], data)
await asyncio.sleep(delta)
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not new_last_message_time:
return None
last_message_time = new_last_message_time
last_message_time = new_last_message_time
finally:
await self.media_group_aggregator.release_lock(event.media_group_id)