Experimental: Pass result of filter as arguments in handler.

This commit is contained in:
Alex Root Junior 2018-06-25 02:15:37 +03:00
parent fc4e6ae69b
commit 1ca0be538b
4 changed files with 71 additions and 100 deletions

View file

@ -1,24 +1,21 @@
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \ from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
RegexpFilter, StateFilter, StatesListFilter RegexpFilter, StateFilter, StatesListFilter
from .factory import FiltersFactory from .factory import FiltersFactory
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [ __all__ = [
'AbstractFilter', 'AbstractFilter',
'AnyFilter',
'AsyncFilter',
'BaseFilter', 'BaseFilter',
'CommandsFilter', 'CommandsFilter',
'ContentTypeFilter', 'ContentTypeFilter',
'ExceptionsFilter', 'ExceptionsFilter',
'Filter',
'FilterRecord', 'FilterRecord',
'FiltersFactory', 'FiltersFactory',
'NotFilter',
'RegexpCommandsFilter', 'RegexpCommandsFilter',
'RegexpFilter', 'RegexpFilter',
'StateFilter', 'StateFilter',
'StatesListFilter', 'StatesListFilter',
'check_filter', 'check_filter',
'check_filters' 'check_filters',
'FilterNotPassed'
] ]

View file

@ -1,39 +1,9 @@
import asyncio
import re import re
from _contextvars import ContextVar from _contextvars import ContextVar
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter from aiogram.dispatcher.filters.filters import BaseFilter
from aiogram.types import CallbackQuery, ContentType, Message from aiogram.types import CallbackQuery, ContentType, Message
USER_STATE = 'USER_STATE'
class AnyFilter(Filter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
super().__init__()
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(Filter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
super().__init__()
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(BaseFilter): class CommandsFilter(BaseFilter):
""" """
@ -72,10 +42,20 @@ class RegexpFilter(BaseFilter):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
async def check(self, obj): async def check(self, obj):
if isinstance(obj, Message) and obj.text: if isinstance(obj, Message):
return bool(self.regexp.search(obj.text)) if obj.text:
match = self.regexp.search(obj.text)
elif obj.caption:
match = self.regexp.search(obj.caption)
else:
return False
elif isinstance(obj, CallbackQuery) and obj.data: elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data)) match = self.regexp.search(obj.data)
else:
return False
if match:
return {'regexp': match}
return False return False
@ -103,8 +83,7 @@ class RegexpCommandsFilter(BaseFilter):
for command in self.regexp_commands: for command in self.regexp_commands:
search = command.search(message.text) search = command.search(message.text)
if search: if search:
message.conf['regexp_command'] = search return {'regexp_command': search}
return True
return False return False
@ -142,18 +121,22 @@ class StateFilter(BaseFilter):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj): async def check(self, obj):
from ..dispatcher import Dispatcher
if self.state == '*': if self.state == '*':
return True return {'state': Dispatcher.current().current_state()}
try: try:
return self.state == self.ctx_state.get() if self.state == self.ctx_state.get():
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
except LookupError: except LookupError:
chat, user = self.get_target(obj) chat, user = self.get_target(obj)
if chat or user: if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
self.ctx_state.set(state) self.ctx_state.set(state)
return state == self.state if state == self.state:
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
return False return False

View file

@ -4,7 +4,10 @@ import typing
from ..handler import Handler from ..handler import Handler
from ...types.base import TelegramObject from ...types.base import TelegramObject
from ...utils.deprecated import deprecated
class FilterNotPassed(Exception):
pass
async def check_filter(filter_, args): async def check_filter(filter_, args):
@ -20,7 +23,7 @@ async def check_filter(filter_, args):
if inspect.isawaitable(filter_) \ if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \ or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, (Filter, AbstractFilter)): or isinstance(filter_, AbstractFilter):
return await filter_(*args) return await filter_(*args)
else: else:
return filter_(*args) return filter_(*args)
@ -34,12 +37,15 @@ async def check_filters(filters, args):
:param args: :param args:
:return: :return:
""" """
data = {}
if filters is not None: if filters is not None:
for filter_ in filters: for filter_ in filters:
f = await check_filter(filter_, args) f = await check_filter(filter_, args)
if not f: if not f:
return False raise FilterNotPassed()
return True elif isinstance(f, dict):
data.update(f)
return data
class FilterRecord: class FilterRecord:
@ -132,36 +138,3 @@ class BaseFilter(AbstractFilter):
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None and cls.key in full_config: if cls.key is not None and cls.key in full_config:
return {cls.key: full_config.pop(cls.key)} return {cls.key: full_config.pop(cls.key)}
class Filter(abc.ABC):
"""
Base class for filters
Subclasses of this class can't be used with FiltersFactory by default.
(Backward capability)
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
@abc.abstractmethod
def check(self, *args, **kwargs):
pass
@deprecated
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass

View file

@ -1,3 +1,6 @@
import inspect
class SkipHandler(BaseException): class SkipHandler(BaseException):
pass pass
@ -6,6 +9,14 @@ class CancelHandler(BaseException):
pass pass
def _check_spec(func: callable, kwargs: dict):
spec = inspect.getfullargspec(func)
if spec.varkw:
return kwargs
return {k: v for k, v in kwargs.items() if k in spec.args}
class Handler: class Handler:
def __init__(self, dispatcher, once=True, middleware_key=None): def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher self.dispatcher = dispatcher
@ -53,7 +64,7 @@ class Handler:
:param args: :param args:
:return: :return:
""" """
from .filters import check_filters from .filters import check_filters, FilterNotPassed
results = [] results = []
@ -63,23 +74,30 @@ class Handler:
except CancelHandler: # Allow to cancel current event except CancelHandler: # Allow to cancel current event
return results return results
for filters, handler in self.handlers: try:
if await check_filters(filters, args): for filters, handler in self.handlers:
try: try:
if self.middleware_key: data = await check_filters(filters, args)
# context.set_value('handler', handler) except FilterNotPassed:
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue continue
except CancelHandler: else:
break try:
if self.middleware_key: if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", # context.set_value('handler', handler)
args + (results,)) await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
partial_data = _check_spec(handler, data)
response = await handler(*args, **partial_data)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue
except CancelHandler:
break
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
return results return results