Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Виталий 2026-03-05 18:16:47 +03:00 committed by Vitaly312
parent 80b4abd3b5
commit 24fdd285fd
4 changed files with 21 additions and 19 deletions

View file

@ -71,9 +71,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
) -> None:
"""
:param ttl_sec: ttl for media group data in seconds
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
releasing until handler finished and too small to expire until telegram send retry if
handler failed.
:param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the
lock from expiring before the handler finishes, but small enough to expire before
Telegram retries a failed delivery.
"""
self.redis = redis
self.ttl_sec = ttl_sec
@ -115,7 +115,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
"else return 0 end"
)
await cast(
Awaitable[str],
Awaitable[int],
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
)
@ -123,7 +123,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
result = await cast(
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
)
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
async def delete_group(self, media_group_id: str) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
@ -172,7 +172,7 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
return len(self.groups[media_group_id])
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
if self.locks.get(media_group_id):
if self.locks.get(media_group_id) is not None:
return False
self.locks[media_group_id] = lock_id
return True
@ -220,8 +220,12 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
return None
try:
last_message_time = await self.media_group_aggregator.get_current_time()
while True:
last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not last_message_time:
return None
delta = self.delay - (
await self.media_group_aggregator.get_current_time() - last_message_time
)
@ -238,11 +242,5 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
await self.media_group_aggregator.delete_group(event.media_group_id)
return result
await asyncio.sleep(delta)
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
event.media_group_id
)
if not new_last_message_time:
return None
last_message_time = new_last_message_time
finally:
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)

View file

@ -18,6 +18,7 @@ from .command import Command, CommandObject, CommandStart
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
from .logic import and_f, invert_f, or_f
from .magic_data import MagicData
from .media_group import MediaGroupFilter
from .state import StateFilter
BaseFilter = Filter
@ -44,6 +45,7 @@ __all__ = (
"ExceptionTypeFilter",
"Filter",
"MagicData",
"MediaGroupFilter",
"StateFilter",
"and_f",
"invert_f",

View file

@ -60,7 +60,7 @@ You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
to filter media groups.
Usage
=====
-----
.. code-block:: python
@ -69,7 +69,7 @@ Usage
# register middleware
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
from aiogram.filters.media_group import MediaGroupFilter
from aiogram.filters import MediaGroupFilter
router.message.outer_middleware(MediaGroupAggregatorMiddleware())

View file

@ -189,14 +189,16 @@ class TestMediaGroupAggregator:
assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
await aggregator.acquire_lock("42", "key1")
assert not await aggregator.acquire_lock("42", "key2")
await aggregator.release_lock("42", "key1")
for i in ("key2", "key3"):
for i in ("key1", "key2"):
assert await aggregator.acquire_lock("42", i)
assert not await aggregator.acquire_lock("42", i)
await aggregator.release_lock("42", i)
async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator):
await aggregator.acquire_lock("42", "key1")
await aggregator.release_lock("42", "key2")
assert not await aggregator.acquire_lock("42", "key1")
async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1))