mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
80b4abd3b5
commit
24fdd285fd
4 changed files with 21 additions and 19 deletions
|
|
@ -71,9 +71,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
) -> None:
|
||||
"""
|
||||
:param ttl_sec: ttl for media group data in seconds
|
||||
:param lock_ttl_sec: ttl for lock in seconds. Value should be too big to prevent lock
|
||||
releasing until handler finished and too small to expire until telegram send retry if
|
||||
handler failed.
|
||||
:param lock_ttl_sec: ttl for lock in seconds. Value should be large enough to prevent the
|
||||
lock from expiring before the handler finishes, but small enough to expire before
|
||||
Telegram retries a failed delivery.
|
||||
"""
|
||||
self.redis = redis
|
||||
self.ttl_sec = ttl_sec
|
||||
|
|
@ -115,7 +115,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
"else return 0 end"
|
||||
)
|
||||
await cast(
|
||||
Awaitable[str],
|
||||
Awaitable[int],
|
||||
self.redis.eval(release_script, 1, self.get_group_lock_key(media_group_id), lock_id),
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
result = await cast(
|
||||
Awaitable[list[str]], self.redis.lrange(self.get_group_key(media_group_id), 0, -1)
|
||||
)
|
||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in set(result)])
|
||||
return self.deduplicate_messages([Message.model_validate_json(msg) for msg in result])
|
||||
|
||||
async def delete_group(self, media_group_id: str) -> None:
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
|
|
@ -172,7 +172,7 @@ class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
return len(self.groups[media_group_id])
|
||||
|
||||
async def acquire_lock(self, media_group_id: str, lock_id: str) -> bool:
|
||||
if self.locks.get(media_group_id):
|
||||
if self.locks.get(media_group_id) is not None:
|
||||
return False
|
||||
self.locks[media_group_id] = lock_id
|
||||
return True
|
||||
|
|
@ -220,8 +220,12 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
if not await self.media_group_aggregator.acquire_lock(event.media_group_id, lock_id):
|
||||
return None
|
||||
try:
|
||||
last_message_time = await self.media_group_aggregator.get_current_time()
|
||||
while True:
|
||||
last_message_time = await self.media_group_aggregator.get_last_message_time(
|
||||
event.media_group_id
|
||||
)
|
||||
if not last_message_time:
|
||||
return None
|
||||
delta = self.delay - (
|
||||
await self.media_group_aggregator.get_current_time() - last_message_time
|
||||
)
|
||||
|
|
@ -238,11 +242,5 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
await self.media_group_aggregator.delete_group(event.media_group_id)
|
||||
return result
|
||||
await asyncio.sleep(delta)
|
||||
new_last_message_time = await self.media_group_aggregator.get_last_message_time(
|
||||
event.media_group_id
|
||||
)
|
||||
if not new_last_message_time:
|
||||
return None
|
||||
last_message_time = new_last_message_time
|
||||
finally:
|
||||
await self.media_group_aggregator.release_lock(event.media_group_id, lock_id)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from .command import Command, CommandObject, CommandStart
|
|||
from .exception import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from .logic import and_f, invert_f, or_f
|
||||
from .magic_data import MagicData
|
||||
from .media_group import MediaGroupFilter
|
||||
from .state import StateFilter
|
||||
|
||||
BaseFilter = Filter
|
||||
|
|
@ -44,6 +45,7 @@ __all__ = (
|
|||
"ExceptionTypeFilter",
|
||||
"Filter",
|
||||
"MagicData",
|
||||
"MediaGroupFilter",
|
||||
"StateFilter",
|
||||
"and_f",
|
||||
"invert_f",
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ You also can use :class:`aiogram.filters.media_group.MediaGroupFilter`
|
|||
to filter media groups.
|
||||
|
||||
Usage
|
||||
=====
|
||||
-----
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ Usage
|
|||
|
||||
# register middleware
|
||||
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
|
||||
from aiogram.filters.media_group import MediaGroupFilter
|
||||
from aiogram.filters import MediaGroupFilter
|
||||
|
||||
router.message.outer_middleware(MediaGroupAggregatorMiddleware())
|
||||
|
||||
|
|
|
|||
|
|
@ -189,14 +189,16 @@ class TestMediaGroupAggregator:
|
|||
assert await aggregator.get_group("42") == []
|
||||
|
||||
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
|
||||
await aggregator.acquire_lock("42", "key1")
|
||||
assert not await aggregator.acquire_lock("42", "key2")
|
||||
await aggregator.release_lock("42", "key1")
|
||||
for i in ("key2", "key3"):
|
||||
for i in ("key1", "key2"):
|
||||
assert await aggregator.acquire_lock("42", i)
|
||||
assert not await aggregator.acquire_lock("42", i)
|
||||
await aggregator.release_lock("42", i)
|
||||
|
||||
async def test_lock_not_acquired_with_wrong_key(self, aggregator: BaseMediaGroupAggregator):
|
||||
await aggregator.acquire_lock("42", "key1")
|
||||
await aggregator.release_lock("42", "key2")
|
||||
assert not await aggregator.acquire_lock("42", "key1")
|
||||
|
||||
async def test_expired_objects_removed(self):
|
||||
aggregator = MemoryMediaGroupAggregator()
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue