mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Modify dispatcher. Add some handlers.
This commit is contained in:
parent
0a5d1c6feb
commit
765741e122
4 changed files with 269 additions and 32 deletions
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter
|
||||
from .handler import Handler
|
||||
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters
|
||||
from .handler import Handler, NextStepHandler
|
||||
from .. import types
|
||||
from ..bot import Bot
|
||||
from ..types.message import ContentType
|
||||
|
||||
|
|
@ -18,12 +19,25 @@ class Dispatcher:
|
|||
self.loop = loop
|
||||
|
||||
self.last_update_id = 0
|
||||
self.updates = Handler(self)
|
||||
self.messages = Handler(self)
|
||||
self.commands = Handler(self)
|
||||
|
||||
self.updates.register(self.process_update)
|
||||
self.updates_handler = Handler(self)
|
||||
self.message_handlers = Handler(self)
|
||||
self.edited_message_handlers = Handler(self)
|
||||
self.channel_post_handlers = Handler(self)
|
||||
self.edited_channel_post_handlers = Handler(self)
|
||||
self.inline_query_handlers = Handler(self)
|
||||
self.chosen_inline_result_handlers = Handler(self)
|
||||
self.callback_query_handlers = Handler(self)
|
||||
self.shipping_query_handlers = Handler(self)
|
||||
self.pre_checkout_query_handlers = Handler(self)
|
||||
|
||||
self.next_step_message_handlers = NextStepHandler(self)
|
||||
self.updates_handler.register(self.process_update)
|
||||
# self.message_handlers.register(self._notify_next_message)
|
||||
|
||||
self._pooling = False
|
||||
|
||||
def __del__(self):
|
||||
self._pooling = False
|
||||
|
||||
async def skip_updates(self):
|
||||
|
|
@ -39,11 +53,29 @@ class Dispatcher:
|
|||
|
||||
async def process_updates(self, updates):
|
||||
for update in updates:
|
||||
self.loop.create_task(self.updates.notify(update))
|
||||
self.loop.create_task(self.updates_handler.notify(update))
|
||||
|
||||
async def process_update(self, update):
|
||||
self.last_update_id = update.update_id
|
||||
if update.message:
|
||||
await self.messages.notify(update.message)
|
||||
if not await self.next_step_message_handlers.notify(update.message):
|
||||
await self.message_handlers.notify(update.message)
|
||||
if update.edited_message:
|
||||
await self.edited_message_handlers.notify(update.edited_message)
|
||||
if update.channel_post:
|
||||
await self.channel_post_handlers.notify(update.channel_post)
|
||||
if update.edited_channel_post:
|
||||
await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||
if update.inline_query:
|
||||
await self.inline_query_handlers.notify(update.inline_query)
|
||||
if update.chosen_inline_result:
|
||||
await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||
if update.callback_query:
|
||||
await self.callback_query_handlers.notify(update.callback_query)
|
||||
if update.shipping_query:
|
||||
await self.shipping_query_handlers.notify(update.shipping_query)
|
||||
if update.pre_checkout_query:
|
||||
await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||
|
||||
async def start_pooling(self, timeout=20, relax=0.1):
|
||||
if self._pooling:
|
||||
|
|
@ -72,8 +104,7 @@ class Dispatcher:
|
|||
def stop_pooling(self):
|
||||
self._pooling = False
|
||||
|
||||
def message_handler(self, commands=None, regexp=None, content_type=None, func=None,
|
||||
custom_filters=None):
|
||||
def message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None, **kwargs):
|
||||
if commands is None:
|
||||
commands = []
|
||||
if content_type is None:
|
||||
|
|
@ -81,29 +112,160 @@ class Dispatcher:
|
|||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
|
||||
filters_preset = []
|
||||
if commands:
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
filters_preset.append(CommandsFilter(commands))
|
||||
|
||||
if regexp:
|
||||
filters_preset.append(RegexpFilter(regexp))
|
||||
|
||||
if content_type:
|
||||
filters_preset.append(ContentTypeFilter(content_type))
|
||||
|
||||
if func:
|
||||
filters_preset.append(func)
|
||||
|
||||
if custom_filters:
|
||||
filters_preset += custom_filters
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_type=content_type,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.messages.register(func, filters_preset)
|
||||
self.message_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def __del__(self):
|
||||
self._pooling = False
|
||||
def edited_message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None,
|
||||
**kwargs):
|
||||
if commands is None:
|
||||
commands = []
|
||||
if content_type is None:
|
||||
content_type = [ContentType.TEXT]
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_type=content_type,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.edited_message_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters, **kwargs):
|
||||
if commands is None:
|
||||
commands = []
|
||||
if content_type is None:
|
||||
content_type = [ContentType.TEXT]
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_type=content_type,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.channel_post_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def edited_channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters,
|
||||
**kwargs):
|
||||
if commands is None:
|
||||
commands = []
|
||||
if content_type is None:
|
||||
content_type = [ContentType.TEXT]
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_type=content_type,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.edited_channel_post_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def inline_handler(self, func=None, *custom_filters, **kwargs):
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.inline_query_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def chosen_inline_handler(self, func=None, *custom_filters, **kwargs):
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.chosen_inline_result_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def callback_query_handler(self, func=None, *custom_filters, **kwargs):
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.chosen_inline_result_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def shipping_query_handler(self, func=None, *custom_filters, **kwargs):
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.shipping_query_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def pre_checkout_query_handler(self, func=None, *custom_filters, **kwargs):
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
func=func,
|
||||
**kwargs)
|
||||
|
||||
def decorator(func):
|
||||
self.pre_checkout_query_handlers.register(func, filters_set)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def next_message(self, message: types.Message, otherwise=None, once=False,
|
||||
regexp=None, content_type=None, func=None, custom_filters=None, **kwargs):
|
||||
if content_type is None:
|
||||
content_type = []
|
||||
if custom_filters is None:
|
||||
custom_filters = []
|
||||
|
||||
filters_set = generate_default_filters(*custom_filters,
|
||||
regexp=regexp,
|
||||
content_type=content_type,
|
||||
func=func,
|
||||
**kwargs)
|
||||
self.next_step_message_handlers.register(message, otherwise, once, filters_set)
|
||||
return await self.next_step_message_handlers.wait(message)
|
||||
|
|
|
|||
|
|
@ -78,3 +78,33 @@ class ContentTypeFilter(Filter):
|
|||
|
||||
def check(self, message):
|
||||
return message.content_type in self.content_types
|
||||
|
||||
|
||||
def generate_default_filters(*args, **kwargs):
|
||||
filters_set = []
|
||||
|
||||
for name, filter_ in kwargs.items():
|
||||
if filter_ is None:
|
||||
continue
|
||||
if name == 'commands':
|
||||
if isinstance(filter_, str):
|
||||
filters_set.append(CommandsFilter([filter_]))
|
||||
else:
|
||||
filters_set.append(CommandsFilter(filter_))
|
||||
elif name == 'regexp':
|
||||
filters_set.append(RegexpFilter(filter_))
|
||||
elif name == 'content_type':
|
||||
filters_set.append(ContentTypeFilter(filter_))
|
||||
elif name == 'func':
|
||||
filters_set.append(filter_)
|
||||
|
||||
filters_set += list(args)
|
||||
|
||||
return filters_set
|
||||
|
||||
|
||||
class DefaultFilters:
|
||||
COMMANDS = 'commands'
|
||||
REGEXP = 'regexp'
|
||||
CONTENT_TYPE = 'content_type'
|
||||
FUNC = 'func'
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from asyncio import Event
|
||||
|
||||
from aiogram import types
|
||||
from .filters import check_filters
|
||||
|
||||
|
||||
|
|
@ -27,3 +30,45 @@ class Handler:
|
|||
await handler(*args, **kwargs)
|
||||
if self.once:
|
||||
break
|
||||
|
||||
|
||||
class NextStepHandler:
|
||||
def __init__(self, dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
self.handlers = {}
|
||||
|
||||
def register(self, message, otherwise=None, once=False, filters=None):
|
||||
chat_id = message.chat.id
|
||||
if chat_id not in self.handlers:
|
||||
self.handlers[chat_id] = {'event': Event(), 'filters': filters,
|
||||
'otherwise': otherwise, 'once': once}
|
||||
return True
|
||||
return False
|
||||
|
||||
async def notify(self, message):
|
||||
chat_id = message.chat.id
|
||||
if chat_id not in self.handlers:
|
||||
return False
|
||||
handler = self.handlers[chat_id]
|
||||
if handler['filters'] and not await check_filters(handler['filters'], [message], {}):
|
||||
otherwise = handler['otherwise']
|
||||
if otherwise:
|
||||
await otherwise(message)
|
||||
if not handler['once']:
|
||||
return False
|
||||
handler['message'] = message
|
||||
handler['event'].set()
|
||||
return True
|
||||
|
||||
async def wait(self, message) -> types.Message:
|
||||
chat_id = message.chat.id
|
||||
handler = self.handlers[chat_id]
|
||||
event = handler.get('event')
|
||||
|
||||
await event.wait()
|
||||
message = self.handlers[chat_id]['message']
|
||||
self.reset(chat_id)
|
||||
return message
|
||||
|
||||
def reset(self, identifier):
|
||||
del self.handlers[identifier]
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class Deserializable:
|
|||
def to_json(self):
|
||||
result = {}
|
||||
for name, attr in self.__dict__.items():
|
||||
if not attr or name == '_bot':
|
||||
if not attr or name in ['_bot', '_parent']:
|
||||
continue
|
||||
if hasattr(attr, 'to_json'):
|
||||
attr = getattr(attr, 'to_json')()
|
||||
|
|
@ -95,7 +95,7 @@ class Deserializable:
|
|||
raise ValueError("data should be a json dict or string.")
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_json())
|
||||
return str(self.to_json())
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue