mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
update tests for media group aggregator
This commit is contained in:
parent
16b718ca07
commit
b7f2da391b
3 changed files with 144 additions and 27 deletions
|
|
@ -5,6 +5,7 @@ from collections import defaultdict
|
|||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.types import Message, TelegramObject
|
||||
|
||||
|
|
@ -56,8 +57,9 @@ class BaseMediaGroupAggregator(ABC):
|
|||
class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||
redis: "Redis"
|
||||
|
||||
def __init__(self, redis: "Redis") -> None:
|
||||
def __init__(self, redis: "Redis", ttl_sec: int = TTL_SEC) -> None:
|
||||
self.redis = redis
|
||||
self.ttl_sec = ttl_sec
|
||||
|
||||
@staticmethod
|
||||
def get_group_key(media_group_id: str) -> str:
|
||||
|
|
@ -73,9 +75,9 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
|
||||
async def add_into_group(self, media_group_id: str, media: Message) -> int:
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=TTL_SEC)
|
||||
pipe.set(self.get_last_message_time_key(media_group_id), time.time(), ex=self.ttl_sec)
|
||||
pipe.rpush(self.get_group_key(media_group_id), media.model_dump_json())
|
||||
pipe.expire(self.get_group_key(media_group_id), TTL_SEC)
|
||||
pipe.expire(self.get_group_key(media_group_id), self.ttl_sec)
|
||||
res = await pipe.execute()
|
||||
return cast(int, res[1])
|
||||
|
||||
|
|
@ -110,16 +112,17 @@ class RedisMediaGroupAggregator(BaseMediaGroupAggregator):
|
|||
|
||||
|
||||
class MemoryMediaGroupAggregator(BaseMediaGroupAggregator):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, ttl_sec: int = TTL_SEC) -> None:
|
||||
self.groups: dict[str, list[Message]] = defaultdict(list)
|
||||
self.last_message_timers: dict[str, float] = {}
|
||||
self.locks: dict[str, bool] = {}
|
||||
self.ttl_sec = ttl_sec
|
||||
|
||||
def remove_expired_objects(self) -> None:
|
||||
expired_group_ids = []
|
||||
current_time = time.time()
|
||||
for group_id, last_message_time in self.last_message_timers.items():
|
||||
if current_time - last_message_time > TTL_SEC:
|
||||
if current_time - last_message_time > self.ttl_sec:
|
||||
expired_group_ids.append(group_id)
|
||||
else:
|
||||
break # the list is sorted in ascending order
|
||||
|
|
@ -188,7 +191,10 @@ class MediaGroupAggregatorMiddleware(BaseMiddleware):
|
|||
album = await self.media_group_aggregator.get_group(event.media_group_id)
|
||||
if not album:
|
||||
return None
|
||||
album.sort(key=lambda msg: msg.message_id)
|
||||
album = sorted(
|
||||
(msg.as_(cast(Bot, data.get("bot"))) for msg in album),
|
||||
key=lambda msg: msg.message_id,
|
||||
)
|
||||
data.update(album=album)
|
||||
result = await handler(album[0], data)
|
||||
await self.media_group_aggregator.delete_group(event.media_group_id)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class MediaGroupFilter(Filter):
|
|||
)
|
||||
|
||||
async def __call__(
|
||||
self, message: Message, album: list[Message] = None
|
||||
self, message: Message, album: list[Message] | None = None
|
||||
) -> Literal[False] | dict[str, Any]:
|
||||
media_count = len(album or [])
|
||||
if not (self.min_media_count <= media_count <= self.max_media_count):
|
||||
|
|
|
|||
|
|
@ -1,17 +1,41 @@
|
|||
import asyncio
|
||||
|
||||
from aiogram.dispatcher.middlewares.media_group import MediaGroupAggregatorMiddleware
|
||||
from aiogram.types import Message, Chat
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Awaitable, Callable
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from aiogram.dispatcher.middlewares.media_group import (
|
||||
BaseMediaGroupAggregator,
|
||||
MediaGroupAggregatorMiddleware,
|
||||
MemoryMediaGroupAggregator,
|
||||
RedisMediaGroupAggregator,
|
||||
)
|
||||
from aiogram.types import Chat, Message
|
||||
|
||||
|
||||
def _get_message(message_id: int, **kwargs):
|
||||
chat = Chat(id=1, type="private", title="Test")
|
||||
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
||||
|
||||
|
||||
async def wait_until_func_call_sleep(func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
|
||||
start_sleep = asyncio.Event()
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
async def mock_sleep(*args, **kwargs):
|
||||
start_sleep.set()
|
||||
await real_sleep(0)
|
||||
|
||||
with mock.patch("asyncio.sleep", mock_sleep):
|
||||
task1 = func(*args, **kwargs)
|
||||
await start_sleep.wait()
|
||||
return task1
|
||||
|
||||
|
||||
class TestMediaGroupAggregatorMiddleware:
|
||||
def _get_message(self, message_id: int, **kwargs):
|
||||
chat = Chat(id=1, type="private", title="Test")
|
||||
return Message(message_id=message_id, date=datetime.now(), chat=chat, **kwargs)
|
||||
|
||||
def get_middleware(self):
|
||||
return MediaGroupAggregatorMiddleware(delay=0.1)
|
||||
|
||||
|
|
@ -22,7 +46,7 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
nonlocal is_called
|
||||
is_called = True
|
||||
|
||||
await self.get_middleware()(next_handler, self._get_message(1), {})
|
||||
await self.get_middleware()(next_handler, _get_message(1), {})
|
||||
assert is_called
|
||||
|
||||
async def test_called_once_for_album(self):
|
||||
|
|
@ -36,13 +60,26 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
album = data.get("album")
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
||||
middleware(next_handler, _get_message(1, media_group_id="42"), {}),
|
||||
middleware(next_handler, _get_message(2, media_group_id="42"), {}),
|
||||
)
|
||||
assert album is not None
|
||||
assert len(album) == 2
|
||||
assert counter == 1
|
||||
|
||||
async def test_bot_object_saved(self, bot):
|
||||
middleware = self.get_middleware()
|
||||
event = album = None
|
||||
|
||||
async def next_handler(message: Message, data: dict[str, Any]):
|
||||
nonlocal event, album
|
||||
event = message
|
||||
album = data.get("album")
|
||||
|
||||
await middleware(next_handler, _get_message(1, media_group_id="42"), {"bot": bot})
|
||||
assert event.bot is bot
|
||||
assert all(msg.bot is bot for msg in album)
|
||||
|
||||
async def test_propagate_first_media_in_album(self):
|
||||
middleware = self.get_middleware()
|
||||
first_message = None
|
||||
|
|
@ -51,13 +88,29 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
nonlocal first_message
|
||||
first_message = message
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(2, media_group_id="42"), {}),
|
||||
middleware(next_handler, self._get_message(1, media_group_id="42"), {}),
|
||||
task1 = await wait_until_func_call_sleep(
|
||||
asyncio.create_task, middleware(next_handler, _get_message(2, media_group_id="42"), {})
|
||||
)
|
||||
await middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||
await task1
|
||||
assert isinstance(first_message, Message)
|
||||
assert first_message.message_id == 1
|
||||
|
||||
async def test_skip_propagating_if_data_deleted(self):
|
||||
middleware = self.get_middleware()
|
||||
counter = 0
|
||||
|
||||
async def next_handler(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
task1 = await wait_until_func_call_sleep(
|
||||
asyncio.create_task, middleware(next_handler, _get_message(1, media_group_id="42"), {})
|
||||
)
|
||||
await middleware.media_group_aggregator.delete_group("42")
|
||||
await task1
|
||||
assert counter == 0
|
||||
|
||||
async def test_different_albums_non_interfere(self):
|
||||
middleware = self.get_middleware()
|
||||
counter = 0
|
||||
|
|
@ -69,8 +122,8 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
albums.append(data.get("album"))
|
||||
|
||||
await asyncio.gather(
|
||||
middleware(next_handler, self._get_message(1, media_group_id="1"), {}),
|
||||
middleware(next_handler, self._get_message(2, media_group_id="2"), {}),
|
||||
middleware(next_handler, _get_message(1, media_group_id="1"), {}),
|
||||
middleware(next_handler, _get_message(2, media_group_id="2"), {}),
|
||||
)
|
||||
assert counter == 2
|
||||
assert len(albums) == 2
|
||||
|
|
@ -80,18 +133,76 @@ class TestMediaGroupAggregatorMiddleware:
|
|||
album = None
|
||||
|
||||
async def failed_handler(*args, **kwargs):
|
||||
raise Exception("Failed")
|
||||
raise RuntimeError("Failed")
|
||||
|
||||
async def working_handler(_, data: dict[str, Any]):
|
||||
nonlocal album
|
||||
album = data.get("album")
|
||||
|
||||
first_message = self._get_message(1, media_group_id="42")
|
||||
second_message = self._get_message(2, media_group_id="42")
|
||||
with pytest.raises(Exception):
|
||||
first_message = _get_message(1, media_group_id="42")
|
||||
second_message = _get_message(2, media_group_id="42")
|
||||
with pytest.raises(RuntimeError):
|
||||
await asyncio.gather(
|
||||
middleware(failed_handler, first_message, {}),
|
||||
middleware(failed_handler, second_message, {}),
|
||||
)
|
||||
await middleware(working_handler, first_message, {})
|
||||
assert len(album) == 2
|
||||
|
||||
|
||||
def test_message_deduplication():
|
||||
message_1, message_2 = _get_message(1), _get_message(2)
|
||||
res = [message_1, message_2]
|
||||
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2]) == res
|
||||
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_2]) == res
|
||||
assert BaseMediaGroupAggregator.deduplicate_messages([message_1, message_2, message_1]) == res
|
||||
|
||||
|
||||
@pytest.fixture(params=["memory", "redis"], scope="function")
|
||||
async def aggregator(request):
|
||||
if request.param == "memory":
|
||||
yield MemoryMediaGroupAggregator()
|
||||
else:
|
||||
redis = Redis.from_url(request.getfixturevalue("redis_server"))
|
||||
yield RedisMediaGroupAggregator(redis)
|
||||
keys = await redis.keys("media_group:*")
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
await redis.aclose()
|
||||
|
||||
|
||||
class TestMediaGroupAggregator:
|
||||
async def test_group_creating(self, aggregator: BaseMediaGroupAggregator):
|
||||
msg1 = _get_message(1)
|
||||
msg2 = _get_message(2)
|
||||
assert await aggregator.add_into_group("42", msg1) == 1
|
||||
assert await aggregator.add_into_group("42", msg2) == 2
|
||||
assert {msg.message_id for msg in await aggregator.get_group("42")} == {
|
||||
msg1.message_id,
|
||||
msg2.message_id,
|
||||
}
|
||||
await aggregator.delete_group("42")
|
||||
assert await aggregator.get_group("42") == []
|
||||
|
||||
async def test_acquire_lock(self, aggregator: BaseMediaGroupAggregator):
|
||||
for _ in range(2):
|
||||
assert await aggregator.acquire_lock("42")
|
||||
assert not await aggregator.acquire_lock("42")
|
||||
await aggregator.release_lock("42")
|
||||
|
||||
async def test_expired_objects_removed(self):
|
||||
aggregator = MemoryMediaGroupAggregator()
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
with mock.patch("time.time", return_value=time.time() + aggregator.ttl_sec + 1):
|
||||
new_msg = _get_message(2)
|
||||
await aggregator.add_into_group("24", new_msg)
|
||||
assert await aggregator.get_group("42") == []
|
||||
assert await aggregator.get_group("24") == [new_msg]
|
||||
|
||||
async def test_last_message_time(self, aggregator: BaseMediaGroupAggregator):
|
||||
assert await aggregator.get_last_message_time("42") is None
|
||||
await aggregator.add_into_group("42", _get_message(1))
|
||||
time_before_second_message = time.time()
|
||||
assert await aggregator.get_last_message_time("42") <= time_before_second_message
|
||||
await aggregator.add_into_group("42", _get_message(2))
|
||||
assert await aggregator.get_last_message_time("42") >= time_before_second_message
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue