diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 26122b27..6c721154 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,8 +1,9 @@ import asyncio import logging -from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter -from .handler import Handler +from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters +from .handler import Handler, NextStepHandler +from .. import types from ..bot import Bot from ..types.message import ContentType @@ -18,12 +19,25 @@ class Dispatcher: self.loop = loop self.last_update_id = 0 - self.updates = Handler(self) - self.messages = Handler(self) - self.commands = Handler(self) - self.updates.register(self.process_update) + self.updates_handler = Handler(self) + self.message_handlers = Handler(self) + self.edited_message_handlers = Handler(self) + self.channel_post_handlers = Handler(self) + self.edited_channel_post_handlers = Handler(self) + self.inline_query_handlers = Handler(self) + self.chosen_inline_result_handlers = Handler(self) + self.callback_query_handlers = Handler(self) + self.shipping_query_handlers = Handler(self) + self.pre_checkout_query_handlers = Handler(self) + self.next_step_message_handlers = NextStepHandler(self) + self.updates_handler.register(self.process_update) + # self.message_handlers.register(self._notify_next_message) + + self._pooling = False + + def __del__(self): self._pooling = False async def skip_updates(self): @@ -39,11 +53,29 @@ class Dispatcher: async def process_updates(self, updates): for update in updates: - self.loop.create_task(self.updates.notify(update)) + self.loop.create_task(self.updates_handler.notify(update)) async def process_update(self, update): + self.last_update_id = update.update_id if update.message: - await self.messages.notify(update.message) + if not await self.next_step_message_handlers.notify(update.message): + await self.message_handlers.notify(update.message) + if update.edited_message: + await self.edited_message_handlers.notify(update.edited_message) + if update.channel_post: + await self.channel_post_handlers.notify(update.channel_post) + if update.edited_channel_post: + await self.edited_channel_post_handlers.notify(update.edited_channel_post) + if update.inline_query: + await self.inline_query_handlers.notify(update.inline_query) + if update.chosen_inline_result: + await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) + if update.callback_query: + await self.callback_query_handlers.notify(update.callback_query) + if update.shipping_query: + await self.shipping_query_handlers.notify(update.shipping_query) + if update.pre_checkout_query: + await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) async def start_pooling(self, timeout=20, relax=0.1): if self._pooling: @@ -72,8 +104,7 @@ class Dispatcher: def stop_pooling(self): self._pooling = False - def message_handler(self, commands=None, regexp=None, content_type=None, func=None, - custom_filters=None): + def message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None, **kwargs): if commands is None: commands = [] if content_type is None: @@ -81,29 +112,160 @@ class Dispatcher: if custom_filters is None: custom_filters = [] - filters_preset = [] - if commands: - if isinstance(commands, str): - commands = [commands] - filters_preset.append(CommandsFilter(commands)) - - if regexp: - filters_preset.append(RegexpFilter(regexp)) - - if content_type: - filters_preset.append(ContentTypeFilter(content_type)) - - if func: - filters_preset.append(func) - - if custom_filters: - filters_preset += custom_filters + filters_set = generate_default_filters(*custom_filters, + commands=commands, + regexp=regexp, + content_type=content_type, + func=func, + **kwargs) def decorator(func): - self.messages.register(func, filters_preset) + self.message_handlers.register(func, filters_set) return func return decorator - def __del__(self): - self._pooling = False + def edited_message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None, + **kwargs): + if commands is None: + commands = [] + if content_type is None: + content_type = [ContentType.TEXT] + if custom_filters is None: + custom_filters = [] + + filters_set = generate_default_filters(*custom_filters, + commands=commands, + regexp=regexp, + content_type=content_type, + func=func, + **kwargs) + + def decorator(func): + self.edited_message_handlers.register(func, filters_set) + return func + + return decorator + + def channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters, **kwargs): + if commands is None: + commands = [] + if content_type is None: + content_type = [ContentType.TEXT] + if custom_filters is None: + custom_filters = [] + + filters_set = generate_default_filters(*custom_filters, + commands=commands, + regexp=regexp, + content_type=content_type, + func=func, + **kwargs) + + def decorator(func): + self.channel_post_handlers.register(func, filters_set) + return func + + return decorator + + def edited_channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters, + **kwargs): + if commands is None: + commands = [] + if content_type is None: + content_type = [ContentType.TEXT] + if custom_filters is None: + custom_filters = [] + + filters_set = generate_default_filters(*custom_filters, + commands=commands, + regexp=regexp, + content_type=content_type, + func=func, + **kwargs) + + def decorator(func): + self.edited_channel_post_handlers.register(func, filters_set) + return func + + return decorator + + def inline_handler(self, func=None, *custom_filters, **kwargs): + if custom_filters is None: + custom_filters = [] + filters_set = generate_default_filters(*custom_filters, + func=func, + **kwargs) + + def decorator(func): + self.inline_query_handlers.register(func, filters_set) + return func + + return decorator + + def chosen_inline_handler(self, func=None, *custom_filters, **kwargs): + if custom_filters is None: + custom_filters = [] + filters_set = generate_default_filters(*custom_filters, + func=func, + **kwargs) + + def decorator(func): + self.chosen_inline_result_handlers.register(func, filters_set) + return func + + return decorator + + def callback_query_handler(self, func=None, *custom_filters, **kwargs): + if custom_filters is None: + custom_filters = [] + filters_set = generate_default_filters(*custom_filters, + func=func, + **kwargs) + + def decorator(func): + self.chosen_inline_result_handlers.register(func, filters_set) + return func + + return decorator + + def shipping_query_handler(self, func=None, *custom_filters, **kwargs): + if custom_filters is None: + custom_filters = [] + filters_set = generate_default_filters(*custom_filters, + func=func, + **kwargs) + + def decorator(func): + self.shipping_query_handlers.register(func, filters_set) + return func + + return decorator + + def pre_checkout_query_handler(self, func=None, *custom_filters, **kwargs): + if custom_filters is None: + custom_filters = [] + filters_set = generate_default_filters(*custom_filters, + func=func, + **kwargs) + + def decorator(func): + self.pre_checkout_query_handlers.register(func, filters_set) + return func + + return decorator + + async def next_message(self, message: types.Message, otherwise=None, once=False, + regexp=None, content_type=None, func=None, custom_filters=None, **kwargs): + if content_type is None: + content_type = [] + if custom_filters is None: + custom_filters = [] + + filters_set = generate_default_filters(*custom_filters, + regexp=regexp, + content_type=content_type, + func=func, + **kwargs) + self.next_step_message_handlers.register(message, otherwise, once, filters_set) + return await self.next_step_message_handlers.wait(message) diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index 2f3a1a6a..99b449ce 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -78,3 +78,33 @@ class ContentTypeFilter(Filter): def check(self, message): return message.content_type in self.content_types + + +def generate_default_filters(*args, **kwargs): + filters_set = [] + + for name, filter_ in kwargs.items(): + if filter_ is None: + continue + if name == 'commands': + if isinstance(filter_, str): + filters_set.append(CommandsFilter([filter_])) + else: + filters_set.append(CommandsFilter(filter_)) + elif name == 'regexp': + filters_set.append(RegexpFilter(filter_)) + elif name == 'content_type': + filters_set.append(ContentTypeFilter(filter_)) + elif name == 'func': + filters_set.append(filter_) + + filters_set += list(args) + + return filters_set + + +class DefaultFilters: + COMMANDS = 'commands' + REGEXP = 'regexp' + CONTENT_TYPE = 'content_type' + FUNC = 'func' diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 2090fe1a..a3c2c971 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,3 +1,6 @@ +from asyncio import Event + +from aiogram import types from .filters import check_filters @@ -27,3 +30,45 @@ class Handler: await handler(*args, **kwargs) if self.once: break + + +class NextStepHandler: + def __init__(self, dispatcher): + self.dispatcher = dispatcher + self.handlers = {} + + def register(self, message, otherwise=None, once=False, filters=None): + chat_id = message.chat.id + if chat_id not in self.handlers: + self.handlers[chat_id] = {'event': Event(), 'filters': filters, + 'otherwise': otherwise, 'once': once} + return True + return False + + async def notify(self, message): + chat_id = message.chat.id + if chat_id not in self.handlers: + return False + handler = self.handlers[chat_id] + if handler['filters'] and not await check_filters(handler['filters'], [message], {}): + otherwise = handler['otherwise'] + if otherwise: + await otherwise(message) + if not handler['once']: + return False + handler['message'] = message + handler['event'].set() + return True + + async def wait(self, message) -> types.Message: + chat_id = message.chat.id + handler = self.handlers[chat_id] + event = handler.get('event') + + await event.wait() + message = self.handlers[chat_id]['message'] + self.reset(chat_id) + return message + + def reset(self, identifier): + del self.handlers[identifier] diff --git a/aiogram/types/base.py b/aiogram/types/base.py index 2604af07..c1cb0e67 100644 --- a/aiogram/types/base.py +++ b/aiogram/types/base.py @@ -33,7 +33,7 @@ class Deserializable: def to_json(self): result = {} for name, attr in self.__dict__.items(): - if not attr or name == '_bot': + if not attr or name in ['_bot', '_parent']: continue if hasattr(attr, 'to_json'): attr = getattr(attr, 'to_json')() @@ -95,7 +95,7 @@ class Deserializable: raise ValueError("data should be a json dict or string.") def __str__(self): - return json.dumps(self.to_json()) + return str(self.to_json()) def __repr__(self): return str(self)