Add middlewares (API + Docs + Tests)

This commit is contained in:
Alex Root Junior 2020-04-12 20:27:32 +03:00
parent e4cd4c1763
commit 5b6ec599b1
24 changed files with 1120 additions and 42 deletions

View file

@ -7,6 +7,7 @@ from aiogram.api.methods import (
SendAnimation,
SendAudio,
SendContact,
SendDice,
SendDocument,
SendGame,
SendInvoice,
@ -26,6 +27,7 @@ from aiogram.api.types import (
Audio,
Chat,
Contact,
Dice,
Document,
EncryptedCredentials,
Game,
@ -391,6 +393,16 @@ class TestMessage:
),
ContentType.POLL,
],
[
Message(
message_id=42,
date=datetime.datetime.now(),
chat=Chat(id=42, type="private"),
dice=Dice(value=6),
from_user=User(id=42, is_bot=False, first_name="Test"),
),
ContentType.DICE,
],
[
Message(
message_id=42,
@ -431,6 +443,7 @@ class TestMessage:
["", dict(text="test"), SendMessage],
["photo", dict(photo="photo"), SendPhoto],
["poll", dict(question="Q?", options=[]), SendPoll],
["dice", dict(), SendDice],
["sticker", dict(sticker="sticker"), SendSticker],
["sticker", dict(sticker="sticker"), SendSticker],
[

View file

@ -0,0 +1,241 @@
import datetime
from typing import Any, Dict, Type
import pytest
from aiogram.api.types import (
CallbackQuery,
Chat,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
User,
)
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
class MyMiddleware(BaseMiddleware):
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
return "channel_post"
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
return "poll_answer"
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
return "channel_post"
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
return "poll_answer"
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
return "update"
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "message"
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_message"
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "channel_post"
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_channel_post"
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
return "inline_query"
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
return "chosen_inline_result"
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
return "callback_query"
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
return "shipping_query"
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
return "pre_checkout_query"
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
return "poll"
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
return "poll_answer"
UPDATE = Update(update_id=42)
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
POLL_ANSWER = PollAnswer(
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
)
class TestBaseMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
)
@pytest.mark.parametrize(
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
)
@pytest.mark.parametrize(
"event_name,event",
[["update", UPDATE], ["message", MESSAGE], ["poll_answer", POLL_ANSWER],],
)
async def test_trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
middleware_cls: Type[BaseMiddleware],
should_be_awaited: bool,
):
middleware = middleware_cls()
with patch(
f"tests.test_dispatcher.test_middlewares.test_base."
f"MyMiddleware.on_{step.value}_{event_name}",
new_callable=CoroutineMock,
) as mocked_call:
response = await middleware.trigger(
step=step, event_name=event_name, event=event, data={}
)
if should_be_awaited:
mocked_call.assert_awaited()
assert response is not None
else:
mocked_call.assert_not_awaited()
assert response is None
def test_not_configured(self):
middleware = BaseMiddleware()
assert not middleware.configured
with pytest.raises(RuntimeError):
manager = middleware.manager

View file

@ -0,0 +1,82 @@
import pytest
from aiogram import Router
from aiogram.api.types import Update
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.dispatcher.middlewares.types import MiddlewareStep
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
@pytest.fixture("function")
def router():
return Router()
@pytest.fixture("function")
def manager(router: Router):
return MiddlewareManager(router)
class TestManager:
def test_setup(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
returned = manager.setup(middleware)
assert returned is middleware
assert middleware.configured
assert middleware.manager is manager
assert middleware in manager
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
with pytest.raises(TypeError):
assert manager.setup(obj)
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
new_manager = MiddlewareManager(router)
with pytest.raises(ValueError):
new_manager.setup(middleware)
with pytest.raises(ValueError):
middleware.setup(new_manager)
def test_configure_twice(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
manager.setup(middleware)
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
middleware.setup(manager)
@pytest.mark.asyncio
@pytest.mark.parametrize("count", range(5))
async def test_trigger(self, manager: MiddlewareManager, count: int):
for _ in range(count):
manager.setup(BaseMiddleware())
with patch(
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
new_callable=CoroutineMock,
) as mocked_call:
await manager.trigger(
step=MiddlewareStep.PROCESS,
event_name="update",
event=Update(update_id=42),
data={},
result=None,
reverse=True,
)
assert mocked_call.await_count == count

View file

@ -18,6 +18,7 @@ from aiogram.api.types import (
User,
)
from aiogram.dispatcher.event.observer import SkipHandler
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect
@ -407,3 +408,11 @@ class TestRouter:
await router1.emit_shutdown()
assert results == [2, 1, 2]
def test_use(self):
router = Router()
middleware = router.use(BaseMiddleware())
assert isinstance(middleware, BaseMiddleware)
assert middleware.configured
assert middleware.manager == router.middleware

View file

@ -2,37 +2,54 @@ from typing import Any, Callable, Optional, Tuple
import pytest
from aiogram.utils import markdown
from aiogram.utils.markdown import (
bold,
code,
hbold,
hcode,
hide_link,
hitalic,
hlink,
hpre,
hstrikethrough,
hunderline,
italic,
link,
pre,
strikethrough,
text,
underline,
)
class TestMarkdown:
@pytest.mark.parametrize(
"func,args,sep,result",
[
[markdown.text, ("test", "test"), " ", "test test"],
[markdown.text, ("test", "test"), "\n", "test\ntest"],
[markdown.text, ("test", "test"), None, "test test"],
[markdown.bold, ("test", "test"), " ", "*test test*"],
[markdown.hbold, ("test", "test"), " ", "<b>test test</b>"],
[markdown.italic, ("test", "test"), " ", "_test test_\r"],
[markdown.hitalic, ("test", "test"), " ", "<i>test test</i>"],
[markdown.code, ("test", "test"), " ", "`test test`"],
[markdown.hcode, ("test", "test"), " ", "<code>test test</code>"],
[markdown.pre, ("test", "test"), " ", "```test test```"],
[markdown.hpre, ("test", "test"), " ", "<pre>test test</pre>"],
[markdown.underline, ("test", "test"), " ", "__test test__"],
[markdown.hunderline, ("test", "test"), " ", "<u>test test</u>"],
[markdown.strikethrough, ("test", "test"), " ", "~test test~"],
[markdown.hstrikethrough, ("test", "test"), " ", "<s>test test</s>"],
[markdown.link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"],
[text, ("test", "test"), " ", "test test"],
[text, ("test", "test"), "\n", "test\ntest"],
[text, ("test", "test"), None, "test test"],
[bold, ("test", "test"), " ", "*test test*"],
[hbold, ("test", "test"), " ", "<b>test test</b>"],
[italic, ("test", "test"), " ", "_test test_\r"],
[hitalic, ("test", "test"), " ", "<i>test test</i>"],
[code, ("test", "test"), " ", "`test test`"],
[hcode, ("test", "test"), " ", "<code>test test</code>"],
[pre, ("test", "test"), " ", "```test test```"],
[hpre, ("test", "test"), " ", "<pre>test test</pre>"],
[underline, ("test", "test"), " ", "__test test__"],
[hunderline, ("test", "test"), " ", "<u>test test</u>"],
[strikethrough, ("test", "test"), " ", "~test test~"],
[hstrikethrough, ("test", "test"), " ", "<s>test test</s>"],
[link, ("test", "https://aiogram.dev"), None, "[test](https://aiogram.dev)"],
[
markdown.hlink,
hlink,
("test", "https://aiogram.dev"),
None,
'<a href="https://aiogram.dev">test</a>',
],
[
markdown.hide_link,
hide_link,
("https://aiogram.dev",),
None,
'<a href="https://aiogram.dev">&#8203;</a>',