From 8ada376a3c99ee586094a54252150f9912dfbd87 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Tue, 10 Apr 2018 03:33:24 +0300 Subject: [PATCH] Remake filters. Implemented filters factory. --- aiogram/dispatcher/__init__.py | 41 ++- aiogram/dispatcher/filters.py | 420 ------------------------- aiogram/dispatcher/filters/__init__.py | 24 ++ aiogram/dispatcher/filters/builtin.py | 187 +++++++++++ aiogram/dispatcher/filters/factory.py | 71 +++++ aiogram/dispatcher/filters/filters.py | 162 ++++++++++ aiogram/dispatcher/handler.py | 3 +- 7 files changed, 474 insertions(+), 434 deletions(-) delete mode 100644 aiogram/dispatcher/filters.py create mode 100644 aiogram/dispatcher/filters/__init__.py create mode 100644 aiogram/dispatcher/filters/builtin.py create mode 100644 aiogram/dispatcher/filters/factory.py create mode 100644 aiogram/dispatcher/filters/filters.py diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 280a86f1..24972903 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -5,8 +5,7 @@ import logging import time import typing -from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \ - generate_default_filters +from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter from .handler import CancelHandler, Handler, SkipHandler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -78,11 +77,26 @@ class Dispatcher: self._closed = True self._close_waiter = loop.create_future() - filters_factory.bind(filters.CommandsFilter, 'commands') - filters_factory.bind(filters.RegexpFilter, 'regexp') - filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands') - filters_factory.bind(filters.ContentTypeFilter, 'content_types') - filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True) + filters_factory.bind(filters.CommandsFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(filters.RegexpFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + self.callback_query_handlers + + ]) + filters_factory.bind(filters.RegexpCommandsFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(filters.ContentTypeFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + ]) + filters_factory.bind(filters.StateFilter) + filters_factory.bind(filters.ExceptionsFilter, event_handlers=[ + self.errors_handlers + ]) def __del__(self): self.stop_polling() @@ -355,12 +369,13 @@ class Dispatcher: custom_filters = list(custom_filters) custom_filters.append(func) - filters_set = self.filters_factory.parse(*custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - state=state, - **kwargs) + filters_set = self.filters_factory.resolve(self.message_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None, diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py deleted file mode 100644 index 8ceab867..00000000 --- a/aiogram/dispatcher/filters.py +++ /dev/null @@ -1,420 +0,0 @@ -import asyncio -import copy -import inspect -import re -import typing - -from ..types import CallbackQuery, ContentType, Message -from ..utils import context -from ..utils.helper import Helper, HelperMode, Item - -USER_STATE = 'USER_STATE' - - -class FiltersFactory: - def __init__(self, dispatcher): - self._dispatcher = dispatcher - self._filters = [] - - @property - def _default_filters(self): - return tuple(filter(lambda item: item[-1], self._filters)) - - def bind(self, filter_, *args, default=False, with_dispatcher=False): - self._filters.append(FilterConfig(self._dispatcher, filter_, default=default, with_dispatcher=with_dispatcher)) - - def unbind(self, filter_): - for item in self._filters: - if filter_ is item[0]: - self._filters.remove(item) - return True - raise ValueError(f'{filter_} is not binded.') - - def replace(self, original, new): - for item in self._filters: - if original is item[0]: - item[0] = new - return True - raise ValueError(f'{original} is not binded.') - - def parse(self, *args, **kwargs): - """ - Generate filters list - - :param args: - :param kwargs: - :return: - """ - used = [] - filters = [] - - filters.extend(args) - - # Registered filters filters - for filterconfig in self._filters: - pass - - # Not registered filters - for key, filter_ in kwargs.items(): - if isinstance(filter_, Filter): - filters.append(filter_) - used.append(filter_.__class__) - else: - raise ValueError(f"Unknown filter with key '{key}'") - - return filters - - -class FilterConfig: - def __init__(self, dispatcher, filter_: typing.Callable, - default: bool = False, with_dispatcher: bool = False, - args: typing.Union[tuple, set, list] = ()): - self.dispatcher = dispatcher - self.filter = filter_ - self.default = default - self.with_dispatcher = with_dispatcher - self.args = args - - def _check_list(self, config): - result = {} - accept = True - - for item in self.args: - value = config.pop(item, None) - if value is None: - accept = False - break - result[item] = value - - return accept or self.default, result - - def _check_dict(self, config): - result = {} - accept = True - - for key, type_ in self.args: - value = config.pop(key, None) - if value is None: - accept = False - break - if type_ is bool: - - return accept or self.default, result - - def check(self, config): - if isinstance(config, dict): - return self._check_dict(config) - else: - return self._check_list(config) - - def parse(self, config): - pass - - def configure(self, config=None): - if config is None: - config = {} - if self.with_dispatcher: - if config: - config = copy.deepcopy(config) - config['dispatcher'] = self.dispatcher - return self.filter(**config) - - -async def check_filter(filter_, args): - """ - Helper for executing filter - - :param filter_: - :param args: - :param kwargs: - :return: - """ - if not callable(filter_): - raise TypeError('Filter must be callable and/or awaitable!') - - if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): - return await filter_(*args) - else: - return filter_(*args) - - -async def check_filters(filters, args): - """ - Check list of filters - - :param filters: - :param args: - :return: - """ - if filters is not None: - for filter_ in filters: - f = await check_filter(filter_, args) - if not f: - return False - return True - - -class Filter: - """ - Base class for filters - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - - def __call__(self, *args, **kwargs): - return self.check(*args, **kwargs) - - def check(self, *args, **kwargs): - raise NotImplementedError - - -class AsyncFilter(Filter): - """ - Base class for asynchronous filters - """ - - def __aiter__(self): - return None - - def __await__(self): - return self.check - - async def check(self, *args, **kwargs): - pass - - -class AnyFilter(AsyncFilter): - """ - One filter from many - """ - - def __init__(self, *filters: callable): - self.filters = filters - super().__init__() - - async def check(self, *args): - f = (check_filter(filter_, args) for filter_ in self.filters) - return any(await asyncio.gather(*f)) - - -class NotFilter(AsyncFilter): - """ - Reverse filter - """ - - def __init__(self, filter_: callable): - self.filter = filter_ - super().__init__() - - async def check(self, *args): - return not await check_filter(self.filter, args) - - -class CommandsFilter(AsyncFilter): - """ - Check commands in message - """ - - def __init__(self, commands): - self.commands = commands - super().__init__() - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - if command not in self.commands: - return False - - return True - - -class RegexpFilter(Filter): - """ - Regexp filter for messages - """ - - def __init__(self, regexp): - self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) - super().__init__() - - def check(self, obj): - if isinstance(obj, Message) and obj.text: - return bool(self.regexp.search(obj.text)) - elif isinstance(obj, CallbackQuery) and obj.data: - return bool(self.regexp.search(obj.data)) - return False - - -class RegexpCommandsFilter(AsyncFilter): - """ - Check commands by regexp in message - """ - - def __init__(self, regexp_commands): - self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] - super().__init__() - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - for command in self.regexp_commands: - search = command.search(message.text) - if search: - message.conf['regexp_command'] = search - return True - return False - - -class ContentTypeFilter(Filter): - """ - Check message content type - """ - - def __init__(self, content_types): - self.content_types = content_types - super().__init__() - - def check(self, message): - return ContentType.ANY[0] in self.content_types or \ - message.content_type in self.content_types - - -class CancelFilter(Filter): - """ - Find cancel in message text - """ - - def __init__(self, cancel_set=None): - if cancel_set is None: - cancel_set = ['/cancel', 'cancel', 'cancel.'] - self.cancel_set = cancel_set - super().__init__() - - def check(self, message): - if message.text: - return message.text.lower() in self.cancel_set - - -class StateFilter(AsyncFilter): - """ - Check user state - """ - - def __init__(self, dispatcher, state): - self.dispatcher = dispatcher - if isinstance(state, str): - state = (state,) - self.state = state - super().__init__() - - def get_target(self, obj): - return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) - - async def check(self, obj): - if self.state == '*': - return True - - if context.check_value(USER_STATE): - context_state = context.get_value(USER_STATE) - return self.state == context_state - else: - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state - return False - - -class StatesListFilter(StateFilter): - """ - List of states - """ - - async def check(self, obj): - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state - return False - - -class ExceptionsFilter(Filter): - """ - Filter for exceptions - """ - - def __init__(self, exception): - self.exception = exception - super().__init__() - - def check(self, dispatcher, update, exception): - try: - raise exception - except self.exception: - return True - except: - return False - - -def generate_default_filters(dispatcher, *args, **kwargs): - """ - Prepare filters - - :param dispatcher: - :param args: - :param kwargs: - :return: - """ - filters_set = [] - - for name, filter_ in kwargs.items(): - if filter_ is None and name != DefaultFilters.STATE: - continue - if name == DefaultFilters.COMMANDS: - if isinstance(filter_, str): - filters_set.append(CommandsFilter([filter_])) - else: - filters_set.append(CommandsFilter(filter_)) - elif name == DefaultFilters.REGEXP: - filters_set.append(RegexpFilter(filter_)) - elif name == DefaultFilters.CONTENT_TYPES: - filters_set.append(ContentTypeFilter(filter_)) - elif name == DefaultFilters.FUNC: - filters_set.append(filter_) - elif name == DefaultFilters.STATE: - if isinstance(filter_, (list, set, tuple)): - filters_set.append(StatesListFilter(dispatcher, filter_)) - else: - filters_set.append(StateFilter(dispatcher, filter_)) - elif isinstance(filter_, Filter): - filters_set.append(filter_) - - filters_set += list(args) - - return filters_set - - -class DefaultFilters(Helper): - mode = HelperMode.snake_case - - COMMANDS = Item() # commands - REGEXP = Item() # regexp - CONTENT_TYPES = Item() # content_type - FUNC = Item() # func - STATE = Item() # state diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py new file mode 100644 index 00000000..70d99531 --- /dev/null +++ b/aiogram/dispatcher/filters/__init__.py @@ -0,0 +1,24 @@ +from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \ + RegexpFilter, StateFilter, StatesListFilter +from .factory import FiltersFactory +from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters + +__all__ = [ + 'AbstractFilter', + 'AnyFilter', + 'AsyncFilter', + 'BaseFilter', + 'CommandsFilter', + 'ContentTypeFilter', + 'ExceptionsFilter', + 'Filter', + 'FilterRecord', + 'FiltersFactory', + 'NotFilter', + 'RegexpCommandsFilter', + 'RegexpFilter', + 'StateFilter', + 'StatesListFilter', + 'check_filter', + 'check_filters' +] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py new file mode 100644 index 00000000..e58ae383 --- /dev/null +++ b/aiogram/dispatcher/filters/builtin.py @@ -0,0 +1,187 @@ +import asyncio +import re + +from aiogram.dispatcher.filters import BaseFilter, Filter, check_filter +from aiogram.types import CallbackQuery, ContentType, Message +from aiogram.utils import context + +USER_STATE = 'USER_STATE' + + +class AnyFilter(Filter): + """ + One filter from many + """ + + def __init__(self, *filters: callable): + self.filters = filters + super().__init__() + + async def check(self, *args): + f = (check_filter(filter_, args) for filter_ in self.filters) + return any(await asyncio.gather(*f)) + + +class NotFilter(Filter): + """ + Reverse filter + """ + + def __init__(self, filter_: callable): + self.filter = filter_ + super().__init__() + + async def check(self, *args): + return not await check_filter(self.filter, args) + + +class CommandsFilter(BaseFilter): + """ + Check commands in message + """ + key = 'commands' + + def __init__(self, dispatcher, commands): + super().__init__(dispatcher) + self.commands = commands + + async def check(self, message): + if not message.is_command(): + return False + + command = message.text.split()[0][1:] + command, _, mention = command.partition('@') + + if mention and mention != (await message.bot.me).username: + return False + + if command not in self.commands: + return False + + return True + + +class RegexpFilter(BaseFilter): + """ + Regexp filter for messages and callback query + """ + key = 'regexp' + + def __init__(self, dispatcher, regexp): + super().__init__(dispatcher) + self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + + async def check(self, obj): + if isinstance(obj, Message) and obj.text: + return bool(self.regexp.search(obj.text)) + elif isinstance(obj, CallbackQuery) and obj.data: + return bool(self.regexp.search(obj.data)) + return False + + +class RegexpCommandsFilter(BaseFilter): + """ + Check commands by regexp in message + """ + + key = 'regexp_commands' + + def __init__(self, dispatcher, regexp_commands): + super().__init__(dispatcher) + self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] + + async def check(self, message): + if not message.is_command(): + return False + + command = message.text.split()[0][1:] + command, _, mention = command.partition('@') + + if mention and mention != (await message.bot.me).username: + return False + + for command in self.regexp_commands: + search = command.search(message.text) + if search: + message.conf['regexp_command'] = search + return True + return False + + +class ContentTypeFilter(BaseFilter): + """ + Check message content type + """ + + key = 'content_types' + + def __init__(self, dispatcher, content_types): + super().__init__(dispatcher) + self.content_types = content_types + + async def check(self, message): + return ContentType.ANY[0] in self.content_types or \ + message.content_type in self.content_types + + +class StateFilter(BaseFilter): + """ + Check user state + """ + key = 'state' + + def __init__(self, dispatcher, state): + super().__init__(dispatcher) + if isinstance(state, str): + state = (state,) + self.state = state + + def get_target(self, obj): + return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + + async def check(self, obj): + if self.state == '*': + return True + + if context.check_value(USER_STATE): + context_state = context.get_value(USER_STATE) + return self.state == context_state + else: + chat, user = self.get_target(obj) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + return False + + +class StatesListFilter(StateFilter): + """ + List of states + """ + + async def check(self, obj): + chat, user = self.get_target(obj) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + return False + + +class ExceptionsFilter(BaseFilter): + """ + Filter for exceptions + """ + + key = 'exception' + + def __init__(self, dispatcher, exception): + super().__init__(dispatcher) + self.exception = exception + + async def check(self, dispatcher, update, exception): + try: + raise exception + except self.exception: + return True + except: + return False diff --git a/aiogram/dispatcher/filters/factory.py b/aiogram/dispatcher/filters/factory.py new file mode 100644 index 00000000..cc5d6a0c --- /dev/null +++ b/aiogram/dispatcher/filters/factory.py @@ -0,0 +1,71 @@ +import typing + +from .filters import AbstractFilter, FilterRecord +from ..handler import Handler + + +# TODO: provide to set default filters (Like state. It will be always be added to filters set) +# TODO: provide excluding event handlers +# TODO: move check_filter/check_filters functions to FiltersFactory class + +class FiltersFactory: + """ + Default filters factory + """ + + def __init__(self, dispatcher): + self._dispatcher = dispatcher + self._registered: typing.List[FilterRecord] = [] + + def bind(self, callback: typing.Union[typing.Callable, AbstractFilter], + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.List[Handler]] = None): + """ + Register filter + + :param callback: callable or subclass of :obj:`AbstractFilter` + :param validator: custom validator. + :param event_handlers: list of instances of :obj:`Handler` + """ + record = FilterRecord(callback, validator, event_handlers) + self._registered.append(record) + + def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]): + """ + Unregister callback + + :param callback: callable of subclass of :obj:`AbstractFilter` + """ + for record in self._registered: + if record.callback == callback: + self._registered.remove(record) + + def resolve(self, event_handler, *custom_filters, **full_config + ) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]: + """ + Resolve filters to filters-set + + :param event_handler: + :param custom_filters: + :param full_config: + :return: + """ + filters_set = [] + if custom_filters: + filters_set.extend(custom_filters) + if full_config: + filters_set.extend(self._resolve_registered(self._dispatcher, event_handler, + {k: v for k, v in full_config.items() if v is not None})) + return filters_set + + def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator: + for record in self._registered: + if not full_config: + break + + filter_ = record.resolve(dispatcher, event_handler, full_config) + if filter_: + yield filter_ + + if full_config: + raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'') diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py new file mode 100644 index 00000000..98e00d9f --- /dev/null +++ b/aiogram/dispatcher/filters/filters.py @@ -0,0 +1,162 @@ +import abc +import inspect +import typing + +from ..handler import Handler +from ...types.base import TelegramObject +from ...utils.deprecated import deprecated + + +async def check_filter(filter_, args): + """ + Helper for executing filter + + :param filter_: + :param args: + :return: + """ + if not callable(filter_): + raise TypeError('Filter must be callable and/or awaitable!') + + if inspect.isawaitable(filter_) \ + or inspect.iscoroutinefunction(filter_) \ + or isinstance(filter_, (Filter, AbstractFilter)): + return await filter_(*args) + else: + return filter_(*args) + + +async def check_filters(filters, args): + """ + Check list of filters + + :param filters: + :param args: + :return: + """ + if filters is not None: + for filter_ in filters: + f = await check_filter(filter_, args) + if not f: + return False + return True + + +class FilterRecord: + """ + Filters record for factory + """ + + def __init__(self, callback: typing.Callable, + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.Iterable[Handler]] = None): + self.callback = callback + self.event_handlers = event_handlers + if validator is not None: + if not callable(validator): + raise TypeError(f"validator must be callable, not {type(validator)}") + self.resolver = validator + elif issubclass(callback, AbstractFilter): + self.resolver = callback.validate + else: + raise RuntimeError('validator is required!') + + def resolve(self, dispatcher, event_observer, full_config): + if not self._check_event_handler(event_observer): + return + config = self.resolver(full_config) + if config: + return self.callback(dispatcher, **config) + + def _check_event_handler(self, event_handler) -> bool: + if not self.event_handlers: + return True + return event_handler in self.event_handlers + + +class AbstractFilter(abc.ABC): + """ + Abstract class for custom filters + """ + + key = None + + def __init__(self, dispatcher, **config): + self.dispatcher = dispatcher + self.config = config + + @classmethod + @abc.abstractmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + """ + Validate and parse config + + :param full_config: + :return: config + """ + pass + + @abc.abstractmethod + async def check(self, *args) -> bool: + """ + Check object + + :param args: + :return: + """ + pass + + async def __call__(self, obj: TelegramObject) -> bool: + return await self.check(obj) + + +class BaseFilter(AbstractFilter): + """ + Abstract class for filters with default validator + """ + + @property + @abc.abstractmethod + def key(self): + pass + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + if cls.key is not None and cls.key in full_config: + return {cls.key: full_config.pop(cls.key)} + + +class Filter(abc.ABC): + """ + Base class for filters + Subclasses of this class can't be used with FiltersFactory by default. + + (Backward capability) + """ + + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + def __call__(self, *args, **kwargs): + return self.check(*args, **kwargs) + + @abc.abstractmethod + def check(self, *args, **kwargs): + pass + + +@deprecated +class AsyncFilter(Filter): + """ + Base class for asynchronous filters + """ + + def __aiter__(self): + return None + + def __await__(self): + return self.check + + async def check(self, *args, **kwargs): + pass diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 139c6011..490c4001 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,5 +1,4 @@ from aiogram.utils import context -from .filters import check_filters class SkipHandler(BaseException): @@ -57,6 +56,8 @@ class Handler: :param args: :return: """ + from .filters import check_filters + results = [] if self.middleware_key: