Experimental: Pass result of filter as arguments in handler.

This commit is contained in:
Alex Root Junior 2018-06-25 02:15:37 +03:00
parent fc4e6ae69b
commit 1ca0be538b
4 changed files with 71 additions and 100 deletions

View file

@ -1,24 +1,21 @@
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
RegexpFilter, StateFilter, StatesListFilter
from .factory import FiltersFactory
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [
'AbstractFilter',
'AnyFilter',
'AsyncFilter',
'BaseFilter',
'CommandsFilter',
'ContentTypeFilter',
'ExceptionsFilter',
'Filter',
'FilterRecord',
'FiltersFactory',
'NotFilter',
'RegexpCommandsFilter',
'RegexpFilter',
'StateFilter',
'StatesListFilter',
'check_filter',
'check_filters'
'check_filters',
'FilterNotPassed'
]

View file

@ -1,39 +1,9 @@
import asyncio
import re
from _contextvars import ContextVar
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
from aiogram.dispatcher.filters.filters import BaseFilter
from aiogram.types import CallbackQuery, ContentType, Message
USER_STATE = 'USER_STATE'
class AnyFilter(Filter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
super().__init__()
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(Filter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
super().__init__()
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(BaseFilter):
"""
@ -72,10 +42,20 @@ class RegexpFilter(BaseFilter):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
async def check(self, obj):
if isinstance(obj, Message) and obj.text:
return bool(self.regexp.search(obj.text))
if isinstance(obj, Message):
if obj.text:
match = self.regexp.search(obj.text)
elif obj.caption:
match = self.regexp.search(obj.caption)
else:
return False
elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data))
match = self.regexp.search(obj.data)
else:
return False
if match:
return {'regexp': match}
return False
@ -103,8 +83,7 @@ class RegexpCommandsFilter(BaseFilter):
for command in self.regexp_commands:
search = command.search(message.text)
if search:
message.conf['regexp_command'] = search
return True
return {'regexp_command': search}
return False
@ -142,18 +121,22 @@ class StateFilter(BaseFilter):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
from ..dispatcher import Dispatcher
if self.state == '*':
return True
return {'state': Dispatcher.current().current_state()}
try:
return self.state == self.ctx_state.get()
if self.state == self.ctx_state.get():
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
except LookupError:
chat, user = self.get_target(obj)
if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
self.ctx_state.set(state)
return state == self.state
if state == self.state:
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
return False

View file

@ -4,7 +4,10 @@ import typing
from ..handler import Handler
from ...types.base import TelegramObject
from ...utils.deprecated import deprecated
class FilterNotPassed(Exception):
pass
async def check_filter(filter_, args):
@ -20,7 +23,7 @@ async def check_filter(filter_, args):
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, (Filter, AbstractFilter)):
or isinstance(filter_, AbstractFilter):
return await filter_(*args)
else:
return filter_(*args)
@ -34,12 +37,15 @@ async def check_filters(filters, args):
:param args:
:return:
"""
data = {}
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
if not f:
return False
return True
raise FilterNotPassed()
elif isinstance(f, dict):
data.update(f)
return data
class FilterRecord:
@ -132,36 +138,3 @@ class BaseFilter(AbstractFilter):
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None and cls.key in full_config:
return {cls.key: full_config.pop(cls.key)}
class Filter(abc.ABC):
"""
Base class for filters
Subclasses of this class can't be used with FiltersFactory by default.
(Backward capability)
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
@abc.abstractmethod
def check(self, *args, **kwargs):
pass
@deprecated
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass

View file

@ -1,3 +1,6 @@
import inspect
class SkipHandler(BaseException):
pass
@ -6,6 +9,14 @@ class CancelHandler(BaseException):
pass
def _check_spec(func: callable, kwargs: dict):
spec = inspect.getfullargspec(func)
if spec.varkw:
return kwargs
return {k: v for k, v in kwargs.items() if k in spec.args}
class Handler:
def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher
@ -53,7 +64,7 @@ class Handler:
:param args:
:return:
"""
from .filters import check_filters
from .filters import check_filters, FilterNotPassed
results = []
@ -63,23 +74,30 @@ class Handler:
except CancelHandler: # Allow to cancel current event
return results
for filters, handler in self.handlers:
if await check_filters(filters, args):
try:
for filters, handler in self.handlers:
try:
if self.middleware_key:
# context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
data = await check_filters(filters, args)
except FilterNotPassed:
continue
except CancelHandler:
break
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
else:
try:
if self.middleware_key:
# context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
partial_data = _check_spec(handler, data)
response = await handler(*args, **partial_data)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue
except CancelHandler:
break
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
return results