From 098f4e8fffeb4ed4830631dbe9b903c2df023e53 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 8 Nov 2021 23:13:15 +0200 Subject: [PATCH] Oops. unstash files. --- .coveragerc | 5 + aiogram/dispatcher/webhook/aiohttp_server.py | 21 +-- tests/mocked_bot.py | 8 +- .../test_webhook/test_aiohtt_server.py | 154 ++++++++++++++++-- 4 files changed, 161 insertions(+), 27 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..9feee202 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + @abstractmethod diff --git a/aiogram/dispatcher/webhook/aiohttp_server.py b/aiogram/dispatcher/webhook/aiohttp_server.py index fefcf2b4..36a28a01 100644 --- a/aiogram/dispatcher/webhook/aiohttp_server.py +++ b/aiogram/dispatcher/webhook/aiohttp_server.py @@ -5,6 +5,8 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, cast from aiohttp import web from aiohttp.abc import Application +from aiohttp.typedefs import Handler +from aiohttp.web_middlewares import middleware from aiogram import Bot, Dispatcher, loggers from aiogram.dispatcher.webhook.security import IPFilter @@ -50,21 +52,21 @@ def check_ip(ip_filter: IPFilter, request: web.Request) -> Tuple[str, bool]: host, _ = peer_name return host, host in ip_filter - return "", False + # Potentially impossible case + return "", False # pragma: no cover def ip_filter_middleware( ip_filter: IPFilter, -) -> Callable[[web.Request, Callable[[web.Request], Any]], Awaitable[Any]]: +) -> Callable[[web.Request, Handler], Awaitable[Any]]: """ :param ip_filter: :return: """ - async def _ip_filter_middleware( - request: web.Request, handler: Callable[[web.Request], Any] - ) -> Any: + @middleware + async def _ip_filter_middleware(request: web.Request, handler: Handler) -> Any: ip_address, accept = check_ip(ip_filter=ip_filter, request=request) if not accept: loggers.webhook.warning(f"Blocking request from an unauthorized IP: {ip_address}") @@ -132,17 +134,16 @@ class BaseRequestHandler(ABC): bot=bot, update=await request.json(loads=bot.session.json_loads) ) ) - return web.Response(status=200) + return web.json_response({}) async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response: result = await self.dispatcher.feed_webhook_update( bot, await request.json(loads=bot.session.json_loads), ) - if isinstance(result, TelegramMethod): - request = result.build_request(bot) - return web.json_response(request.render_webhook_request()) - return web.Response(status=200) + if result: + return web.json_response(result) + return web.json_response({}) async def handle(self, request: web.Request) -> web.Response: bot = await self.resolve_bot(request) diff --git a/tests/mocked_bot.py b/tests/mocked_bot.py index 0a496105..1783eb86 100644 --- a/tests/mocked_bot.py +++ b/tests/mocked_bot.py @@ -13,6 +13,7 @@ class MockedSession(BaseSession): super(MockedSession, self).__init__() self.responses: Deque[Response[TelegramType]] = deque() self.requests: Deque[Request] = deque() + self.closed = True def add_result(self, response: Response[TelegramType]) -> Response[TelegramType]: self.responses.append(response) @@ -22,11 +23,12 @@ class MockedSession(BaseSession): return self.requests.pop() async def close(self): - pass + self.closed = True async def make_request( self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET ) -> TelegramType: + self.closed = False self.requests.append(method.build_request(bot)) response: Response[TelegramType] = self.responses.pop() self.check_response( @@ -45,7 +47,9 @@ class MockedBot(Bot): session: MockedSession def __init__(self, **kwargs): - super(MockedBot, self).__init__("42:TEST", session=MockedSession(), **kwargs) + super(MockedBot, self).__init__( + kwargs.pop("token", "42:TEST"), session=MockedSession(), **kwargs + ) self._me = User( id=self.id, is_bot=True, diff --git a/tests/test_dispatcher/test_webhook/test_aiohtt_server.py b/tests/test_dispatcher/test_webhook/test_aiohtt_server.py index c75ed518..fa9dad9c 100644 --- a/tests/test_dispatcher/test_webhook/test_aiohtt_server.py +++ b/tests/test_dispatcher/test_webhook/test_aiohtt_server.py @@ -1,12 +1,23 @@ -from typing import Any +import time +from dataclasses import dataclass +from typing import Any, Dict import pytest +from aiohttp import web +from aiohttp.test_utils import TestClient from aiohttp.web_app import Application -from aiogram import Dispatcher -from aiogram.dispatcher.webhook.aiohttp_server import ip_filter_middleware, setup_application +from aiogram import Dispatcher, F +from aiogram.dispatcher.webhook.aiohttp_server import ( + SimpleRequestHandler, + TokenBasedRequestHandler, + ip_filter_middleware, + setup_application, +) from aiogram.dispatcher.webhook.security import IPFilter -from aiogram.methods import Request +from aiogram.methods import GetMe, Request +from aiogram.types import Message, User +from tests.mocked_bot import MockedBot class TestAiohttpServer: @@ -20,18 +31,131 @@ class TestAiohttpServer: assert len(app.on_startup) == 2 assert len(app.on_shutdown) == 1 - async def test_middleware(self, aiohttp_client: Any): + async def test_middleware(self, aiohttp_client): app = Application() - app.middlewares.append(ip_filter_middleware(IPFilter.default())) + ip_filter = IPFilter.default() + app.middlewares.append(ip_filter_middleware(ip_filter)) - async def handler1(request: Request): - pass + async def handler(request: Request): + return web.json_response({"ok": True}) - async def handler2(request: Request): - pytest.fail() + app.router.add_route("POST", "/webhook", handler) + client: TestClient = await aiohttp_client(app) - app.router.add_route("POST", "/good", handler1) - app.router.add_route("POST", "/bad", handler2) - client = await aiohttp_client(app) - resp = await client.get("/bad") - assert resp + resp = await client.post("/webhook") + assert resp.status == 401 + + resp = await client.post("/webhook", headers={"X-Forwarded-For": "149.154.167.220"}) + assert resp.status == 200 + + resp = await client.post( + "/webhook", headers={"X-Forwarded-For": "149.154.167.220,10.111.0.2"} + ) + assert resp.status == 200 + + +class TestSimpleRequestHandler: + async def make_reqest(self, client: TestClient, text: str = "test"): + return await client.post( + "/webhook", + json={ + "update_id": 0, + "message": { + "message_id": 0, + "from": {"id": 42, "first_name": "Test", "is_bot": False}, + "chat": {"id": 42, "is_bot": False, "type": "private"}, + "date": int(time.time()), + "text": text, + }, + }, + ) + + async def test(self, bot: MockedBot, aiohttp_client): + app = Application() + dp = Dispatcher() + + @dp.message(F.text == "test") + def handle_message(msg: Message): + return msg.answer("PASS") + + handler = SimpleRequestHandler( + dispatcher=dp, + bot=bot, + handle_in_background=False, + ) + handler.register(app, path="/webhook") + client: TestClient = await aiohttp_client(app) + + resp = await self.make_reqest(client=client) + assert resp.status == 200 + result = await resp.json() + assert result["method"] == "sendMessage" + + resp = await self.make_reqest(client=client, text="spam") + assert resp.status == 200 + result = await resp.json() + assert not result + + handler.handle_in_background = True + resp = await self.make_reqest(client=client) + assert resp.status == 200 + result = await resp.json() + assert not result + + +class TestTokenBasedRequestHandler: + async def test_register(self): + dispatcher = Dispatcher() + app = Application() + + handler = TokenBasedRequestHandler(dispatcher=dispatcher) + + assert len(app.router.routes()) == 0 + with pytest.raises(ValueError): + handler.register(app, path="/webhook") + + assert len(app.router.routes()) == 0 + handler.register(app, path="/webhook/{bot_token}") + assert len(app.router.routes()) == 1 + + async def test_close(self): + dispatcher = Dispatcher() + + handler = TokenBasedRequestHandler(dispatcher=dispatcher) + + bot1 = handler.bots["42:TEST"] = MockedBot(token="42:TEST") + bot1.add_result_for(GetMe, ok=True, result=User(id=42, is_bot=True, first_name="Test")) + assert await bot1.get_me() + assert not bot1.session.closed + bot2 = handler.bots["1337:TEST"] = MockedBot(token="1337:TEST") + bot2.add_result_for(GetMe, ok=True, result=User(id=1337, is_bot=True, first_name="Test")) + assert await bot2.get_me() + assert not bot2.session.closed + + await handler.close() + assert bot1.session.closed + assert bot2.session.closed + + async def test_resolve_bot(self): + dispatcher = Dispatcher() + handler = TokenBasedRequestHandler(dispatcher=dispatcher) + + @dataclass + class FakeRequest: + match_info: Dict[str, Any] + + bot1 = await handler.resolve_bot(request=FakeRequest(match_info={"bot_token": "42:TEST"})) + assert bot1.id == 42 + + bot2 = await handler.resolve_bot( + request=FakeRequest(match_info={"bot_token": "1337:TEST"}) + ) + assert bot2.id == 1337 + + bot3 = await handler.resolve_bot( + request=FakeRequest(match_info={"bot_token": "1337:TEST"}) + ) + assert bot3.id == 1337 + + assert bot2 == bot3 + assert len(handler.bots) == 2