mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Experimental: Pass result of filter as arguments in handler.
This commit is contained in:
parent
fc4e6ae69b
commit
1ca0be538b
4 changed files with 71 additions and 100 deletions
|
|
@ -1,24 +1,21 @@
|
||||||
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \
|
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \
|
||||||
RegexpFilter, StateFilter, StatesListFilter
|
RegexpFilter, StateFilter, StatesListFilter
|
||||||
from .factory import FiltersFactory
|
from .factory import FiltersFactory
|
||||||
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters
|
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AbstractFilter',
|
'AbstractFilter',
|
||||||
'AnyFilter',
|
|
||||||
'AsyncFilter',
|
|
||||||
'BaseFilter',
|
'BaseFilter',
|
||||||
'CommandsFilter',
|
'CommandsFilter',
|
||||||
'ContentTypeFilter',
|
'ContentTypeFilter',
|
||||||
'ExceptionsFilter',
|
'ExceptionsFilter',
|
||||||
'Filter',
|
|
||||||
'FilterRecord',
|
'FilterRecord',
|
||||||
'FiltersFactory',
|
'FiltersFactory',
|
||||||
'NotFilter',
|
|
||||||
'RegexpCommandsFilter',
|
'RegexpCommandsFilter',
|
||||||
'RegexpFilter',
|
'RegexpFilter',
|
||||||
'StateFilter',
|
'StateFilter',
|
||||||
'StatesListFilter',
|
'StatesListFilter',
|
||||||
'check_filter',
|
'check_filter',
|
||||||
'check_filters'
|
'check_filters',
|
||||||
|
'FilterNotPassed'
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,39 +1,9 @@
|
||||||
import asyncio
|
|
||||||
import re
|
import re
|
||||||
from _contextvars import ContextVar
|
from _contextvars import ContextVar
|
||||||
|
|
||||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
|
from aiogram.dispatcher.filters.filters import BaseFilter
|
||||||
from aiogram.types import CallbackQuery, ContentType, Message
|
from aiogram.types import CallbackQuery, ContentType, Message
|
||||||
|
|
||||||
USER_STATE = 'USER_STATE'
|
|
||||||
|
|
||||||
|
|
||||||
class AnyFilter(Filter):
|
|
||||||
"""
|
|
||||||
One filter from many
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *filters: callable):
|
|
||||||
self.filters = filters
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
async def check(self, *args):
|
|
||||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
|
||||||
return any(await asyncio.gather(*f))
|
|
||||||
|
|
||||||
|
|
||||||
class NotFilter(Filter):
|
|
||||||
"""
|
|
||||||
Reverse filter
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filter_: callable):
|
|
||||||
self.filter = filter_
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
async def check(self, *args):
|
|
||||||
return not await check_filter(self.filter, args)
|
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(BaseFilter):
|
class CommandsFilter(BaseFilter):
|
||||||
"""
|
"""
|
||||||
|
|
@ -72,10 +42,20 @@ class RegexpFilter(BaseFilter):
|
||||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
|
||||||
async def check(self, obj):
|
async def check(self, obj):
|
||||||
if isinstance(obj, Message) and obj.text:
|
if isinstance(obj, Message):
|
||||||
return bool(self.regexp.search(obj.text))
|
if obj.text:
|
||||||
|
match = self.regexp.search(obj.text)
|
||||||
|
elif obj.caption:
|
||||||
|
match = self.regexp.search(obj.caption)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||||
return bool(self.regexp.search(obj.data))
|
match = self.regexp.search(obj.data)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if match:
|
||||||
|
return {'regexp': match}
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,8 +83,7 @@ class RegexpCommandsFilter(BaseFilter):
|
||||||
for command in self.regexp_commands:
|
for command in self.regexp_commands:
|
||||||
search = command.search(message.text)
|
search = command.search(message.text)
|
||||||
if search:
|
if search:
|
||||||
message.conf['regexp_command'] = search
|
return {'regexp_command': search}
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -142,18 +121,22 @@ class StateFilter(BaseFilter):
|
||||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||||
|
|
||||||
async def check(self, obj):
|
async def check(self, obj):
|
||||||
|
from ..dispatcher import Dispatcher
|
||||||
|
|
||||||
if self.state == '*':
|
if self.state == '*':
|
||||||
return True
|
return {'state': Dispatcher.current().current_state()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.state == self.ctx_state.get()
|
if self.state == self.ctx_state.get():
|
||||||
|
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
|
||||||
except LookupError:
|
except LookupError:
|
||||||
chat, user = self.get_target(obj)
|
chat, user = self.get_target(obj)
|
||||||
|
|
||||||
if chat or user:
|
if chat or user:
|
||||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||||
self.ctx_state.set(state)
|
self.ctx_state.set(state)
|
||||||
return state == self.state
|
if state == self.state:
|
||||||
|
return {'state': Dispatcher.current().current_state(), 'raw_state': self.state}
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,10 @@ import typing
|
||||||
|
|
||||||
from ..handler import Handler
|
from ..handler import Handler
|
||||||
from ...types.base import TelegramObject
|
from ...types.base import TelegramObject
|
||||||
from ...utils.deprecated import deprecated
|
|
||||||
|
|
||||||
|
class FilterNotPassed(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args):
|
async def check_filter(filter_, args):
|
||||||
|
|
@ -20,7 +23,7 @@ async def check_filter(filter_, args):
|
||||||
|
|
||||||
if inspect.isawaitable(filter_) \
|
if inspect.isawaitable(filter_) \
|
||||||
or inspect.iscoroutinefunction(filter_) \
|
or inspect.iscoroutinefunction(filter_) \
|
||||||
or isinstance(filter_, (Filter, AbstractFilter)):
|
or isinstance(filter_, AbstractFilter):
|
||||||
return await filter_(*args)
|
return await filter_(*args)
|
||||||
else:
|
else:
|
||||||
return filter_(*args)
|
return filter_(*args)
|
||||||
|
|
@ -34,12 +37,15 @@ async def check_filters(filters, args):
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
data = {}
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
for filter_ in filters:
|
for filter_ in filters:
|
||||||
f = await check_filter(filter_, args)
|
f = await check_filter(filter_, args)
|
||||||
if not f:
|
if not f:
|
||||||
return False
|
raise FilterNotPassed()
|
||||||
return True
|
elif isinstance(f, dict):
|
||||||
|
data.update(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class FilterRecord:
|
class FilterRecord:
|
||||||
|
|
@ -132,36 +138,3 @@ class BaseFilter(AbstractFilter):
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||||
if cls.key is not None and cls.key in full_config:
|
if cls.key is not None and cls.key in full_config:
|
||||||
return {cls.key: full_config.pop(cls.key)}
|
return {cls.key: full_config.pop(cls.key)}
|
||||||
|
|
||||||
|
|
||||||
class Filter(abc.ABC):
|
|
||||||
"""
|
|
||||||
Base class for filters
|
|
||||||
Subclasses of this class can't be used with FiltersFactory by default.
|
|
||||||
|
|
||||||
(Backward capability)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._args = args
|
|
||||||
self._kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.check(*args, **kwargs)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def check(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated
|
|
||||||
class AsyncFilter(Filter):
|
|
||||||
"""
|
|
||||||
Base class for asynchronous filters
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __await__(self):
|
|
||||||
return self.check
|
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
class SkipHandler(BaseException):
|
class SkipHandler(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -6,6 +9,14 @@ class CancelHandler(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_spec(func: callable, kwargs: dict):
|
||||||
|
spec = inspect.getfullargspec(func)
|
||||||
|
if spec.varkw:
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
return {k: v for k, v in kwargs.items() if k in spec.args}
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
|
|
@ -53,7 +64,7 @@ class Handler:
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
from .filters import check_filters
|
from .filters import check_filters, FilterNotPassed
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
|
@ -63,23 +74,30 @@ class Handler:
|
||||||
except CancelHandler: # Allow to cancel current event
|
except CancelHandler: # Allow to cancel current event
|
||||||
return results
|
return results
|
||||||
|
|
||||||
for filters, handler in self.handlers:
|
try:
|
||||||
if await check_filters(filters, args):
|
for filters, handler in self.handlers:
|
||||||
try:
|
try:
|
||||||
if self.middleware_key:
|
data = await check_filters(filters, args)
|
||||||
# context.set_value('handler', handler)
|
except FilterNotPassed:
|
||||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
|
||||||
response = await handler(*args)
|
|
||||||
if response is not None:
|
|
||||||
results.append(response)
|
|
||||||
if self.once:
|
|
||||||
break
|
|
||||||
except SkipHandler:
|
|
||||||
continue
|
continue
|
||||||
except CancelHandler:
|
else:
|
||||||
break
|
try:
|
||||||
if self.middleware_key:
|
if self.middleware_key:
|
||||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
# context.set_value('handler', handler)
|
||||||
args + (results,))
|
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||||
|
partial_data = _check_spec(handler, data)
|
||||||
|
response = await handler(*args, **partial_data)
|
||||||
|
if response is not None:
|
||||||
|
results.append(response)
|
||||||
|
if self.once:
|
||||||
|
break
|
||||||
|
except SkipHandler:
|
||||||
|
continue
|
||||||
|
except CancelHandler:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||||
|
args + (results,))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue