diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 0ca75663..e49a9f30 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -7,8 +7,7 @@ import typing from contextvars import ContextVar from aiogram import types -from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \ - USER_STATE, generate_default_filters +from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter from .handler import CancelHandler, Handler, SkipHandler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -38,12 +37,15 @@ class Dispatcher: def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, run_tasks_by_default: bool = False, - throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False): + throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False, + filters_factory=None): if loop is None: loop = bot.loop if storage is None: storage = DisabledStorage() + if filters_factory is None: + filters_factory = FiltersFactory(self) self.bot: Bot = bot self.loop = loop @@ -55,6 +57,7 @@ class Dispatcher: self.last_update_id = 0 + self.filters_factory: FiltersFactory = filters_factory self.updates_handler = Handler(self, middleware_key='update') self.message_handlers = Handler(self, middleware_key='message') self.edited_message_handlers = Handler(self, middleware_key='edited_message') @@ -75,6 +78,27 @@ class Dispatcher: self._closed = True self._close_waiter = loop.create_future() + filters_factory.bind(filters.CommandsFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(filters.RegexpFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + self.callback_query_handlers + + ]) + filters_factory.bind(filters.RegexpCommandsFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(filters.ContentTypeFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + ]) + filters_factory.bind(filters.StateFilter) + filters_factory.bind(filters.ExceptionsFilter, event_handlers=[ + self.errors_handlers + ]) + def __del__(self): self.stop_polling() @@ -251,7 +275,7 @@ class Dispatcher: if relax: await asyncio.sleep(relax) finally: - self._close_waiter.set_result(None) + self._close_waiter._set_result(None) log.warning('Polling is stopped.') async def _process_polling_updates(self, updates): @@ -298,8 +322,8 @@ class Dispatcher: """ return self._polling - def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None, - state=None, custom_filters=None, run_task=None, **kwargs): + def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None, + func=None, state=None, run_task=None, **kwargs): """ Register handler for message @@ -328,17 +352,17 @@ class Dispatcher: """ if content_types is None: content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] + if func is not None: + custom_filters = list(custom_filters) + custom_filters.append(func) - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) + filters_set = self.filters_factory.resolve(self.message_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None, @@ -414,9 +438,9 @@ class Dispatcher: """ def decorator(callback): - self.register_message_handler(callback, + self.register_message_handler(callback, *custom_filters, commands=commands, regexp=regexp, content_types=content_types, - func=func, state=state, custom_filters=custom_filters, run_task=run_task, + func=func, state=state, run_task=run_task, **kwargs) return callback diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index fb3f04a0..e69de29b 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -1,289 +0,0 @@ -import asyncio -import inspect -import re - -from ..types import CallbackQuery, ContentType, Message -from ..utils import context -from ..utils.helper import Helper, HelperMode, Item - -USER_STATE = 'USER_STATE' - - -async def check_filter(filter_, args): - """ - Helper for executing filter - - :param filter_: - :param args: - :param kwargs: - :return: - """ - if not callable(filter_): - raise TypeError('Filter must be callable and/or awaitable!') - - if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): - return await filter_(*args) - else: - return filter_(*args) - - -async def check_filters(filters, args): - """ - Check list of filters - - :param filters: - :param args: - :return: - """ - if filters is not None: - for filter_ in filters: - f = await check_filter(filter_, args) - if not f: - return False - return True - - -class Filter: - """ - Base class for filters - """ - - def __call__(self, *args, **kwargs): - return self.check(*args, **kwargs) - - def check(self, *args, **kwargs): - raise NotImplementedError - - -class AsyncFilter(Filter): - """ - Base class for asynchronous filters - """ - - def __aiter__(self): - return None - - def __await__(self): - return self.check - - async def check(self, *args, **kwargs): - pass - - -class AnyFilter(AsyncFilter): - """ - One filter from many - """ - - def __init__(self, *filters: callable): - self.filters = filters - - async def check(self, *args): - f = (check_filter(filter_, args) for filter_ in self.filters) - return any(await asyncio.gather(*f)) - - -class NotFilter(AsyncFilter): - """ - Reverse filter - """ - - def __init__(self, filter_: callable): - self.filter = filter_ - - async def check(self, *args): - return not await check_filter(self.filter, args) - - -class CommandsFilter(AsyncFilter): - """ - Check commands in message - """ - - def __init__(self, commands): - self.commands = commands - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - if command not in self.commands: - return False - - return True - - -class RegexpFilter(Filter): - """ - Regexp filter for messages - """ - - def __init__(self, regexp): - self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) - - def check(self, obj): - if isinstance(obj, Message) and obj.text: - return bool(self.regexp.search(obj.text)) - elif isinstance(obj, CallbackQuery) and obj.data: - return bool(self.regexp.search(obj.data)) - return False - - -class RegexpCommandsFilter(AsyncFilter): - """ - Check commands by regexp in message - """ - - def __init__(self, regexp_commands): - self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - for command in self.regexp_commands: - search = command.search(message.text) - if search: - message.conf['regexp_command'] = search - return True - return False - - -class ContentTypeFilter(Filter): - """ - Check message content type - """ - - def __init__(self, content_types): - self.content_types = content_types - - def check(self, message): - return ContentType.ANY[0] in self.content_types or \ - message.content_type in self.content_types - - -class CancelFilter(Filter): - """ - Find cancel in message text - """ - - def __init__(self, cancel_set=None): - if cancel_set is None: - cancel_set = ['/cancel', 'cancel', 'cancel.'] - self.cancel_set = cancel_set - - def check(self, message): - if message.text: - return message.text.lower() in self.cancel_set - - -class StateFilter(AsyncFilter): - """ - Check user state - """ - - def __init__(self, dispatcher, state): - self.dispatcher = dispatcher - self.state = state - - def get_target(self, obj): - return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) - - async def check(self, obj): - if self.state == '*': - return True - - if context.check_value(USER_STATE): - context_state = context.get_value(USER_STATE) - return self.state == context_state - else: - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state - return False - - -class StatesListFilter(StateFilter): - """ - List of states - """ - - async def check(self, obj): - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state - return False - - -class ExceptionsFilter(Filter): - """ - Filter for exceptions - """ - - def __init__(self, exception): - self.exception = exception - - def check(self, dispatcher, update, exception): - return isinstance(exception, self.exception) - - -def generate_default_filters(dispatcher, *args, **kwargs): - """ - Prepare filters - - :param dispatcher: - :param args: - :param kwargs: - :return: - """ - filters_set = [] - - for name, filter_ in kwargs.items(): - if filter_ is None and name != DefaultFilters.STATE: - continue - if name == DefaultFilters.COMMANDS: - if isinstance(filter_, str): - filters_set.append(CommandsFilter([filter_])) - else: - filters_set.append(CommandsFilter(filter_)) - elif name == DefaultFilters.REGEXP: - filters_set.append(RegexpFilter(filter_)) - elif name == DefaultFilters.CONTENT_TYPES: - filters_set.append(ContentTypeFilter(filter_)) - elif name == DefaultFilters.FUNC: - filters_set.append(filter_) - elif name == DefaultFilters.STATE: - if isinstance(filter_, (list, set, tuple)): - filters_set.append(StatesListFilter(dispatcher, filter_)) - else: - filters_set.append(StateFilter(dispatcher, filter_)) - elif isinstance(filter_, Filter): - filters_set.append(filter_) - - filters_set += list(args) - - return filters_set - - -class DefaultFilters(Helper): - mode = HelperMode.snake_case - - COMMANDS = Item() # commands - REGEXP = Item() # regexp - CONTENT_TYPES = Item() # content_type - FUNC = Item() # func - STATE = Item() # state diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py new file mode 100644 index 00000000..70d99531 --- /dev/null +++ b/aiogram/dispatcher/filters/__init__.py @@ -0,0 +1,24 @@ +from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \ + RegexpFilter, StateFilter, StatesListFilter +from .factory import FiltersFactory +from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters + +__all__ = [ + 'AbstractFilter', + 'AnyFilter', + 'AsyncFilter', + 'BaseFilter', + 'CommandsFilter', + 'ContentTypeFilter', + 'ExceptionsFilter', + 'Filter', + 'FilterRecord', + 'FiltersFactory', + 'NotFilter', + 'RegexpCommandsFilter', + 'RegexpFilter', + 'StateFilter', + 'StatesListFilter', + 'check_filter', + 'check_filters' +] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py new file mode 100644 index 00000000..00725b5b --- /dev/null +++ b/aiogram/dispatcher/filters/builtin.py @@ -0,0 +1,187 @@ +import asyncio +import re + +from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter +from aiogram.types import CallbackQuery, ContentType, Message +from aiogram.utils import context + +USER_STATE = 'USER_STATE' + + +class AnyFilter(Filter): + """ + One filter from many + """ + + def __init__(self, *filters: callable): + self.filters = filters + super().__init__() + + async def check(self, *args): + f = (check_filter(filter_, args) for filter_ in self.filters) + return any(await asyncio.gather(*f)) + + +class NotFilter(Filter): + """ + Reverse filter + """ + + def __init__(self, filter_: callable): + self.filter = filter_ + super().__init__() + + async def check(self, *args): + return not await check_filter(self.filter, args) + + +class CommandsFilter(BaseFilter): + """ + Check commands in message + """ + key = 'commands' + + def __init__(self, dispatcher, commands): + super().__init__(dispatcher) + self.commands = commands + + async def check(self, message): + if not message.is_command(): + return False + + command = message.text.split()[0][1:] + command, _, mention = command.partition('@') + + if mention and mention != (await message.bot.me).username: + return False + + if command not in self.commands: + return False + + return True + + +class RegexpFilter(BaseFilter): + """ + Regexp filter for messages and callback query + """ + key = 'regexp' + + def __init__(self, dispatcher, regexp): + super().__init__(dispatcher) + self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + + async def check(self, obj): + if isinstance(obj, Message) and obj.text: + return bool(self.regexp.search(obj.text)) + elif isinstance(obj, CallbackQuery) and obj.data: + return bool(self.regexp.search(obj.data)) + return False + + +class RegexpCommandsFilter(BaseFilter): + """ + Check commands by regexp in message + """ + + key = 'regexp_commands' + + def __init__(self, dispatcher, regexp_commands): + super().__init__(dispatcher) + self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] + + async def check(self, message): + if not message.is_command(): + return False + + command = message.text.split()[0][1:] + command, _, mention = command.partition('@') + + if mention and mention != (await message.bot.me).username: + return False + + for command in self.regexp_commands: + search = command.search(message.text) + if search: + message.conf['regexp_command'] = search + return True + return False + + +class ContentTypeFilter(BaseFilter): + """ + Check message content type + """ + + key = 'content_types' + + def __init__(self, dispatcher, content_types): + super().__init__(dispatcher) + self.content_types = content_types + + async def check(self, message): + return ContentType.ANY[0] in self.content_types or \ + message.content_type in self.content_types + + +class StateFilter(BaseFilter): + """ + Check user state + """ + key = 'state' + + def __init__(self, dispatcher, state): + super().__init__(dispatcher) + if isinstance(state, str): + state = (state,) + self.state = state + + def get_target(self, obj): + return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + + async def check(self, obj): + if self.state == '*': + return True + + if context.check_value(USER_STATE): + context_state = context.get_value(USER_STATE) + return self.state == context_state + else: + chat, user = self.get_target(obj) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + return False + + +class StatesListFilter(StateFilter): + """ + List of states + """ + + async def check(self, obj): + chat, user = self.get_target(obj) + + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state + return False + + +class ExceptionsFilter(BaseFilter): + """ + Filter for exceptions + """ + + key = 'exception' + + def __init__(self, dispatcher, exception): + super().__init__(dispatcher) + self.exception = exception + + async def check(self, dispatcher, update, exception): + try: + raise exception + except self.exception: + return True + except: + return False diff --git a/aiogram/dispatcher/filters/factory.py b/aiogram/dispatcher/filters/factory.py new file mode 100644 index 00000000..cc5d6a0c --- /dev/null +++ b/aiogram/dispatcher/filters/factory.py @@ -0,0 +1,71 @@ +import typing + +from .filters import AbstractFilter, FilterRecord +from ..handler import Handler + + +# TODO: provide to set default filters (Like state. It will be always be added to filters set) +# TODO: provide excluding event handlers +# TODO: move check_filter/check_filters functions to FiltersFactory class + +class FiltersFactory: + """ + Default filters factory + """ + + def __init__(self, dispatcher): + self._dispatcher = dispatcher + self._registered: typing.List[FilterRecord] = [] + + def bind(self, callback: typing.Union[typing.Callable, AbstractFilter], + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.List[Handler]] = None): + """ + Register filter + + :param callback: callable or subclass of :obj:`AbstractFilter` + :param validator: custom validator. + :param event_handlers: list of instances of :obj:`Handler` + """ + record = FilterRecord(callback, validator, event_handlers) + self._registered.append(record) + + def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]): + """ + Unregister callback + + :param callback: callable of subclass of :obj:`AbstractFilter` + """ + for record in self._registered: + if record.callback == callback: + self._registered.remove(record) + + def resolve(self, event_handler, *custom_filters, **full_config + ) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]: + """ + Resolve filters to filters-set + + :param event_handler: + :param custom_filters: + :param full_config: + :return: + """ + filters_set = [] + if custom_filters: + filters_set.extend(custom_filters) + if full_config: + filters_set.extend(self._resolve_registered(self._dispatcher, event_handler, + {k: v for k, v in full_config.items() if v is not None})) + return filters_set + + def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator: + for record in self._registered: + if not full_config: + break + + filter_ = record.resolve(dispatcher, event_handler, full_config) + if filter_: + yield filter_ + + if full_config: + raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'') diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py new file mode 100644 index 00000000..98e00d9f --- /dev/null +++ b/aiogram/dispatcher/filters/filters.py @@ -0,0 +1,162 @@ +import abc +import inspect +import typing + +from ..handler import Handler +from ...types.base import TelegramObject +from ...utils.deprecated import deprecated + + +async def check_filter(filter_, args): + """ + Helper for executing filter + + :param filter_: + :param args: + :return: + """ + if not callable(filter_): + raise TypeError('Filter must be callable and/or awaitable!') + + if inspect.isawaitable(filter_) \ + or inspect.iscoroutinefunction(filter_) \ + or isinstance(filter_, (Filter, AbstractFilter)): + return await filter_(*args) + else: + return filter_(*args) + + +async def check_filters(filters, args): + """ + Check list of filters + + :param filters: + :param args: + :return: + """ + if filters is not None: + for filter_ in filters: + f = await check_filter(filter_, args) + if not f: + return False + return True + + +class FilterRecord: + """ + Filters record for factory + """ + + def __init__(self, callback: typing.Callable, + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.Iterable[Handler]] = None): + self.callback = callback + self.event_handlers = event_handlers + if validator is not None: + if not callable(validator): + raise TypeError(f"validator must be callable, not {type(validator)}") + self.resolver = validator + elif issubclass(callback, AbstractFilter): + self.resolver = callback.validate + else: + raise RuntimeError('validator is required!') + + def resolve(self, dispatcher, event_observer, full_config): + if not self._check_event_handler(event_observer): + return + config = self.resolver(full_config) + if config: + return self.callback(dispatcher, **config) + + def _check_event_handler(self, event_handler) -> bool: + if not self.event_handlers: + return True + return event_handler in self.event_handlers + + +class AbstractFilter(abc.ABC): + """ + Abstract class for custom filters + """ + + key = None + + def __init__(self, dispatcher, **config): + self.dispatcher = dispatcher + self.config = config + + @classmethod + @abc.abstractmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + """ + Validate and parse config + + :param full_config: + :return: config + """ + pass + + @abc.abstractmethod + async def check(self, *args) -> bool: + """ + Check object + + :param args: + :return: + """ + pass + + async def __call__(self, obj: TelegramObject) -> bool: + return await self.check(obj) + + +class BaseFilter(AbstractFilter): + """ + Abstract class for filters with default validator + """ + + @property + @abc.abstractmethod + def key(self): + pass + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + if cls.key is not None and cls.key in full_config: + return {cls.key: full_config.pop(cls.key)} + + +class Filter(abc.ABC): + """ + Base class for filters + Subclasses of this class can't be used with FiltersFactory by default. + + (Backward capability) + """ + + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + def __call__(self, *args, **kwargs): + return self.check(*args, **kwargs) + + @abc.abstractmethod + def check(self, *args, **kwargs): + pass + + +@deprecated +class AsyncFilter(Filter): + """ + Base class for asynchronous filters + """ + + def __aiter__(self): + return None + + def __await__(self): + return self.check + + async def check(self, *args, **kwargs): + pass diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 8d75273f..bee98981 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,4 +1,3 @@ -from .filters import check_filters from ..utils import context @@ -57,6 +56,8 @@ class Handler: :param args: :return: """ + from .filters import check_filters + results = [] if self.middleware_key: