Modify dispatcher. Add some handlers.

This commit is contained in:
Alex Root Junior 2017-06-02 08:50:23 +03:00
parent 0a5d1c6feb
commit 765741e122
4 changed files with 269 additions and 32 deletions

View file

@ -1,8 +1,9 @@
import asyncio
import logging
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter
from .handler import Handler
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters
from .handler import Handler, NextStepHandler
from .. import types
from ..bot import Bot
from ..types.message import ContentType
@ -18,12 +19,25 @@ class Dispatcher:
self.loop = loop
self.last_update_id = 0
self.updates = Handler(self)
self.messages = Handler(self)
self.commands = Handler(self)
self.updates.register(self.process_update)
self.updates_handler = Handler(self)
self.message_handlers = Handler(self)
self.edited_message_handlers = Handler(self)
self.channel_post_handlers = Handler(self)
self.edited_channel_post_handlers = Handler(self)
self.inline_query_handlers = Handler(self)
self.chosen_inline_result_handlers = Handler(self)
self.callback_query_handlers = Handler(self)
self.shipping_query_handlers = Handler(self)
self.pre_checkout_query_handlers = Handler(self)
self.next_step_message_handlers = NextStepHandler(self)
self.updates_handler.register(self.process_update)
# self.message_handlers.register(self._notify_next_message)
self._pooling = False
def __del__(self):
self._pooling = False
async def skip_updates(self):
@ -39,11 +53,29 @@ class Dispatcher:
async def process_updates(self, updates):
for update in updates:
self.loop.create_task(self.updates.notify(update))
self.loop.create_task(self.updates_handler.notify(update))
async def process_update(self, update):
self.last_update_id = update.update_id
if update.message:
await self.messages.notify(update.message)
if not await self.next_step_message_handlers.notify(update.message):
await self.message_handlers.notify(update.message)
if update.edited_message:
await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post:
await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post:
await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query:
await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result:
await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query:
await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query:
await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query:
await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
async def start_pooling(self, timeout=20, relax=0.1):
if self._pooling:
@ -72,8 +104,7 @@ class Dispatcher:
def stop_pooling(self):
self._pooling = False
def message_handler(self, commands=None, regexp=None, content_type=None, func=None,
custom_filters=None):
def message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None, **kwargs):
if commands is None:
commands = []
if content_type is None:
@ -81,29 +112,160 @@ class Dispatcher:
if custom_filters is None:
custom_filters = []
filters_preset = []
if commands:
if isinstance(commands, str):
commands = [commands]
filters_preset.append(CommandsFilter(commands))
if regexp:
filters_preset.append(RegexpFilter(regexp))
if content_type:
filters_preset.append(ContentTypeFilter(content_type))
if func:
filters_preset.append(func)
if custom_filters:
filters_preset += custom_filters
filters_set = generate_default_filters(*custom_filters,
commands=commands,
regexp=regexp,
content_type=content_type,
func=func,
**kwargs)
def decorator(func):
self.messages.register(func, filters_preset)
self.message_handlers.register(func, filters_set)
return func
return decorator
def __del__(self):
self._pooling = False
def edited_message_handler(self, commands=None, regexp=None, content_type=None, func=None, custom_filters=None,
**kwargs):
if commands is None:
commands = []
if content_type is None:
content_type = [ContentType.TEXT]
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
commands=commands,
regexp=regexp,
content_type=content_type,
func=func,
**kwargs)
def decorator(func):
self.edited_message_handlers.register(func, filters_set)
return func
return decorator
def channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters, **kwargs):
if commands is None:
commands = []
if content_type is None:
content_type = [ContentType.TEXT]
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
commands=commands,
regexp=regexp,
content_type=content_type,
func=func,
**kwargs)
def decorator(func):
self.channel_post_handlers.register(func, filters_set)
return func
return decorator
def edited_channel_post_handler(self, commands=None, regexp=None, content_type=None, func=None, *custom_filters,
**kwargs):
if commands is None:
commands = []
if content_type is None:
content_type = [ContentType.TEXT]
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
commands=commands,
regexp=regexp,
content_type=content_type,
func=func,
**kwargs)
def decorator(func):
self.edited_channel_post_handlers.register(func, filters_set)
return func
return decorator
def inline_handler(self, func=None, *custom_filters, **kwargs):
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
func=func,
**kwargs)
def decorator(func):
self.inline_query_handlers.register(func, filters_set)
return func
return decorator
def chosen_inline_handler(self, func=None, *custom_filters, **kwargs):
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
func=func,
**kwargs)
def decorator(func):
self.chosen_inline_result_handlers.register(func, filters_set)
return func
return decorator
def callback_query_handler(self, func=None, *custom_filters, **kwargs):
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
func=func,
**kwargs)
def decorator(func):
self.chosen_inline_result_handlers.register(func, filters_set)
return func
return decorator
def shipping_query_handler(self, func=None, *custom_filters, **kwargs):
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
func=func,
**kwargs)
def decorator(func):
self.shipping_query_handlers.register(func, filters_set)
return func
return decorator
def pre_checkout_query_handler(self, func=None, *custom_filters, **kwargs):
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
func=func,
**kwargs)
def decorator(func):
self.pre_checkout_query_handlers.register(func, filters_set)
return func
return decorator
async def next_message(self, message: types.Message, otherwise=None, once=False,
regexp=None, content_type=None, func=None, custom_filters=None, **kwargs):
if content_type is None:
content_type = []
if custom_filters is None:
custom_filters = []
filters_set = generate_default_filters(*custom_filters,
regexp=regexp,
content_type=content_type,
func=func,
**kwargs)
self.next_step_message_handlers.register(message, otherwise, once, filters_set)
return await self.next_step_message_handlers.wait(message)

View file

@ -78,3 +78,33 @@ class ContentTypeFilter(Filter):
def check(self, message):
return message.content_type in self.content_types
def generate_default_filters(*args, **kwargs):
filters_set = []
for name, filter_ in kwargs.items():
if filter_ is None:
continue
if name == 'commands':
if isinstance(filter_, str):
filters_set.append(CommandsFilter([filter_]))
else:
filters_set.append(CommandsFilter(filter_))
elif name == 'regexp':
filters_set.append(RegexpFilter(filter_))
elif name == 'content_type':
filters_set.append(ContentTypeFilter(filter_))
elif name == 'func':
filters_set.append(filter_)
filters_set += list(args)
return filters_set
class DefaultFilters:
COMMANDS = 'commands'
REGEXP = 'regexp'
CONTENT_TYPE = 'content_type'
FUNC = 'func'

View file

@ -1,3 +1,6 @@
from asyncio import Event
from aiogram import types
from .filters import check_filters
@ -27,3 +30,45 @@ class Handler:
await handler(*args, **kwargs)
if self.once:
break
class NextStepHandler:
def __init__(self, dispatcher):
self.dispatcher = dispatcher
self.handlers = {}
def register(self, message, otherwise=None, once=False, filters=None):
chat_id = message.chat.id
if chat_id not in self.handlers:
self.handlers[chat_id] = {'event': Event(), 'filters': filters,
'otherwise': otherwise, 'once': once}
return True
return False
async def notify(self, message):
chat_id = message.chat.id
if chat_id not in self.handlers:
return False
handler = self.handlers[chat_id]
if handler['filters'] and not await check_filters(handler['filters'], [message], {}):
otherwise = handler['otherwise']
if otherwise:
await otherwise(message)
if not handler['once']:
return False
handler['message'] = message
handler['event'].set()
return True
async def wait(self, message) -> types.Message:
chat_id = message.chat.id
handler = self.handlers[chat_id]
event = handler.get('event')
await event.wait()
message = self.handlers[chat_id]['message']
self.reset(chat_id)
return message
def reset(self, identifier):
del self.handlers[identifier]

View file

@ -33,7 +33,7 @@ class Deserializable:
def to_json(self):
result = {}
for name, attr in self.__dict__.items():
if not attr or name == '_bot':
if not attr or name in ['_bot', '_parent']:
continue
if hasattr(attr, 'to_json'):
attr = getattr(attr, 'to_json')()
@ -95,7 +95,7 @@ class Deserializable:
raise ValueError("data should be a json dict or string.")
def __str__(self):
return json.dumps(self.to_json())
return str(self.to_json())
def __repr__(self):
return str(self)