mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Remake filters. Implemented filters factory.
This commit is contained in:
parent
fbcbf9ade4
commit
8ada376a3c
7 changed files with 474 additions and 434 deletions
|
|
@ -5,8 +5,7 @@ import logging
|
|||
import time
|
||||
import typing
|
||||
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \
|
||||
generate_default_filters
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter
|
||||
from .handler import CancelHandler, Handler, SkipHandler
|
||||
from .middlewares import MiddlewareManager
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||
|
|
@ -78,11 +77,26 @@ class Dispatcher:
|
|||
self._closed = True
|
||||
self._close_waiter = loop.create_future()
|
||||
|
||||
filters_factory.bind(filters.CommandsFilter, 'commands')
|
||||
filters_factory.bind(filters.RegexpFilter, 'regexp')
|
||||
filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands')
|
||||
filters_factory.bind(filters.ContentTypeFilter, 'content_types')
|
||||
filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True)
|
||||
filters_factory.bind(filters.CommandsFilter, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers
|
||||
])
|
||||
filters_factory.bind(filters.RegexpFilter, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers,
|
||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||
self.callback_query_handlers
|
||||
|
||||
])
|
||||
filters_factory.bind(filters.RegexpCommandsFilter, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers
|
||||
])
|
||||
filters_factory.bind(filters.ContentTypeFilter, event_handlers=[
|
||||
self.message_handlers, self.edited_message_handlers,
|
||||
self.channel_post_handlers, self.edited_channel_post_handlers,
|
||||
])
|
||||
filters_factory.bind(filters.StateFilter)
|
||||
filters_factory.bind(filters.ExceptionsFilter, event_handlers=[
|
||||
self.errors_handlers
|
||||
])
|
||||
|
||||
def __del__(self):
|
||||
self.stop_polling()
|
||||
|
|
@ -355,12 +369,13 @@ class Dispatcher:
|
|||
custom_filters = list(custom_filters)
|
||||
custom_filters.append(func)
|
||||
|
||||
filters_set = self.filters_factory.parse(*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_types=content_types,
|
||||
state=state,
|
||||
**kwargs)
|
||||
filters_set = self.filters_factory.resolve(self.message_handlers,
|
||||
*custom_filters,
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
content_types=content_types,
|
||||
state=state,
|
||||
**kwargs)
|
||||
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
|
||||
|
||||
def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None,
|
||||
|
|
|
|||
|
|
@ -1,420 +0,0 @@
|
|||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import re
|
||||
import typing
|
||||
|
||||
from ..types import CallbackQuery, ContentType, Message
|
||||
from ..utils import context
|
||||
from ..utils.helper import Helper, HelperMode, Item
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
class FiltersFactory:
|
||||
def __init__(self, dispatcher):
|
||||
self._dispatcher = dispatcher
|
||||
self._filters = []
|
||||
|
||||
@property
|
||||
def _default_filters(self):
|
||||
return tuple(filter(lambda item: item[-1], self._filters))
|
||||
|
||||
def bind(self, filter_, *args, default=False, with_dispatcher=False):
|
||||
self._filters.append(FilterConfig(self._dispatcher, filter_, default=default, with_dispatcher=with_dispatcher))
|
||||
|
||||
def unbind(self, filter_):
|
||||
for item in self._filters:
|
||||
if filter_ is item[0]:
|
||||
self._filters.remove(item)
|
||||
return True
|
||||
raise ValueError(f'{filter_} is not binded.')
|
||||
|
||||
def replace(self, original, new):
|
||||
for item in self._filters:
|
||||
if original is item[0]:
|
||||
item[0] = new
|
||||
return True
|
||||
raise ValueError(f'{original} is not binded.')
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
"""
|
||||
Generate filters list
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
used = []
|
||||
filters = []
|
||||
|
||||
filters.extend(args)
|
||||
|
||||
# Registered filters filters
|
||||
for filterconfig in self._filters:
|
||||
pass
|
||||
|
||||
# Not registered filters
|
||||
for key, filter_ in kwargs.items():
|
||||
if isinstance(filter_, Filter):
|
||||
filters.append(filter_)
|
||||
used.append(filter_.__class__)
|
||||
else:
|
||||
raise ValueError(f"Unknown filter with key '{key}'")
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
class FilterConfig:
|
||||
def __init__(self, dispatcher, filter_: typing.Callable,
|
||||
default: bool = False, with_dispatcher: bool = False,
|
||||
args: typing.Union[tuple, set, list] = ()):
|
||||
self.dispatcher = dispatcher
|
||||
self.filter = filter_
|
||||
self.default = default
|
||||
self.with_dispatcher = with_dispatcher
|
||||
self.args = args
|
||||
|
||||
def _check_list(self, config):
|
||||
result = {}
|
||||
accept = True
|
||||
|
||||
for item in self.args:
|
||||
value = config.pop(item, None)
|
||||
if value is None:
|
||||
accept = False
|
||||
break
|
||||
result[item] = value
|
||||
|
||||
return accept or self.default, result
|
||||
|
||||
def _check_dict(self, config):
|
||||
result = {}
|
||||
accept = True
|
||||
|
||||
for key, type_ in self.args:
|
||||
value = config.pop(key, None)
|
||||
if value is None:
|
||||
accept = False
|
||||
break
|
||||
if type_ is bool:
|
||||
|
||||
return accept or self.default, result
|
||||
|
||||
def check(self, config):
|
||||
if isinstance(config, dict):
|
||||
return self._check_dict(config)
|
||||
else:
|
||||
return self._check_list(config)
|
||||
|
||||
def parse(self, config):
|
||||
pass
|
||||
|
||||
def configure(self, config=None):
|
||||
if config is None:
|
||||
config = {}
|
||||
if self.with_dispatcher:
|
||||
if config:
|
||||
config = copy.deepcopy(config)
|
||||
config['dispatcher'] = self.dispatcher
|
||||
return self.filter(**config)
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param filter_:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args)
|
||||
|
||||
|
||||
async def check_filters(filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
Base class for filters
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
def check(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncFilter(Filter):
|
||||
"""
|
||||
Base class for asynchronous filters
|
||||
"""
|
||||
|
||||
def __aiter__(self):
|
||||
return None
|
||||
|
||||
def __await__(self):
|
||||
return self.check
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class AnyFilter(AsyncFilter):
|
||||
"""
|
||||
One filter from many
|
||||
"""
|
||||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
class NotFilter(AsyncFilter):
|
||||
"""
|
||||
Reverse filter
|
||||
"""
|
||||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(AsyncFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
"""
|
||||
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
super().__init__()
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
if command not in self.commands:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class RegexpFilter(Filter):
|
||||
"""
|
||||
Regexp filter for messages
|
||||
"""
|
||||
|
||||
def __init__(self, regexp):
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
super().__init__()
|
||||
|
||||
def check(self, obj):
|
||||
if isinstance(obj, Message) and obj.text:
|
||||
return bool(self.regexp.search(obj.text))
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
return bool(self.regexp.search(obj.data))
|
||||
return False
|
||||
|
||||
|
||||
class RegexpCommandsFilter(AsyncFilter):
|
||||
"""
|
||||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
super().__init__()
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
message.conf['regexp_command'] = search
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ContentTypeFilter(Filter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
super().__init__()
|
||||
|
||||
def check(self, message):
|
||||
return ContentType.ANY[0] in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class CancelFilter(Filter):
|
||||
"""
|
||||
Find cancel in message text
|
||||
"""
|
||||
|
||||
def __init__(self, cancel_set=None):
|
||||
if cancel_set is None:
|
||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||
self.cancel_set = cancel_set
|
||||
super().__init__()
|
||||
|
||||
def check(self, message):
|
||||
if message.text:
|
||||
return message.text.lower() in self.cancel_set
|
||||
|
||||
|
||||
class StateFilter(AsyncFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
self.dispatcher = dispatcher
|
||||
if isinstance(state, str):
|
||||
state = (state,)
|
||||
self.state = state
|
||||
super().__init__()
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
if self.state == '*':
|
||||
return True
|
||||
|
||||
if context.check_value(USER_STATE):
|
||||
context_state = context.get_value(USER_STATE)
|
||||
return self.state == context_state
|
||||
else:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
class StatesListFilter(StateFilter):
|
||||
"""
|
||||
List of states
|
||||
"""
|
||||
|
||||
async def check(self, obj):
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
class ExceptionsFilter(Filter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
||||
def __init__(self, exception):
|
||||
self.exception = exception
|
||||
super().__init__()
|
||||
|
||||
def check(self, dispatcher, update, exception):
|
||||
try:
|
||||
raise exception
|
||||
except self.exception:
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
||||
"""
|
||||
Prepare filters
|
||||
|
||||
:param dispatcher:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
filters_set = []
|
||||
|
||||
for name, filter_ in kwargs.items():
|
||||
if filter_ is None and name != DefaultFilters.STATE:
|
||||
continue
|
||||
if name == DefaultFilters.COMMANDS:
|
||||
if isinstance(filter_, str):
|
||||
filters_set.append(CommandsFilter([filter_]))
|
||||
else:
|
||||
filters_set.append(CommandsFilter(filter_))
|
||||
elif name == DefaultFilters.REGEXP:
|
||||
filters_set.append(RegexpFilter(filter_))
|
||||
elif name == DefaultFilters.CONTENT_TYPES:
|
||||
filters_set.append(ContentTypeFilter(filter_))
|
||||
elif name == DefaultFilters.FUNC:
|
||||
filters_set.append(filter_)
|
||||
elif name == DefaultFilters.STATE:
|
||||
if isinstance(filter_, (list, set, tuple)):
|
||||
filters_set.append(StatesListFilter(dispatcher, filter_))
|
||||
else:
|
||||
filters_set.append(StateFilter(dispatcher, filter_))
|
||||
elif isinstance(filter_, Filter):
|
||||
filters_set.append(filter_)
|
||||
|
||||
filters_set += list(args)
|
||||
|
||||
return filters_set
|
||||
|
||||
|
||||
class DefaultFilters(Helper):
|
||||
mode = HelperMode.snake_case
|
||||
|
||||
COMMANDS = Item() # commands
|
||||
REGEXP = Item() # regexp
|
||||
CONTENT_TYPES = Item() # content_type
|
||||
FUNC = Item() # func
|
||||
STATE = Item() # state
|
||||
24
aiogram/dispatcher/filters/__init__.py
Normal file
24
aiogram/dispatcher/filters/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from .builtin import AnyFilter, CommandsFilter, ContentTypeFilter, ExceptionsFilter, NotFilter, RegexpCommandsFilter, \
|
||||
RegexpFilter, StateFilter, StatesListFilter
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, AsyncFilter, BaseFilter, Filter, FilterRecord, check_filter, check_filters
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
'AnyFilter',
|
||||
'AsyncFilter',
|
||||
'BaseFilter',
|
||||
'CommandsFilter',
|
||||
'ContentTypeFilter',
|
||||
'ExceptionsFilter',
|
||||
'Filter',
|
||||
'FilterRecord',
|
||||
'FiltersFactory',
|
||||
'NotFilter',
|
||||
'RegexpCommandsFilter',
|
||||
'RegexpFilter',
|
||||
'StateFilter',
|
||||
'StatesListFilter',
|
||||
'check_filter',
|
||||
'check_filters'
|
||||
]
|
||||
187
aiogram/dispatcher/filters/builtin.py
Normal file
187
aiogram/dispatcher/filters/builtin.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import asyncio
|
||||
import re
|
||||
|
||||
from aiogram.dispatcher.filters import BaseFilter, Filter, check_filter
|
||||
from aiogram.types import CallbackQuery, ContentType, Message
|
||||
from aiogram.utils import context
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
class AnyFilter(Filter):
|
||||
"""
|
||||
One filter from many
|
||||
"""
|
||||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
class NotFilter(Filter):
|
||||
"""
|
||||
Reverse filter
|
||||
"""
|
||||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
super().__init__()
|
||||
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(BaseFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
"""
|
||||
key = 'commands'
|
||||
|
||||
def __init__(self, dispatcher, commands):
|
||||
super().__init__(dispatcher)
|
||||
self.commands = commands
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
if command not in self.commands:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class RegexpFilter(BaseFilter):
|
||||
"""
|
||||
Regexp filter for messages and callback query
|
||||
"""
|
||||
key = 'regexp'
|
||||
|
||||
def __init__(self, dispatcher, regexp):
|
||||
super().__init__(dispatcher)
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
async def check(self, obj):
|
||||
if isinstance(obj, Message) and obj.text:
|
||||
return bool(self.regexp.search(obj.text))
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
return bool(self.regexp.search(obj.data))
|
||||
return False
|
||||
|
||||
|
||||
class RegexpCommandsFilter(BaseFilter):
|
||||
"""
|
||||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
key = 'regexp_commands'
|
||||
|
||||
def __init__(self, dispatcher, regexp_commands):
|
||||
super().__init__(dispatcher)
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
message.conf['regexp_command'] = search
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ContentTypeFilter(BaseFilter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
||||
key = 'content_types'
|
||||
|
||||
def __init__(self, dispatcher, content_types):
|
||||
super().__init__(dispatcher)
|
||||
self.content_types = content_types
|
||||
|
||||
async def check(self, message):
|
||||
return ContentType.ANY[0] in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class StateFilter(BaseFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
key = 'state'
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
super().__init__(dispatcher)
|
||||
if isinstance(state, str):
|
||||
state = (state,)
|
||||
self.state = state
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
if self.state == '*':
|
||||
return True
|
||||
|
||||
if context.check_value(USER_STATE):
|
||||
context_state = context.get_value(USER_STATE)
|
||||
return self.state == context_state
|
||||
else:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
class StatesListFilter(StateFilter):
|
||||
"""
|
||||
List of states
|
||||
"""
|
||||
|
||||
async def check(self, obj):
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
class ExceptionsFilter(BaseFilter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
||||
key = 'exception'
|
||||
|
||||
def __init__(self, dispatcher, exception):
|
||||
super().__init__(dispatcher)
|
||||
self.exception = exception
|
||||
|
||||
async def check(self, dispatcher, update, exception):
|
||||
try:
|
||||
raise exception
|
||||
except self.exception:
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
71
aiogram/dispatcher/filters/factory.py
Normal file
71
aiogram/dispatcher/filters/factory.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import typing
|
||||
|
||||
from .filters import AbstractFilter, FilterRecord
|
||||
from ..handler import Handler
|
||||
|
||||
|
||||
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
|
||||
# TODO: provide excluding event handlers
|
||||
# TODO: move check_filter/check_filters functions to FiltersFactory class
|
||||
|
||||
class FiltersFactory:
|
||||
"""
|
||||
Default filters factory
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher):
|
||||
self._dispatcher = dispatcher
|
||||
self._registered: typing.List[FilterRecord] = []
|
||||
|
||||
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.List[Handler]] = None):
|
||||
"""
|
||||
Register filter
|
||||
|
||||
:param callback: callable or subclass of :obj:`AbstractFilter`
|
||||
:param validator: custom validator.
|
||||
:param event_handlers: list of instances of :obj:`Handler`
|
||||
"""
|
||||
record = FilterRecord(callback, validator, event_handlers)
|
||||
self._registered.append(record)
|
||||
|
||||
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
||||
"""
|
||||
Unregister callback
|
||||
|
||||
:param callback: callable of subclass of :obj:`AbstractFilter`
|
||||
"""
|
||||
for record in self._registered:
|
||||
if record.callback == callback:
|
||||
self._registered.remove(record)
|
||||
|
||||
def resolve(self, event_handler, *custom_filters, **full_config
|
||||
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
|
||||
"""
|
||||
Resolve filters to filters-set
|
||||
|
||||
:param event_handler:
|
||||
:param custom_filters:
|
||||
:param full_config:
|
||||
:return:
|
||||
"""
|
||||
filters_set = []
|
||||
if custom_filters:
|
||||
filters_set.extend(custom_filters)
|
||||
if full_config:
|
||||
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler,
|
||||
{k: v for k, v in full_config.items() if v is not None}))
|
||||
return filters_set
|
||||
|
||||
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator:
|
||||
for record in self._registered:
|
||||
if not full_config:
|
||||
break
|
||||
|
||||
filter_ = record.resolve(dispatcher, event_handler, full_config)
|
||||
if filter_:
|
||||
yield filter_
|
||||
|
||||
if full_config:
|
||||
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')
|
||||
162
aiogram/dispatcher/filters/filters.py
Normal file
162
aiogram/dispatcher/filters/filters.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import abc
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from ..handler import Handler
|
||||
from ...types.base import TelegramObject
|
||||
from ...utils.deprecated import deprecated
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param filter_:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, (Filter, AbstractFilter)):
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args)
|
||||
|
||||
|
||||
async def check_filters(filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FilterRecord:
|
||||
"""
|
||||
Filters record for factory
|
||||
"""
|
||||
|
||||
def __init__(self, callback: typing.Callable,
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||
self.callback = callback
|
||||
self.event_handlers = event_handlers
|
||||
if validator is not None:
|
||||
if not callable(validator):
|
||||
raise TypeError(f"validator must be callable, not {type(validator)}")
|
||||
self.resolver = validator
|
||||
elif issubclass(callback, AbstractFilter):
|
||||
self.resolver = callback.validate
|
||||
else:
|
||||
raise RuntimeError('validator is required!')
|
||||
|
||||
def resolve(self, dispatcher, event_observer, full_config):
|
||||
if not self._check_event_handler(event_observer):
|
||||
return
|
||||
config = self.resolver(full_config)
|
||||
if config:
|
||||
return self.callback(dispatcher, **config)
|
||||
|
||||
def _check_event_handler(self, event_handler) -> bool:
|
||||
if not self.event_handlers:
|
||||
return True
|
||||
return event_handler in self.event_handlers
|
||||
|
||||
|
||||
class AbstractFilter(abc.ABC):
|
||||
"""
|
||||
Abstract class for custom filters
|
||||
"""
|
||||
|
||||
key = None
|
||||
|
||||
def __init__(self, dispatcher, **config):
|
||||
self.dispatcher = dispatcher
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
"""
|
||||
Validate and parse config
|
||||
|
||||
:param full_config:
|
||||
:return: config
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check(self, *args) -> bool:
|
||||
"""
|
||||
Check object
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def __call__(self, obj: TelegramObject) -> bool:
|
||||
return await self.check(obj)
|
||||
|
||||
|
||||
class BaseFilter(AbstractFilter):
|
||||
"""
|
||||
Abstract class for filters with default validator
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def key(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
if cls.key is not None and cls.key in full_config:
|
||||
return {cls.key: full_config.pop(cls.key)}
|
||||
|
||||
|
||||
class Filter(abc.ABC):
|
||||
"""
|
||||
Base class for filters
|
||||
Subclasses of this class can't be used with FiltersFactory by default.
|
||||
|
||||
(Backward capability)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated
|
||||
class AsyncFilter(Filter):
|
||||
"""
|
||||
Base class for asynchronous filters
|
||||
"""
|
||||
|
||||
def __aiter__(self):
|
||||
return None
|
||||
|
||||
def __await__(self):
|
||||
return self.check
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
from aiogram.utils import context
|
||||
from .filters import check_filters
|
||||
|
||||
|
||||
class SkipHandler(BaseException):
|
||||
|
|
@ -57,6 +56,8 @@ class Handler:
|
|||
:param args:
|
||||
:return:
|
||||
"""
|
||||
from .filters import check_filters
|
||||
|
||||
results = []
|
||||
|
||||
if self.middleware_key:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue