From 1ca0be538bcb70a8d94d5cccad600705ea6325de Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 25 Jun 2018 02:15:37 +0300 Subject: [PATCH] Experimental: Pass result of filter as arguments in handler. --- aiogram/dispatcher/filters/__init__.py | 11 ++--- aiogram/dispatcher/filters/builtin.py | 61 ++++++++++---------------- aiogram/dispatcher/filters/filters.py | 47 +++++--------------- aiogram/dispatcher/handler.py | 52 +++++++++++++++------- 4 files changed, 71 insertions(+), 100 deletions(-) diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 70d99531..48fe4e82 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,24 +1,21 @@ -from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \ +from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \ RegexpFilter, StateFilter, StatesListFilter from .factory import FiltersFactory -from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters +from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters __all__ = [ 'AbstractFilter', - 'AnyFilter', - 'AsyncFilter', 'BaseFilter', 'CommandsFilter', 'ContentTypeFilter', 'ExceptionsFilter', - 'Filter', 'FilterRecord', 'FiltersFactory', - 'NotFilter', 'RegexpCommandsFilter', 'RegexpFilter', 'StateFilter', 'StatesListFilter', 'check_filter', - 'check_filters' + 'check_filters', + 'FilterNotPassed' ] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 135a7338..5a2d9199 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -1,39 +1,9 @@ -import asyncio import re from _contextvars import ContextVar -from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter +from aiogram.dispatcher.filters.filters import BaseFilter from aiogram.types import CallbackQuery, ContentType, Message -USER_STATE = 'USER_STATE' - - -class AnyFilter(Filter): - """ - One filter from many - """ - - def __init__(self, *filters: callable): - self.filters = filters - super().__init__() - - async def check(self, *args): - f = (check_filter(filter_, args) for filter_ in self.filters) - return any(await asyncio.gather(*f)) - - -class NotFilter(Filter): - """ - Reverse filter - """ - - def __init__(self, filter_: callable): - self.filter = filter_ - super().__init__() - - async def check(self, *args): - return not await check_filter(self.filter, args) - class CommandsFilter(BaseFilter): """ @@ -72,10 +42,20 @@ class RegexpFilter(BaseFilter): self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) async def check(self, obj): - if isinstance(obj, Message) and obj.text: - return bool(self.regexp.search(obj.text)) + if isinstance(obj, Message): + if obj.text: + match = self.regexp.search(obj.text) + elif obj.caption: + match = self.regexp.search(obj.caption) + else: + return False elif isinstance(obj, CallbackQuery) and obj.data: - return bool(self.regexp.search(obj.data)) + match = self.regexp.search(obj.data) + else: + return False + + if match: + return {'regexp': match} return False @@ -103,8 +83,7 @@ class RegexpCommandsFilter(BaseFilter): for command in self.regexp_commands: search = command.search(message.text) if search: - message.conf['regexp_command'] = search - return True + return {'regexp_command': search} return False @@ -142,18 +121,22 @@ class StateFilter(BaseFilter): return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) async def check(self, obj): + from ..dispatcher import Dispatcher + if self.state == '*': - return True + return {'state': Dispatcher.current().current_state()} try: - return self.state == self.ctx_state.get() + if self.state == self.ctx_state.get(): + return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} except LookupError: chat, user = self.get_target(obj) if chat or user: state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state self.ctx_state.set(state) - return state == self.state + if state == self.state: + return {'state': Dispatcher.current().current_state(), 'raw_state': self.state} return False diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index bbdf909f..716edaac 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -4,7 +4,10 @@ import typing from ..handler import Handler from ...types.base import TelegramObject -from ...utils.deprecated import deprecated + + +class FilterNotPassed(Exception): + pass async def check_filter(filter_, args): @@ -20,7 +23,7 @@ async def check_filter(filter_, args): if inspect.isawaitable(filter_) \ or inspect.iscoroutinefunction(filter_) \ - or isinstance(filter_, (Filter, AbstractFilter)): + or isinstance(filter_, AbstractFilter): return await filter_(*args) else: return filter_(*args) @@ -34,12 +37,15 @@ async def check_filters(filters, args): :param args: :return: """ + data = {} if filters is not None: for filter_ in filters: f = await check_filter(filter_, args) if not f: - return False - return True + raise FilterNotPassed() + elif isinstance(f, dict): + data.update(f) + return data class FilterRecord: @@ -132,36 +138,3 @@ class BaseFilter(AbstractFilter): def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: if cls.key is not None and cls.key in full_config: return {cls.key: full_config.pop(cls.key)} - - -class Filter(abc.ABC): - """ - Base class for filters - Subclasses of this class can't be used with FiltersFactory by default. - - (Backward capability) - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - - def __call__(self, *args, **kwargs): - return self.check(*args, **kwargs) - - @abc.abstractmethod - def check(self, *args, **kwargs): - pass - - -@deprecated -class AsyncFilter(Filter): - """ - Base class for asynchronous filters - """ - - def __await__(self): - return self.check - - async def check(self, *args, **kwargs): - pass diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 58b285a6..516b6114 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,3 +1,6 @@ +import inspect + + class SkipHandler(BaseException): pass @@ -6,6 +9,14 @@ class CancelHandler(BaseException): pass +def _check_spec(func: callable, kwargs: dict): + spec = inspect.getfullargspec(func) + if spec.varkw: + return kwargs + + return {k: v for k, v in kwargs.items() if k in spec.args} + + class Handler: def __init__(self, dispatcher, once=True, middleware_key=None): self.dispatcher = dispatcher @@ -53,7 +64,7 @@ class Handler: :param args: :return: """ - from .filters import check_filters + from .filters import check_filters, FilterNotPassed results = [] @@ -63,23 +74,30 @@ class Handler: except CancelHandler: # Allow to cancel current event return results - for filters, handler in self.handlers: - if await check_filters(filters, args): + try: + for filters, handler in self.handlers: try: - if self.middleware_key: - # context.set_value('handler', handler) - await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) - response = await handler(*args) - if response is not None: - results.append(response) - if self.once: - break - except SkipHandler: + data = await check_filters(filters, args) + except FilterNotPassed: continue - except CancelHandler: - break - if self.middleware_key: - await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", - args + (results,)) + else: + try: + if self.middleware_key: + # context.set_value('handler', handler) + await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) + partial_data = _check_spec(handler, data) + response = await handler(*args, **partial_data) + if response is not None: + results.append(response) + if self.once: + break + except SkipHandler: + continue + except CancelHandler: + break + finally: + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", + args + (results,)) return results