mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add middlewares (API + Docs + Tests)
This commit is contained in:
parent
e4cd4c1763
commit
5b6ec599b1
24 changed files with 1120 additions and 42 deletions
|
|
@ -7,6 +7,7 @@ from aiogram.api.methods import (
|
|||
SendAnimation,
|
||||
SendAudio,
|
||||
SendContact,
|
||||
SendDice,
|
||||
SendDocument,
|
||||
SendGame,
|
||||
SendInvoice,
|
||||
|
|
@ -26,6 +27,7 @@ from aiogram.api.types import (
|
|||
Audio,
|
||||
Chat,
|
||||
Contact,
|
||||
Dice,
|
||||
Document,
|
||||
EncryptedCredentials,
|
||||
Game,
|
||||
|
|
@ -391,6 +393,16 @@ class TestMessage:
|
|||
),
|
||||
ContentType.POLL,
|
||||
],
|
||||
[
|
||||
Message(
|
||||
message_id=42,
|
||||
date=datetime.datetime.now(),
|
||||
chat=Chat(id=42, type="private"),
|
||||
dice=Dice(value=6),
|
||||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
ContentType.DICE,
|
||||
],
|
||||
[
|
||||
Message(
|
||||
message_id=42,
|
||||
|
|
@ -431,6 +443,7 @@ class TestMessage:
|
|||
["", dict(text="test"), SendMessage],
|
||||
["photo", dict(photo="photo"), SendPhoto],
|
||||
["poll", dict(question="Q?", options=[]), SendPoll],
|
||||
["dice", dict(), SendDice],
|
||||
["sticker", dict(sticker="sticker"), SendSticker],
|
||||
["sticker", dict(sticker="sticker"), SendSticker],
|
||||
[
|
||||
|
|
|
|||
0
tests/test_dispatcher/test_middlewares/__init__.py
Normal file
0
tests/test_dispatcher/test_middlewares/__init__.py
Normal file
241
tests/test_dispatcher/test_middlewares/test_base.py
Normal file
241
tests/test_dispatcher/test_middlewares/test_base.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
import datetime
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram.api.types import (
|
||||
CallbackQuery,
|
||||
Chat,
|
||||
ChosenInlineResult,
|
||||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PreCheckoutQuery,
|
||||
ShippingQuery,
|
||||
Update,
|
||||
User,
|
||||
)
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
except ImportError:
|
||||
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
|
||||
|
||||
|
||||
class MyMiddleware(BaseMiddleware):
|
||||
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_pre_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_pre_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_pre_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_pre_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_pre_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_pre_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_pre_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_pre_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_pre_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
async def on_post_process_update(
|
||||
self, update: Update, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_post_process_message(
|
||||
self, message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_post_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_post_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_post_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_post_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_post_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_post_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_post_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_post_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_post_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
|
||||
UPDATE = Update(update_id=42)
|
||||
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
|
||||
POLL_ANSWER = PollAnswer(
|
||||
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
|
||||
)
|
||||
|
||||
|
||||
class TestBaseMiddleware:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"event_name,event",
|
||||
[["update", UPDATE], ["message", MESSAGE], ["poll_answer", POLL_ANSWER],],
|
||||
)
|
||||
async def test_trigger(
|
||||
self,
|
||||
step: MiddlewareStep,
|
||||
event_name: str,
|
||||
event: UpdateType,
|
||||
middleware_cls: Type[BaseMiddleware],
|
||||
should_be_awaited: bool,
|
||||
):
|
||||
middleware = middleware_cls()
|
||||
|
||||
with patch(
|
||||
f"tests.test_dispatcher.test_middlewares.test_base."
|
||||
f"MyMiddleware.on_{step.value}_{event_name}",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_call:
|
||||
response = await middleware.trigger(
|
||||
step=step, event_name=event_name, event=event, data={}
|
||||
)
|
||||
if should_be_awaited:
|
||||
mocked_call.assert_awaited()
|
||||
assert response is not None
|
||||
else:
|
||||
mocked_call.assert_not_awaited()
|
||||
assert response is None
|
||||
|
||||
def test_not_configured(self):
|
||||
middleware = BaseMiddleware()
|
||||
assert not middleware.configured
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
manager = middleware.manager
|
||||
82
tests/test_dispatcher/test_middlewares/test_manager.py
Normal file
82
tests/test_dispatcher/test_middlewares/test_manager.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import pytest
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.api.types import Update
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
except ImportError:
|
||||
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture("function")
|
||||
def router():
|
||||
return Router()
|
||||
|
||||
|
||||
@pytest.fixture("function")
|
||||
def manager(router: Router):
|
||||
return MiddlewareManager(router)
|
||||
|
||||
|
||||
class TestManager:
|
||||
def test_setup(self, manager: MiddlewareManager):
|
||||
middleware = BaseMiddleware()
|
||||
returned = manager.setup(middleware)
|
||||
assert returned is middleware
|
||||
assert middleware.configured
|
||||
assert middleware.manager is manager
|
||||
assert middleware in manager
|
||||
|
||||
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
|
||||
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
|
||||
with pytest.raises(TypeError):
|
||||
assert manager.setup(obj)
|
||||
|
||||
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
|
||||
middleware = BaseMiddleware()
|
||||
manager.setup(middleware)
|
||||
|
||||
assert middleware.configured
|
||||
|
||||
new_manager = MiddlewareManager(router)
|
||||
with pytest.raises(ValueError):
|
||||
new_manager.setup(middleware)
|
||||
with pytest.raises(ValueError):
|
||||
middleware.setup(new_manager)
|
||||
|
||||
def test_configure_twice(self, manager: MiddlewareManager):
|
||||
middleware = BaseMiddleware()
|
||||
manager.setup(middleware)
|
||||
|
||||
assert middleware.configured
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
|
||||
manager.setup(middleware)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
|
||||
middleware.setup(manager)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("count", range(5))
|
||||
async def test_trigger(self, manager: MiddlewareManager, count: int):
|
||||
for _ in range(count):
|
||||
manager.setup(BaseMiddleware())
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_call:
|
||||
await manager.trigger(
|
||||
step=MiddlewareStep.PROCESS,
|
||||
event_name="update",
|
||||
event=Update(update_id=42),
|
||||
data={},
|
||||
result=None,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
assert mocked_call.await_count == count
|
||||
|
|
@ -18,6 +18,7 @@ from aiogram.api.types import (
|
|||
User,
|
||||
)
|
||||
from aiogram.dispatcher.event.observer import SkipHandler
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.utils.warnings import CodeHasNoEffect
|
||||
|
||||
|
|
@ -407,3 +408,11 @@ class TestRouter:
|
|||
|
||||
await router1.emit_shutdown()
|
||||
assert results == [2, 1, 2]
|
||||
|
||||
def test_use(self):
|
||||
router = Router()
|
||||
|
||||
middleware = router.use(BaseMiddleware())
|
||||
assert isinstance(middleware, BaseMiddleware)
|
||||
assert middleware.configured
|
||||
assert middleware.manager == router.middleware
|
||||
|
|
|
|||
|
|
@ -2,37 +2,54 @@ from typing import Any, Callable, Optional, Tuple
|
|||
|
||||
import pytest
|
||||
|
||||
from aiogram.utils import markdown
|
||||
from aiogram.utils.markdown import (
|
||||
bold,
|
||||
code,
|
||||
hbold,
|
||||
hcode,
|
||||
hide_link,
|
||||
hitalic,
|
||||
hlink,
|
||||
hpre,
|
||||
hstrikethrough,
|
||||
hunderline,
|
||||
italic,
|
||||
link,
|
||||
pre,
|
||||
strikethrough,
|
||||
text,
|
||||
underline,
|
||||
)
|
||||
|
||||
|
||||
class TestMarkdown:
|
||||
@pytest.mark.parametrize(
|
||||
"func,args,sep,result",
|
||||
[
|
||||
[markdown.text, ("test", "test"), " ", "test test"],
|
||||
[markdown.text, ("test", "test"), "\n", "test\ntest"],
|
||||
[markdown.text, ("test", "test"), None, "test test"],
|
||||
[markdown.bold, ("test", "test"), " ", "*test test*"],
|
||||
[markdown.hbold, ("test", "test"), " ", "<b>test test</b>"],
|
||||
[markdown.italic, ("test", "test"), " ", "_test test_\r"],
|
||||
[markdown.hitalic, ("test", "test"), " ", "<i>test test</i>"],
|
||||
[markdown.code, ("test", "test"), " ", "`test test`"],
|
||||
[markdown.hcode, ("test", "test"), " ", "<code>test test</code>"],
|
||||
[markdown.pre, ("test", "test"), " ", "```test test```"],
|
||||
[markdown.hpre, ("test", "test"), " ", "<pre>test test</pre>"],
|
||||
[markdown.underline, ("test", "test"), " ", "__test test__"],
|
||||
[markdown.hunderline, ("test", "test"), " ", "<u>test test</u>"],
|
||||
[markdown.strikethrough, ("test", "test"), " ", "~test test~"],
|
||||
[markdown.hstrikethrough, ("test", "test"), " ", "<s>test test</s>"],
|
||||
[markdown.link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"],
|
||||
[text, ("test", "test"), " ", "test test"],
|
||||
[text, ("test", "test"), "\n", "test\ntest"],
|
||||
[text, ("test", "test"), None, "test test"],
|
||||
[bold, ("test", "test"), " ", "*test test*"],
|
||||
[hbold, ("test", "test"), " ", "<b>test test</b>"],
|
||||
[italic, ("test", "test"), " ", "_test test_\r"],
|
||||
[hitalic, ("test", "test"), " ", "<i>test test</i>"],
|
||||
[code, ("test", "test"), " ", "`test test`"],
|
||||
[hcode, ("test", "test"), " ", "<code>test test</code>"],
|
||||
[pre, ("test", "test"), " ", "```test test```"],
|
||||
[hpre, ("test", "test"), " ", "<pre>test test</pre>"],
|
||||
[underline, ("test", "test"), " ", "__test test__"],
|
||||
[hunderline, ("test", "test"), " ", "<u>test test</u>"],
|
||||
[strikethrough, ("test", "test"), " ", "~test test~"],
|
||||
[hstrikethrough, ("test", "test"), " ", "<s>test test</s>"],
|
||||
[link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"],
|
||||
[
|
||||
markdown.hlink,
|
||||
hlink,
|
||||
("test", "https://aiogram.dev"),
|
||||
None,
|
||||
'<a href="https://aiogram.dev">test</a>',
|
||||
],
|
||||
[
|
||||
markdown.hide_link,
|
||||
hide_link,
|
||||
("https://aiogram.dev",),
|
||||
None,
|
||||
'<a href="https://aiogram.dev">​</a>',
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue