Pass handler args from middlewares

This commit is contained in:
Alex Root Junior 2018-06-25 17:07:11 +03:00
parent 6832e92ca2
commit 21ba9288d0
3 changed files with 33 additions and 27 deletions

View file

@ -9,7 +9,7 @@ class ContextMiddleware(BaseMiddleware):
Allow to store data at all of lifetime of Update object
"""
async def on_pre_process_update(self, update: types.Update):
async def on_pre_process_update(self, update: types.Update, data: dict):
"""
Start of Update lifetime
@ -18,7 +18,7 @@ class ContextMiddleware(BaseMiddleware):
"""
self._configure_update(update)
async def on_post_process_update(self, update: types.Update, result):
async def on_post_process_update(self, update: types.Update, result, data: dict):
"""
On finishing of processing update

View file

@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware):
return round((time.time() - start) * 1000)
return -1
async def on_pre_process_update(self, update: types.Update):
async def on_pre_process_update(self, update: types.Update, data: dict):
update.conf['_start'] = time.time()
self.logger.debug(f"Received update [ID:{update.update_id}]")
async def on_post_process_update(self, update: types.Update, result):
async def on_post_process_update(self, update: types.Update, result, data: dict):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
async def on_pre_process_message(self, message: types.Message):
async def on_pre_process_message(self, message: types.Message, data: dict):
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_post_process_message(self, message: types.Message, results):
async def on_post_process_message(self, message: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_pre_process_edited_message(self, edited_message):
async def on_pre_process_edited_message(self, edited_message, data: dict):
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_post_process_edited_message(self, edited_message, results):
async def on_post_process_edited_message(self, edited_message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_pre_process_channel_post(self, channel_post: types.Message):
async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]")
async def on_post_process_channel_post(self, channel_post: types.Message, results):
async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
if callback_query.message:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
@ -96,7 +96,7 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_post_process_callback_query(self, callback_query, results):
async def on_post_process_callback_query(self, callback_query, results, data: dict):
if callback_query.message:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
@ -108,25 +108,25 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_post_process_shipping_query(self, shipping_query, results):
async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_pre_process_error(self, dispatcher, update, error):
async def on_pre_process_error(self, dispatcher, update, error, data: dict):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")

View file

@ -1,4 +1,7 @@
import inspect
from contextvars import ContextVar
ctx_data = ContextVar('ctx_handler_data')
class SkipHandler(BaseException):
@ -68,23 +71,26 @@ class Handler:
results = []
data = {}
ctx_data.set(data)
if self.middleware_key:
try:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
except CancelHandler: # Allow to cancel current event
return results
try:
for filters, handler in self.handlers:
try:
data = await check_filters(filters, args)
data.update(await check_filters(filters, args))
except FilterNotPassed:
continue
else:
try:
if self.middleware_key:
# context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
partial_data = _check_spec(handler, data)
response = await handler(*args, **partial_data)
if response is not None:
@ -98,6 +104,6 @@ class Handler:
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
args + (results, data,))
return results