mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Experimental: Pass result of filter as arguments in handler.
This commit is contained in:
parent
fc4e6ae69b
commit
1ca0be538b
4 changed files with 71 additions and 100 deletions
|
|
@ -1,24 +1,21 @@
|
|||
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \
|
||||
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
|
||||
RegexpFilter, StateFilter, StatesListFilter
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters
|
||||
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
'AnyFilter',
|
||||
'AsyncFilter',
|
||||
'BaseFilter',
|
||||
'CommandsFilter',
|
||||
'ContentTypeFilter',
|
||||
'ExceptionsFilter',
|
||||
'Filter',
|
||||
'FilterRecord',
|
||||
'FiltersFactory',
|
||||
'NotFilter',
|
||||
'RegexpCommandsFilter',
|
||||
'RegexpFilter',
|
||||
'StateFilter',
|
||||
'StatesListFilter',
|
||||
'check_filter',
|
||||
'check_filters'
|
||||
'check_filters',
|
||||
'FilterNotPassed'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,39 +1,9 @@
|
|||
import asyncio
|
||||
import re
|
||||
from _contextvars import ContextVar
|
||||
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter
|
||||
from aiogram.types import CallbackQuery, ContentType, Message
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
class AnyFilter(Filter):
|
||||
"""
|
||||
One filter from many
|
||||
"""
|
||||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
class NotFilter(Filter):
|
||||
"""
|
||||
Reverse filter
|
||||
"""
|
||||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(BaseFilter):
|
||||
"""
|
||||
|
|
@ -72,10 +42,20 @@ class RegexpFilter(BaseFilter):
|
|||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
async def check(self, obj):
|
||||
if isinstance(obj, Message) and obj.text:
|
||||
return bool(self.regexp.search(obj.text))
|
||||
if isinstance(obj, Message):
|
||||
if obj.text:
|
||||
match = self.regexp.search(obj.text)
|
||||
elif obj.caption:
|
||||
match = self.regexp.search(obj.caption)
|
||||
else:
|
||||
return False
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
return bool(self.regexp.search(obj.data))
|
||||
match = self.regexp.search(obj.data)
|
||||
else:
|
||||
return False
|
||||
|
||||
if match:
|
||||
return {'regexp': match}
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -103,8 +83,7 @@ class RegexpCommandsFilter(BaseFilter):
|
|||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
message.conf['regexp_command'] = search
|
||||
return True
|
||||
return {'regexp_command': search}
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -142,18 +121,22 @@ class StateFilter(BaseFilter):
|
|||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
from ..dispatcher import Dispatcher
|
||||
|
||||
if self.state == '*':
|
||||
return True
|
||||
return {'state': Dispatcher.current().current_state()}
|
||||
|
||||
try:
|
||||
return self.state == self.ctx_state.get()
|
||||
if self.state == self.ctx_state.get():
|
||||
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
|
||||
except LookupError:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
self.ctx_state.set(state)
|
||||
return state == self.state
|
||||
if state == self.state:
|
||||
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ import typing
|
|||
|
||||
from ..handler import Handler
|
||||
from ...types.base import TelegramObject
|
||||
from ...utils.deprecated import deprecated
|
||||
|
||||
|
||||
class FilterNotPassed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
|
|
@ -20,7 +23,7 @@ async def check_filter(filter_, args):
|
|||
|
||||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, (Filter, AbstractFilter)):
|
||||
or isinstance(filter_, AbstractFilter):
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args)
|
||||
|
|
@ -34,12 +37,15 @@ async def check_filters(filters, args):
|
|||
:param args:
|
||||
:return:
|
||||
"""
|
||||
data = {}
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
raise FilterNotPassed()
|
||||
elif isinstance(f, dict):
|
||||
data.update(f)
|
||||
return data
|
||||
|
||||
|
||||
class FilterRecord:
|
||||
|
|
@ -132,36 +138,3 @@ class BaseFilter(AbstractFilter):
|
|||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||
if cls.key is not None and cls.key in full_config:
|
||||
return {cls.key: full_config.pop(cls.key)}
|
||||
|
||||
|
||||
class Filter(abc.ABC):
|
||||
"""
|
||||
Base class for filters
|
||||
Subclasses of this class can't be used with FiltersFactory by default.
|
||||
|
||||
(Backward capability)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated
|
||||
class AsyncFilter(Filter):
|
||||
"""
|
||||
Base class for asynchronous filters
|
||||
"""
|
||||
|
||||
def __await__(self):
|
||||
return self.check
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import inspect
|
||||
|
||||
|
||||
class SkipHandler(BaseException):
|
||||
pass
|
||||
|
||||
|
|
@ -6,6 +9,14 @@ class CancelHandler(BaseException):
|
|||
pass
|
||||
|
||||
|
||||
def _check_spec(func: callable, kwargs: dict):
|
||||
spec = inspect.getfullargspec(func)
|
||||
if spec.varkw:
|
||||
return kwargs
|
||||
|
||||
return {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||
self.dispatcher = dispatcher
|
||||
|
|
@ -53,7 +64,7 @@ class Handler:
|
|||
:param args:
|
||||
:return:
|
||||
"""
|
||||
from .filters import check_filters
|
||||
from .filters import check_filters, FilterNotPassed
|
||||
|
||||
results = []
|
||||
|
||||
|
|
@ -63,23 +74,30 @@ class Handler:
|
|||
except CancelHandler: # Allow to cancel current event
|
||||
return results
|
||||
|
||||
for filters, handler in self.handlers:
|
||||
if await check_filters(filters, args):
|
||||
try:
|
||||
for filters, handler in self.handlers:
|
||||
try:
|
||||
if self.middleware_key:
|
||||
# context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
response = await handler(*args)
|
||||
if response is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
break
|
||||
except SkipHandler:
|
||||
data = await check_filters(filters, args)
|
||||
except FilterNotPassed:
|
||||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results,))
|
||||
else:
|
||||
try:
|
||||
if self.middleware_key:
|
||||
# context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
partial_data = _check_spec(handler, data)
|
||||
response = await handler(*args, **partial_data)
|
||||
if response is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
break
|
||||
except SkipHandler:
|
||||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
finally:
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results,))
|
||||
|
||||
return results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue