From 9ee1a0cbed5e0d6d8bf979eb7554952ee7fa83ee Mon Sep 17 00:00:00 2001 From: Nikita <43146729+gabbhack@users.noreply.github.com> Date: Sat, 6 Apr 2019 22:47:33 +0500 Subject: [PATCH] Optimize filter checking --- aiogram/dispatcher/filters/__init__.py | 8 +++-- aiogram/dispatcher/filters/filters.py | 43 +++++++++++++++++--------- aiogram/dispatcher/handler.py | 16 ++++++++-- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 489e7c76..49a0e637 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,7 +1,8 @@ from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \ ExceptionsFilter, FuncFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text from .factory import FiltersFactory -from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters +from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, FilterObj, execute_filter, \ + check_filters, get_filter_spec, get_filters_spec __all__ = [ 'AbstractFilter', @@ -23,6 +24,9 @@ __all__ = [ 'Regexp', 'StateFilter', 'Text', - 'check_filter', + 'FilterObj', + 'get_filters_spec', + 'get_filters_spec', + 'execute_filter', 'check_filters' ] diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index 7c0203f2..a6d83c62 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -2,7 +2,7 @@ import abc import inspect import typing -from ..handler import Handler +from ..handler import Handler, FilterObj class FilterNotPassed(Exception): @@ -20,15 +20,7 @@ def wrap_async(func): return async_wrapper -async def check_filter(dispatcher, filter_, args): - """ - Helper for executing filter - - :param dispatcher: - :param filter_: - :param args: - :return: - """ +def get_filter_spec(dispatcher, filter_: callable): kwargs = {} if not callable(filter_): raise TypeError('Filter must be callable and/or awaitable!') @@ -39,16 +31,37 @@ async def check_filter(dispatcher, filter_, args): if inspect.isawaitable(filter_) \ or inspect.iscoroutinefunction(filter_) \ or isinstance(filter_, AbstractFilter): - return await filter_(*args, **kwargs) + return FilterObj(filter=filter_, kwargs=kwargs, is_async=True) else: - return filter_(*args, **kwargs) + return FilterObj(filter=filter_, kwargs=kwargs, is_async=False) -async def check_filters(dispatcher, filters, args): +def get_filters_spec(dispatcher, filters: typing.Iterable[callable]): + data = [] + if filters is not None: + for i in filters: + data.append(get_filter_spec(dispatcher, i)) + return data + + +async def execute_filter(filter_: FilterObj, args): + """ + Helper for executing filter + + :param filter_: + :param args: + :return: + """ + if filter_.is_async: + return await filter_.filter(*args, **filter_.kwargs) + else: + return filter_.filter(*args, **filter_.kwargs) + + +async def check_filters(filters: typing.Iterable[FilterObj], args): """ Check list of filters - :param dispatcher: :param filters: :param args: :return: @@ -56,7 +69,7 @@ async def check_filters(dispatcher, filters, args): data = {} if filters is not None: for filter_ in filters: - f = await check_filter(dispatcher, filter_, args) + f = await execute_filter(filter_, args) if not f: raise FilterNotPassed() elif isinstance(f, dict): diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 24771e92..17b715d1 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,12 +1,19 @@ import inspect from contextvars import ContextVar from dataclasses import dataclass -from typing import Union +from typing import Optional, Iterable ctx_data = ContextVar('ctx_handler_data') current_handler = ContextVar('current_handler') +@dataclass +class FilterObj: + filter: callable + kwargs: dict + is_async: bool + + class SkipHandler(Exception): pass @@ -39,6 +46,7 @@ class Handler: self.middleware_key = middleware_key def register(self, handler, filters=None, index=None): + from .filters import get_filters_spec """ Register callback @@ -52,6 +60,8 @@ class Handler: if filters and not isinstance(filters, (list, tuple, set)): filters = [filters] + filters = get_filters_spec(self.dispatcher, filters) + record = Handler.HandlerObj(handler=handler, spec=spec, filters=filters) if index is None: self.handlers.append(record) @@ -95,7 +105,7 @@ class Handler: try: for handler_obj in self.handlers: try: - data.update(await check_filters(self.dispatcher, handler_obj.filters, args)) + data.update(await check_filters(handler_obj.filters, args)) except FilterNotPassed: continue else: @@ -126,4 +136,4 @@ class Handler: class HandlerObj: handler: callable spec: inspect.FullArgSpec - filters: Union[list, tuple, set, None] = None + filters: Optional[Iterable[FilterObj]] = None