From 9e673998f0266a3a911b6f1e45056e2ed30ab8fb Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 12 Apr 2020 22:13:25 +0300 Subject: [PATCH] Errors handler --- aiogram/dispatcher/event/observer.py | 24 ++++++- aiogram/dispatcher/filters/__init__.py | 1 + aiogram/dispatcher/middlewares/base.py | 17 +++++ aiogram/dispatcher/middlewares/types.py | 1 + aiogram/dispatcher/router.py | 12 ++++ docs/dispatcher/middlewares/basics.md | 4 ++ docs/dispatcher/middlewares/index.md | 12 ++++ .../test_middlewares/test_base.py | 18 +++++- tests/test_dispatcher/test_router.py | 64 ++++++++++++++++++- 9 files changed, 150 insertions(+), 3 deletions(-) diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 756d57f2..cea2eb6a 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -1,7 +1,18 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, Generator, List, Type +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + List, + NoReturn, + Optional, + Type, +) from pydantic import ValidationError @@ -17,6 +28,17 @@ class SkipHandler(Exception): pass +class CancelHandler(Exception): + pass + + +def skip(message: Optional[str] = None) -> NoReturn: + """ + Raise an SkipHandler + """ + raise SkipHandler(message or "Event skipped") + + class EventObserver: """ Base events observer diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 25db7020..b9612ad4 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -27,4 +27,5 @@ BUILTIN_FILTERS: Dict[str, Tuple[Type[BaseFilter], ...]] = { "pre_checkout_query": (), "poll": (), "poll_answer": (), + "errors": (), } diff --git a/aiogram/dispatcher/middlewares/base.py b/aiogram/dispatcher/middlewares/base.py index 2ec921b7..8766f9dc 100644 --- a/aiogram/dispatcher/middlewares/base.py +++ b/aiogram/dispatcher/middlewares/base.py @@ -133,6 +133,11 @@ class BaseMiddleware(AbstractMiddleware): Event that triggers before process poll_answer """ + async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: + """ + Event that triggers before process error + """ + # ============================================================================================= # Event that triggers on process after filters. # ============================================================================================= @@ -214,6 +219,11 @@ class BaseMiddleware(AbstractMiddleware): Event that triggers on process poll_answer """ + async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: + """ + Event that triggers on process error + """ + # ============================================================================================= # Event that triggers after process . # ============================================================================================= @@ -298,3 +308,10 @@ class BaseMiddleware(AbstractMiddleware): """ Event that triggers after processing poll_answer """ + + async def on_post_process_error( + self, exception: Exception, data: Dict[str, Any], result: Any + ) -> Any: + """ + Event that triggers after processing error + """ diff --git a/aiogram/dispatcher/middlewares/types.py b/aiogram/dispatcher/middlewares/types.py index 3d1da420..bc173025 100644 --- a/aiogram/dispatcher/middlewares/types.py +++ b/aiogram/dispatcher/middlewares/types.py @@ -25,6 +25,7 @@ UpdateType = Union[ PreCheckoutQuery, ShippingQuery, Update, + BaseException, ] diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 888117be..dab48c25 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -48,6 +48,8 @@ class Router: ) self.poll_handler = TelegramEventObserver(router=self, event_name="poll") self.poll_answer_handler = TelegramEventObserver(router=self, event_name="poll_answer") + self.errors_handler = TelegramEventObserver(router=self, event_name="error") + self.middleware = MiddlewareManager(router=self) self.startup = EventObserver() @@ -66,6 +68,7 @@ class Router: "pre_checkout_query": self.pre_checkout_query_handler, "poll": self.poll_handler, "poll_answer": self.poll_answer_handler, + "error": self.errors_handler, } # Root handler @@ -291,6 +294,15 @@ class Router: continue raise SkipHandler + + except SkipHandler: + raise + + except Exception as e: + async for result in self.errors_handler.trigger(e, **kwargs): + return result + raise + finally: if user_token: User.reset_current(user_token) diff --git a/docs/dispatcher/middlewares/basics.md b/docs/dispatcher/middlewares/basics.md index 973ffe98..83b58f07 100644 --- a/docs/dispatcher/middlewares/basics.md +++ b/docs/dispatcher/middlewares/basics.md @@ -29,6 +29,7 @@ Where is: - `#!python3 pre_checkout_query` - `#!python3 poll` - `#!python3 poll_answer` + - `#!python3 error` ## Connecting middleware with router @@ -109,3 +110,6 @@ Returns `#!python3 Any` - `#!python3 on_pre_process_poll_answer` - will be triggered on **pre process** `#!python3 poll_answer` event - `#!python3 on_process_poll_answer` - will be triggered on **process** `#!python3 poll_answer` event - `#!python3 on_post_process_poll_answer` - will be triggered on **post process** `#!python3 poll_answer` event +- `#!python3 on_pre_process_error` - will be triggered on **pre process** `#!python3 error` event +- `#!python3 on_process_error` - will be triggered on **process** `#!python3 error` event +- `#!python3 on_post_process_error` - will be triggered on **post process** `#!python3 error` event diff --git a/docs/dispatcher/middlewares/index.md b/docs/dispatcher/middlewares/index.md index 12baf473..6815a565 100644 --- a/docs/dispatcher/middlewares/index.md +++ b/docs/dispatcher/middlewares/index.md @@ -46,6 +46,18 @@ Simple workflow: 1. Call **post-process** update middleware in all routers tree 1. Emit response into webhook (when it needed) +!!! warning + When filters does not match any handler with this event the `#!python3 process` + step will not be called. + +!!! warning + When exception will be caused in handlers pipeline will be stopped immediately + and then start processing error via errors handler and it own middleware callbacks. + +!!! warning + Middlewares for updates will be called for all routers in tree but callbacks for events + will be called only for specific branch of routers. + ### Pipeline in pictures: #### Simple pipeline diff --git a/tests/test_dispatcher/test_middlewares/test_base.py b/tests/test_dispatcher/test_middlewares/test_base.py index 203028ec..7899324d 100644 --- a/tests/test_dispatcher/test_middlewares/test_base.py +++ b/tests/test_dispatcher/test_middlewares/test_base.py @@ -80,6 +80,9 @@ class MyMiddleware(BaseMiddleware): ) -> Any: return "poll_answer" + async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: + return "error" + async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any: return "update" @@ -130,6 +133,9 @@ class MyMiddleware(BaseMiddleware): async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any: return "poll_answer" + async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any: + return "error" + async def on_post_process_update( self, update: Update, data: Dict[str, Any], result: Any ) -> Any: @@ -188,6 +194,11 @@ class MyMiddleware(BaseMiddleware): ) -> Any: return "poll_answer" + async def on_post_process_error( + self, exception: Exception, data: Dict[str, Any], result: Any + ) -> Any: + return "error" + UPDATE = Update(update_id=42) MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private")) @@ -206,7 +217,12 @@ class TestBaseMiddleware: ) @pytest.mark.parametrize( "event_name,event", - [["update", UPDATE], ["message", MESSAGE], ["poll_answer", POLL_ANSWER],], + [ + ["update", UPDATE], + ["message", MESSAGE], + ["poll_answer", POLL_ANSWER], + ["error", Exception("KABOOM")], + ], ) async def test_trigger( self, diff --git a/tests/test_dispatcher/test_router.py b/tests/test_dispatcher/test_router.py index eacb8d0c..2d26a445 100644 --- a/tests/test_dispatcher/test_router.py +++ b/tests/test_dispatcher/test_router.py @@ -17,7 +17,7 @@ from aiogram.api.types import ( Update, User, ) -from aiogram.dispatcher.event.observer import SkipHandler +from aiogram.dispatcher.event.observer import SkipHandler, skip from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.router import Router from aiogram.utils.warnings import CodeHasNoEffect @@ -416,3 +416,65 @@ class TestRouter: assert isinstance(middleware, BaseMiddleware) assert middleware.configured assert middleware.manager == router.middleware + + def test_skip(self): + with pytest.raises(SkipHandler): + skip() + with pytest.raises(SkipHandler, match="KABOOM"): + skip("KABOOM") + + @pytest.mark.asyncio + async def test_exception_handler_catch_exceptions(self): + root_router = Router() + router = Router() + root_router.include_router(router) + + @router.message_handler() + async def message_handler(message: Message): + raise Exception("KABOOM") + + update = Update( + update_id=42, + message=Message( + message_id=42, + date=datetime.datetime.now(), + text="test", + chat=Chat(id=42, type="private"), + from_user=User(id=42, is_bot=False, first_name="Test"), + ), + ) + with pytest.raises(Exception, match="KABOOM"): + await root_router.listen_update( + update_type="message", + update=update, + event=update.message, + from_user=update.message.from_user, + chat=update.message.chat, + ) + + @root_router.errors_handler() + async def root_error_handler(exception: Exception): + return exception + + response = await root_router.listen_update( + update_type="message", + update=update, + event=update.message, + from_user=update.message.from_user, + chat=update.message.chat, + ) + assert isinstance(response, Exception) + assert str(response) == "KABOOM" + + @router.errors_handler() + async def error_handler(exception: Exception): + return "KABOOM" + + response = await root_router.listen_update( + update_type="message", + update=update, + event=update.message, + from_user=update.message.from_user, + chat=update.message.chat, + ) + assert response == "KABOOM"