From 21ba9288d0c19375a2ed2b165920a1d3538bab55 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 25 Jun 2018 17:07:11 +0300 Subject: [PATCH] Pass handler args from middlewares --- aiogram/contrib/middlewares/context.py | 4 +-- aiogram/contrib/middlewares/logging.py | 42 +++++++++++++------------- aiogram/dispatcher/handler.py | 14 ++++++--- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/aiogram/contrib/middlewares/context.py b/aiogram/contrib/middlewares/context.py index 54fca52d..29f45dcb 100644 --- a/aiogram/contrib/middlewares/context.py +++ b/aiogram/contrib/middlewares/context.py @@ -9,7 +9,7 @@ class ContextMiddleware(BaseMiddleware): Allow to store data at all of lifetime of Update object """ - async def on_pre_process_update(self, update: types.Update): + async def on_pre_process_update(self, update: types.Update, data: dict): """ Start of Update lifetime @@ -18,7 +18,7 @@ class ContextMiddleware(BaseMiddleware): """ self._configure_update(update) - async def on_post_process_update(self, update: types.Update, result): + async def on_post_process_update(self, update: types.Update, result, data: dict): """ On finishing of processing update diff --git a/aiogram/contrib/middlewares/logging.py b/aiogram/contrib/middlewares/logging.py index ee9ac65a..04c34938 100644 --- a/aiogram/contrib/middlewares/logging.py +++ b/aiogram/contrib/middlewares/logging.py @@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware): return round((time.time() - start) * 1000) return -1 - async def on_pre_process_update(self, update: types.Update): + async def on_pre_process_update(self, update: types.Update, data: dict): update.conf['_start'] = time.time() self.logger.debug(f"Received update [ID:{update.update_id}]") - async def on_post_process_update(self, update: types.Update, result): + async def on_post_process_update(self, update: types.Update, result, data: dict): timeout = self.check_timeout(update) if timeout > 0: self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)") - async def on_pre_process_message(self, message: types.Message): + async def on_pre_process_message(self, message: types.Message, data: dict): self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") - async def on_post_process_message(self, message: types.Message, results): + async def on_post_process_message(self, message: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") - async def on_pre_process_edited_message(self, edited_message): + async def on_pre_process_edited_message(self, edited_message, data: dict): self.logger.info(f"Received edited message [ID:{edited_message.message_id}] " f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") - async def on_post_process_edited_message(self, edited_message, results): + async def on_post_process_edited_message(self, edited_message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"edited message [ID:{edited_message.message_id}] " f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") - async def on_pre_process_channel_post(self, channel_post: types.Message): + async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict): self.logger.info(f"Received channel post [ID:{channel_post.message_id}] " f"in channel [ID:{channel_post.chat.id}]") - async def on_post_process_channel_post(self, channel_post: types.Message, results): + async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"channel post [ID:{channel_post.message_id}] " f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]") - async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message): + async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict): self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] " f"in channel [ID:{edited_channel_post.chat.id}]") - async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results): + async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"edited channel post [ID:{edited_channel_post.message_id}] " f"in channel [ID:{edited_channel_post.chat.id}]") - async def on_pre_process_inline_query(self, inline_query: types.InlineQuery): + async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict): self.logger.info(f"Received inline query [ID:{inline_query.id}] " f"from user [ID:{inline_query.from_user.id}]") - async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results): + async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"inline query [ID:{inline_query.id}] " f"from user [ID:{inline_query.from_user.id}]") - async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult): + async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict): self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " f"from user [ID:{chosen_inline_result.from_user.id}] " f"result [ID:{chosen_inline_result.result_id}]") - async def on_post_process_chosen_inline_result(self, chosen_inline_result, results): + async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " f"from user [ID:{chosen_inline_result.from_user.id}] " f"result [ID:{chosen_inline_result.result_id}]") - async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery): + async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict): if callback_query.message: self.logger.info(f"Received callback query [ID:{callback_query.id}] " f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] " @@ -96,7 +96,7 @@ class LoggingMiddleware(BaseMiddleware): f"from inline message [ID:{callback_query.inline_message_id}] " f"from user [ID:{callback_query.from_user.id}]") - async def on_post_process_callback_query(self, callback_query, results): + async def on_post_process_callback_query(self, callback_query, results, data: dict): if callback_query.message: self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"callback query [ID:{callback_query.id}] " @@ -108,25 +108,25 @@ class LoggingMiddleware(BaseMiddleware): f"from inline message [ID:{callback_query.inline_message_id}] " f"from user [ID:{callback_query.from_user.id}]") - async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery): + async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict): self.logger.info(f"Received shipping query [ID:{shipping_query.id}] " f"from user [ID:{shipping_query.from_user.id}]") - async def on_post_process_shipping_query(self, shipping_query, results): + async def on_post_process_shipping_query(self, shipping_query, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"shipping query [ID:{shipping_query.id}] " f"from user [ID:{shipping_query.from_user.id}]") - async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery): + async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict): self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] " f"from user [ID:{pre_checkout_query.from_user.id}]") - async def on_post_process_pre_checkout_query(self, pre_checkout_query, results): + async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"pre-checkout query [ID:{pre_checkout_query.id}] " f"from user [ID:{pre_checkout_query.from_user.id}]") - async def on_pre_process_error(self, dispatcher, update, error): + async def on_pre_process_error(self, dispatcher, update, error, data: dict): timeout = self.check_timeout(update) if timeout > 0: self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)") diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 516b6114..fc98da2a 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,4 +1,7 @@ import inspect +from contextvars import ContextVar + +ctx_data = ContextVar('ctx_handler_data') class SkipHandler(BaseException): @@ -68,23 +71,26 @@ class Handler: results = [] + data = {} + ctx_data.set(data) + if self.middleware_key: try: - await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args) + await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,)) except CancelHandler: # Allow to cancel current event return results try: for filters, handler in self.handlers: try: - data = await check_filters(filters, args) + data.update(await check_filters(filters, args)) except FilterNotPassed: continue else: try: if self.middleware_key: # context.set_value('handler', handler) - await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) + await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,)) partial_data = _check_spec(handler, data) response = await handler(*args, **partial_data) if response is not None: @@ -98,6 +104,6 @@ class Handler: finally: if self.middleware_key: await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", - args + (results,)) + args + (results, data,)) return results