Merge branch 'filters-factory' into dev-2.x

# Conflicts:
#	aiogram/dispatcher/__init__.py
#	aiogram/dispatcher/filters.py
#	aiogram/dispatcher/handler.py
This commit is contained in:
Alex Root Junior 2018-06-24 03:17:01 +03:00
commit 4fbddebe8f
7 changed files with 488 additions and 308 deletions

View file

@ -7,8 +7,7 @@ import typing
from contextvars import ContextVar
from aiogram import types
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \
USER_STATE, generate_default_filters
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter
from .handler import CancelHandler, Handler, SkipHandler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
@ -38,12 +37,15 @@ class Dispatcher:
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
run_tasks_by_default: bool = False,
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False,
filters_factory=None):
if loop is None:
loop = bot.loop
if storage is None:
storage = DisabledStorage()
if filters_factory is None:
filters_factory = FiltersFactory(self)
self.bot: Bot = bot
self.loop = loop
@ -55,6 +57,7 @@ class Dispatcher:
self.last_update_id = 0
self.filters_factory: FiltersFactory = filters_factory
self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
@ -75,6 +78,27 @@ class Dispatcher:
self._closed = True
self._close_waiter = loop.create_future()
filters_factory.bind(filters.CommandsFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(filters.RegexpFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers
])
filters_factory.bind(filters.RegexpCommandsFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(filters.ContentTypeFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
])
filters_factory.bind(filters.StateFilter)
filters_factory.bind(filters.ExceptionsFilter, event_handlers=[
self.errors_handlers
])
def __del__(self):
self.stop_polling()
@ -251,7 +275,7 @@ class Dispatcher:
if relax:
await asyncio.sleep(relax)
finally:
self._close_waiter.set_result(None)
self._close_waiter._set_result(None)
log.warning('Polling is stopped.')
async def _process_polling_updates(self, updates):
@ -298,8 +322,8 @@ class Dispatcher:
"""
return self._polling
def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None,
state=None, custom_filters=None, run_task=None, **kwargs):
def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
func=None, state=None, run_task=None, **kwargs):
"""
Register handler for message
@ -328,17 +352,17 @@ class Dispatcher:
"""
if content_types is None:
content_types = ContentType.TEXT
if custom_filters is None:
custom_filters = []
if func is not None:
custom_filters = list(custom_filters)
custom_filters.append(func)
filters_set = generate_default_filters(self,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
func=func,
state=state,
**kwargs)
filters_set = self.filters_factory.resolve(self.message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None,
@ -414,9 +438,9 @@ class Dispatcher:
"""
def decorator(callback):
self.register_message_handler(callback,
self.register_message_handler(callback, *custom_filters,
commands=commands, regexp=regexp, content_types=content_types,
func=func, state=state, custom_filters=custom_filters, run_task=run_task,
func=func, state=state, run_task=run_task,
**kwargs)
return callback

View file

@ -1,289 +0,0 @@
import asyncio
import inspect
import re
from ..types import CallbackQuery, ContentType, Message
from ..utils import context
from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args):
"""
Helper for executing filter
:param filter_:
:param args:
:param kwargs:
:return:
"""
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args)
else:
return filter_(*args)
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
if not f:
return False
return True
class Filter:
"""
Base class for filters
"""
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
def check(self, *args, **kwargs):
raise NotImplementedError
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __aiter__(self):
return None
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass
class AnyFilter(AsyncFilter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(AsyncFilter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter):
"""
Check commands in message
"""
def __init__(self, commands):
self.commands = commands
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
if command not in self.commands:
return False
return True
class RegexpFilter(Filter):
"""
Regexp filter for messages
"""
def __init__(self, regexp):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
def check(self, obj):
if isinstance(obj, Message) and obj.text:
return bool(self.regexp.search(obj.text))
elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data))
return False
class RegexpCommandsFilter(AsyncFilter):
"""
Check commands by regexp in message
"""
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
message.conf['regexp_command'] = search
return True
return False
class ContentTypeFilter(Filter):
"""
Check message content type
"""
def __init__(self, content_types):
self.content_types = content_types
def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
class CancelFilter(Filter):
"""
Find cancel in message text
"""
def __init__(self, cancel_set=None):
if cancel_set is None:
cancel_set = ['/cancel', 'cancel', 'cancel.']
self.cancel_set = cancel_set
def check(self, message):
if message.text:
return message.text.lower() in self.cancel_set
class StateFilter(AsyncFilter):
"""
Check user state
"""
def __init__(self, dispatcher, state):
self.dispatcher = dispatcher
self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if self.state == '*':
return True
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False
class StatesListFilter(StateFilter):
"""
List of states
"""
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
class ExceptionsFilter(Filter):
"""
Filter for exceptions
"""
def __init__(self, exception):
self.exception = exception
def check(self, dispatcher, update, exception):
return isinstance(exception, self.exception)
def generate_default_filters(dispatcher, *args, **kwargs):
"""
Prepare filters
:param dispatcher:
:param args:
:param kwargs:
:return:
"""
filters_set = []
for name, filter_ in kwargs.items():
if filter_ is None and name != DefaultFilters.STATE:
continue
if name == DefaultFilters.COMMANDS:
if isinstance(filter_, str):
filters_set.append(CommandsFilter([filter_]))
else:
filters_set.append(CommandsFilter(filter_))
elif name == DefaultFilters.REGEXP:
filters_set.append(RegexpFilter(filter_))
elif name == DefaultFilters.CONTENT_TYPES:
filters_set.append(ContentTypeFilter(filter_))
elif name == DefaultFilters.FUNC:
filters_set.append(filter_)
elif name == DefaultFilters.STATE:
if isinstance(filter_, (list, set, tuple)):
filters_set.append(StatesListFilter(dispatcher, filter_))
else:
filters_set.append(StateFilter(dispatcher, filter_))
elif isinstance(filter_, Filter):
filters_set.append(filter_)
filters_set += list(args)
return filters_set
class DefaultFilters(Helper):
mode = HelperMode.snake_case
COMMANDS = Item() # commands
REGEXP = Item() # regexp
CONTENT_TYPES = Item() # content_type
FUNC = Item() # func
STATE = Item() # state

View file

@ -0,0 +1,24 @@
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \
RegexpFilter, StateFilter, StatesListFilter
from .factory import FiltersFactory
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters
__all__ = [
'AbstractFilter',
'AnyFilter',
'AsyncFilter',
'BaseFilter',
'CommandsFilter',
'ContentTypeFilter',
'ExceptionsFilter',
'Filter',
'FilterRecord',
'FiltersFactory',
'NotFilter',
'RegexpCommandsFilter',
'RegexpFilter',
'StateFilter',
'StatesListFilter',
'check_filter',
'check_filters'
]

View file

@ -0,0 +1,187 @@
import asyncio
import re
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
from aiogram.types import CallbackQuery, ContentType, Message
from aiogram.utils import context
USER_STATE = 'USER_STATE'
class AnyFilter(Filter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
super().__init__()
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(Filter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
super().__init__()
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(BaseFilter):
"""
Check commands in message
"""
key = 'commands'
def __init__(self, dispatcher, commands):
super().__init__(dispatcher)
self.commands = commands
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
if command not in self.commands:
return False
return True
class RegexpFilter(BaseFilter):
"""
Regexp filter for messages and callback query
"""
key = 'regexp'
def __init__(self, dispatcher, regexp):
super().__init__(dispatcher)
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
async def check(self, obj):
if isinstance(obj, Message) and obj.text:
return bool(self.regexp.search(obj.text))
elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data))
return False
class RegexpCommandsFilter(BaseFilter):
"""
Check commands by regexp in message
"""
key = 'regexp_commands'
def __init__(self, dispatcher, regexp_commands):
super().__init__(dispatcher)
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
message.conf['regexp_command'] = search
return True
return False
class ContentTypeFilter(BaseFilter):
"""
Check message content type
"""
key = 'content_types'
def __init__(self, dispatcher, content_types):
super().__init__(dispatcher)
self.content_types = content_types
async def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
class StateFilter(BaseFilter):
"""
Check user state
"""
key = 'state'
def __init__(self, dispatcher, state):
super().__init__(dispatcher)
if isinstance(state, str):
state = (state,)
self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if self.state == '*':
return True
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
class StatesListFilter(StateFilter):
"""
List of states
"""
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
class ExceptionsFilter(BaseFilter):
"""
Filter for exceptions
"""
key = 'exception'
def __init__(self, dispatcher, exception):
super().__init__(dispatcher)
self.exception = exception
async def check(self, dispatcher, update, exception):
try:
raise exception
except self.exception:
return True
except:
return False

View file

@ -0,0 +1,71 @@
import typing
from .filters import AbstractFilter, FilterRecord
from ..handler import Handler
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
# TODO: provide excluding event handlers
# TODO: move check_filter/check_filters functions to FiltersFactory class
class FiltersFactory:
"""
Default filters factory
"""
def __init__(self, dispatcher):
self._dispatcher = dispatcher
self._registered: typing.List[FilterRecord] = []
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None):
"""
Register filter
:param callback: callable or subclass of :obj:`AbstractFilter`
:param validator: custom validator.
:param event_handlers: list of instances of :obj:`Handler`
"""
record = FilterRecord(callback, validator, event_handlers)
self._registered.append(record)
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
"""
Unregister callback
:param callback: callable of subclass of :obj:`AbstractFilter`
"""
for record in self._registered:
if record.callback == callback:
self._registered.remove(record)
def resolve(self, event_handler, *custom_filters, **full_config
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
"""
Resolve filters to filters-set
:param event_handler:
:param custom_filters:
:param full_config:
:return:
"""
filters_set = []
if custom_filters:
filters_set.extend(custom_filters)
if full_config:
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler,
{k: v for k, v in full_config.items() if v is not None}))
return filters_set
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator:
for record in self._registered:
if not full_config:
break
filter_ = record.resolve(dispatcher, event_handler, full_config)
if filter_:
yield filter_
if full_config:
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')

View file

@ -0,0 +1,162 @@
import abc
import inspect
import typing
from ..handler import Handler
from ...types.base import TelegramObject
from ...utils.deprecated import deprecated
async def check_filter(filter_, args):
"""
Helper for executing filter
:param filter_:
:param args:
:return:
"""
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, (Filter, AbstractFilter)):
return await filter_(*args)
else:
return filter_(*args)
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
if not f:
return False
return True
class FilterRecord:
"""
Filters record for factory
"""
def __init__(self, callback: typing.Callable,
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
self.callback = callback
self.event_handlers = event_handlers
if validator is not None:
if not callable(validator):
raise TypeError(f"validator must be callable, not {type(validator)}")
self.resolver = validator
elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate
else:
raise RuntimeError('validator is required!')
def resolve(self, dispatcher, event_observer, full_config):
if not self._check_event_handler(event_observer):
return
config = self.resolver(full_config)
if config:
return self.callback(dispatcher, **config)
def _check_event_handler(self, event_handler) -> bool:
if not self.event_handlers:
return True
return event_handler in self.event_handlers
class AbstractFilter(abc.ABC):
"""
Abstract class for custom filters
"""
key = None
def __init__(self, dispatcher, **config):
self.dispatcher = dispatcher
self.config = config
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Validate and parse config
:param full_config:
:return: config
"""
pass
@abc.abstractmethod
async def check(self, *args) -> bool:
"""
Check object
:param args:
:return:
"""
pass
async def __call__(self, obj: TelegramObject) -> bool:
return await self.check(obj)
class BaseFilter(AbstractFilter):
"""
Abstract class for filters with default validator
"""
@property
@abc.abstractmethod
def key(self):
pass
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
if cls.key is not None and cls.key in full_config:
return {cls.key: full_config.pop(cls.key)}
class Filter(abc.ABC):
"""
Base class for filters
Subclasses of this class can't be used with FiltersFactory by default.
(Backward capability)
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
@abc.abstractmethod
def check(self, *args, **kwargs):
pass
@deprecated
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __aiter__(self):
return None
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass

View file

@ -1,4 +1,3 @@
from .filters import check_filters
from ..utils import context
@ -57,6 +56,8 @@ class Handler:
:param args:
:return:
"""
from .filters import check_filters
results = []
if self.middleware_key: