Rework middlewares

This commit is contained in:
Alex Root Junior 2022-04-11 03:30:22 +03:00
parent 5487f3314b
commit 03252c878e
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
35 changed files with 1244 additions and 669 deletions

View file

@ -215,11 +215,11 @@ class TestBaseSession:
return await make_request(bot, method)
session = CustomSession()
assert not session.middlewares
assert not session.middleware._middlewares
session.middleware(my_middleware)
assert my_middleware in session.middlewares
assert len(session.middlewares) == 1
assert my_middleware in session.middleware
assert len(session.middleware) == 1
async def test_use_middleware(self, bot: MockedBot):
flag_before = False

View file

@ -0,0 +1,45 @@
from aiogram import Bot
from aiogram.client.session.middlewares.base import (
BaseRequestMiddleware,
NextRequestMiddlewareType,
)
from aiogram.client.session.middlewares.manager import RequestMiddlewareManager
from aiogram.methods import Response, TelegramMethod
from aiogram.types import TelegramObject
class TestMiddlewareManager:
async def test_register(self):
manager = RequestMiddlewareManager()
@manager
async def middleware(handler, event, data):
await handler(event, data)
assert middleware in manager._middlewares
manager.unregister(middleware)
assert middleware not in manager._middlewares
async def test_wrap_middlewares(self):
manager = RequestMiddlewareManager()
class MyMiddleware(BaseRequestMiddleware):
async def __call__(
self,
make_request: NextRequestMiddlewareType,
bot: Bot,
method: TelegramMethod[TelegramObject],
) -> Response[TelegramObject]:
return await make_request(bot, method)
manager.register(MyMiddleware())
@manager()
@manager
async def middleware(make_request, bot, method):
return await make_request(bot, method)
async def target_call(bot, method, timeout: int = None):
return timeout
assert await manager.wrap_middlewares(target_call, timeout=42)(None, None) == 42

View file

@ -5,27 +5,38 @@ from aiogram.dispatcher.router import Router
from tests.deprecated import check_deprecated
OBSERVERS = {
"callback_query",
"channel_post",
"chosen_inline_result",
"edited_channel_post",
"edited_message",
"errors",
"inline_query",
"message",
"edited_message",
"channel_post",
"edited_channel_post",
"inline_query",
"chosen_inline_result",
"callback_query",
"shipping_query",
"pre_checkout_query",
"poll",
"poll_answer",
"pre_checkout_query",
"shipping_query",
"my_chat_member",
"chat_member",
"chat_join_request",
"errors",
}
DEPRECATED_OBSERVERS = {observer + "_handler" for observer in OBSERVERS}
@pytest.mark.parametrize("observer_name", DEPRECATED_OBSERVERS)
@pytest.mark.parametrize("observer_name", OBSERVERS)
def test_deprecated_handlers_name(observer_name: str):
router = Router()
with check_deprecated("3.2", exception=AttributeError):
observer = getattr(router, observer_name)
observer = getattr(router, f"{observer_name}_handler")
assert isinstance(observer, TelegramEventObserver)
@pytest.mark.parametrize("observer_name", OBSERVERS)
def test_deprecated_register_handlers(observer_name: str):
router = Router()
with check_deprecated("3.2", exception=AttributeError):
register = getattr(router, f"register_{observer_name}")
register(lambda event: True)
assert callable(register)

View file

@ -76,6 +76,21 @@ class TestDispatcher:
assert dp.update.handlers[0].callback == dp._listen_update
assert dp.update.outer_middleware
def test_data_bind(self):
dp = Dispatcher()
assert dp.get("foo") is None
assert dp.get("foo", 42) == 42
dp["foo"] = 1
assert dp._data["foo"] == 1
assert dp["foo"] == 1
del dp["foo"]
assert "foo" not in dp._data
def test_storage_property(self, dispatcher: Dispatcher):
assert dispatcher.storage is dispatcher.fsm.storage
def test_parent_router(self, dispatcher: Dispatcher):
with pytest.raises(RuntimeError):
dispatcher.parent_router = Router()

View file

@ -0,0 +1,42 @@
from functools import partial
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
class TestMiddlewareManager:
async def test_register(self):
manager = MiddlewareManager()
@manager
async def middleware(handler, event, data):
await handler(event, data)
assert middleware in manager._middlewares
manager.unregister(middleware)
assert middleware not in manager._middlewares
async def test_wrap_middlewares(self):
manager = MiddlewareManager()
async def target(*args, **kwargs):
kwargs["target"] = True
kwargs["stack"].append(-1)
return kwargs
async def middleware1(handler, event, data):
data["mw1"] = True
data["stack"].append(1)
return await handler(event, data)
async def middleware2(handler, event, data):
data["mw2"] = True
data["stack"].append(2)
return await handler(event, data)
wrapped = manager.wrap_middlewares([middleware1, middleware2], target)
assert isinstance(wrapped, partial)
assert wrapped.func is middleware1
result = await wrapped(None, {"stack": []})
assert result == {"mw1": True, "mw2": True, "target": True, "stack": [1, 2, -1]}

View file

@ -30,7 +30,7 @@ class MyCallback(CallbackData, prefix="test"):
class TestCallbackData:
def test_init_subclass_prefix_required(self):
assert MyCallback.prefix == "test"
assert MyCallback.__prefix__ == "test"
with pytest.raises(ValueError, match="prefix required.+"):
@ -38,12 +38,12 @@ class TestCallbackData:
pass
def test_init_subclass_sep_validation(self):
assert MyCallback.sep == ":"
assert MyCallback.__separator__ == ":"
class MyCallback2(CallbackData, prefix="test2", sep="@"):
pass
assert MyCallback2.sep == "@"
assert MyCallback2.__separator__ == "@"
with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"):

View file

@ -22,9 +22,9 @@ class TestMagicDataFilter:
assert value.spam is True
return value
f = MagicData(magic_data=F.func(check))
f = MagicData(magic_data=F.func(check).as_("test"))
result = await f(Update(update_id=123), "foo", "bar", spam=True)
assert called
assert isinstance(result, bool)
assert result
assert isinstance(result, dict)
assert result["test"]

View file

@ -2,7 +2,7 @@ from typing import Any, Dict
import pytest
from aiogram.utils.link import create_telegram_link, create_tg_link
from aiogram.utils.link import BRANCH, create_telegram_link, create_tg_link, docs_url
class TestLink:
@ -22,3 +22,12 @@ class TestLink:
)
def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str):
assert create_telegram_link(base, **params) == result
def test_fragment(self):
assert (
docs_url("test.html", fragment_="test")
== f"https://docs.aiogram.dev/en/{BRANCH}/test.html#test"
)
def test_docs(self):
assert docs_url("test.html") == f"https://docs.aiogram.dev/en/{BRANCH}/test.html"

View file

@ -47,6 +47,11 @@ class TestTextDecoration:
'<a href="tg://user?id=42">test</a>',
],
[html_decoration, MessageEntity(type="url", offset=0, length=5), "test"],
[
html_decoration,
MessageEntity(type="spoiler", offset=0, length=5),
'<span class="tg-spoiler">test</span>',
],
[
html_decoration,
MessageEntity(type="text_link", offset=0, length=5, url="https://aiogram.dev"),
@ -76,6 +81,7 @@ class TestTextDecoration:
[markdown_decoration, MessageEntity(type="bot_command", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="email", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="phone_number", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="spoiler", offset=0, length=5), "|test|"],
[
markdown_decoration,
MessageEntity(