update tests for media group aggregator

This commit is contained in:
Vitaly312 2026-02-28 20:32:36 +03:00
parent 16b718ca07
commit b7f2da391b
3 changed files with 144 additions and 27 deletions

View file

@ -5,6 +5,7 @@ from collections import defaultdict
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Message, TelegramObject
@ -56,8 +57,9 @@ class BaseMediaGroupAggregator(ABC):
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
redis: "Redis"
def __init__(self, redis: "Redis") -> None:
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
self.redis = redis
self.ttl_sec = ttl_sec
@staticmethod
def get_group_key(media_group_id: str) -> str:
@ -73,9 +75,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
async def add_into_group(self, media_group_id: str, media: Message) -> int:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=TTL_SEC)
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
pipe.expire(self.get_group_key(media_group_id), TTL_SEC)
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
res = await pipe.execute()
return cast(int, res[1])
@ -110,16 +112,17 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
def __init__(self) -> None:
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
self.groups: dict[str, list[Message]] = defaultdict(list)
self.last_message_timers: dict[str, float] = {}
self.locks: dict[str, bool] = {}
self.ttl_sec = ttl_sec
def remove_expired_objects(self) -> None:
expired_group_ids = []
current_time = time.time()
for group_id, last_message_time in self.last_message_timers.items():
if current_time - last_message_time > TTL_SEC:
if current_time - last_message_time > self.ttl_sec:
expired_group_ids.append(group_id)
else:
break # the list is sorted in ascending order
@ -188,7 +191,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
album = await self.media_group_aggregator.get_group(event.media_group_id)
if not album:
return None
album.sort(key=lambda msg: msg.message_id)
album = sorted(
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
key=lambda msg: msg.message_id,
)
data.update(album=album)
result = await handler(album[0], data)
await self.media_group_aggregator.delete_group(event.media_group_id)

View file

@ -54,7 +54,7 @@ class MediaGroupFilter(Filter):
)
async def __call__(
self, message: Message, album: list[Message] = None
self, message: Message, album: list[Message] | None = None
) -> Literal[False] | dict[str, Any]:
media_count = len(album or [])
if not (self.min_media_count <= media_count <= self.max_media_count):

View file

@ -1,17 +1,41 @@
import asyncio
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
from aiogram.types import Message, Chat
import time
from datetime import datetime
from typing import Any
from typing import Any, Awaitable, Callable
from unittest import mock
import pytest
from redis.asyncio.client import Redis
from aiogram.dispatcher.middlewares.media_group import (
BaseMediaGroupAggregator,
MediaGroupAggregatorMiddleware,
MemoryMediaGroupAggregator,
RedisMediaGroupAggregator,
)
from aiogram.types import Chat, Message
def _get_message(message_id: int, **kwargs):
chat = Chat(id=1, type="private", title="Test")
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
start_sleep = asyncio.Event()
real_sleep = asyncio.sleep
async def mock_sleep(*args, **kwargs):
start_sleep.set()
await real_sleep(0)
with mock.patch("asyncio.sleep", mock_sleep):
task1 = func(*args, **kwargs)
await start_sleep.wait()
return task1
class TestMediaGroupAggregatorMiddleware:
def _get_message(self, message_id: int, **kwargs):
chat = Chat(id=1, type="private", title="Test")
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
def get_middleware(self):
return MediaGroupAggregatorMiddleware(delay=0.1)
@ -22,7 +46,7 @@ class TestMediaGroupAggregatorMiddleware:
nonlocal is_called
is_called = True
await self.get_middleware()(next_handler, self._get_message(1), {})
await self.get_middleware()(next_handler, _get_message(1), {})
assert is_called
async def test_called_once_for_album(self):
@ -36,13 +60,26 @@ class TestMediaGroupAggregatorMiddleware:
album = data.get("album")
await asyncio.gather(
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
)
assert album is not None
assert len(album) == 2
assert counter == 1
async def test_bot_object_saved(self, bot):
middleware = self.get_middleware()
event = album = None
async def next_handler(message: Message, data: dict[str, Any]):
nonlocal event, album
event = message
album = data.get("album")
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
assert event.bot is bot
assert all(msg.bot is bot for msg in album)
async def test_propagate_first_media_in_album(self):
middleware = self.get_middleware()
first_message = None
@ -51,13 +88,29 @@ class TestMediaGroupAggregatorMiddleware:
nonlocal first_message
first_message = message
await asyncio.gather(
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
)
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
await task1
assert isinstance(first_message, Message)
assert first_message.message_id == 1
async def test_skip_propagating_if_data_deleted(self):
middleware = self.get_middleware()
counter = 0
async def next_handler(*args, **kwargs):
nonlocal counter
counter += 1
task1 = await wait_until_func_call_sleep(
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
)
await middleware.media_group_aggregator.delete_group("42")
await task1
assert counter == 0
async def test_different_albums_non_interfere(self):
middleware = self.get_middleware()
counter = 0
@ -69,8 +122,8 @@ class TestMediaGroupAggregatorMiddleware:
albums.append(data.get("album"))
await asyncio.gather(
middleware(next_handler, self._get_message(1, media_group_id="1"), {}),
middleware(next_handler, self._get_message(2, media_group_id="2"), {}),
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
)
assert counter == 2
assert len(albums) == 2
@ -80,18 +133,76 @@ class TestMediaGroupAggregatorMiddleware:
album = None
async def failed_handler(*args, **kwargs):
raise Exception("Failed")
raise RuntimeError("Failed")
async def working_handler(_, data: dict[str, Any]):
nonlocal album
album = data.get("album")
first_message = self._get_message(1, media_group_id="42")
second_message = self._get_message(2, media_group_id="42")
with pytest.raises(Exception):
first_message = _get_message(1, media_group_id="42")
second_message = _get_message(2, media_group_id="42")
with pytest.raises(RuntimeError):
await asyncio.gather(
middleware(failed_handler, first_message, {}),
middleware(failed_handler, second_message, {}),
)
await middleware(working_handler, first_message, {})
assert len(album) == 2
def test_message_deduplication():
message_1, message_2 = _get_message(1), _get_message(2)
res = [message_1, message_2]
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
@pytest.fixture(params=["memory", "redis"], scope="function")
async def aggregator(request):
if request.param == "memory":
yield MemoryMediaGroupAggregator()
else:
redis = Redis.from_url(request.getfixturevalue("redis_server"))
yield RedisMediaGroupAggregator(redis)
keys = await redis.keys("media_group:*")
if keys:
await redis.delete(*keys)
await redis.aclose()
class TestMediaGroupAggregator:
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
msg1 = _get_message(1)
msg2 = _get_message(2)
assert await aggregator.add_into_group("42", msg1) == 1
assert await aggregator.add_into_group("42", msg2) == 2
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
msg1.message_id,
msg2.message_id,
}
await aggregator.delete_group("42")
assert await aggregator.get_group("42") == []
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
for _ in range(2):
assert await aggregator.acquire_lock("42")
assert not await aggregator.acquire_lock("42")
await aggregator.release_lock("42")
async def test_expired_objects_removed(self):
aggregator = MemoryMediaGroupAggregator()
await aggregator.add_into_group("42", _get_message(1))
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1):
new_msg = _get_message(2)
await aggregator.add_into_group("24", new_msg)
assert await aggregator.get_group("42") == []
assert await aggregator.get_group("24") == [new_msg]
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
assert await aggregator.get_last_message_time("42") is None
await aggregator.add_into_group("42", _get_message(1))
time_before_second_message = time.time()
assert await aggregator.get_last_message_time("42") <= time_before_second_message
await aggregator.add_into_group("42", _get_message(2))
assert await aggregator.get_last_message_time("42") >= time_before_second_message