diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py index 096a2254..cf9d00aa 100644 --- a/aiogram/dispatcher/event/observer.py +++ b/aiogram/dispatcher/event/observer.py @@ -70,18 +70,18 @@ class TelegramEventObserver(EventObserver): self.filters: List[Type[BaseFilter]] = [] def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: - if not isinstance(bound_filter, BaseFilter): - pass + if not issubclass(bound_filter, BaseFilter): + raise TypeError( + "bound_filter() argument 'bound_filter' must be subclass of BaseFilter" + ) self.filters.append(bound_filter) def _resolve_filters_chain(self): registry: List[FilterType] = [] - routers: List[Router] = [] router = self.router - while router and router not in routers: + while router: observer = router.observers[self.event_name] - routers.append(router) router = router.parent_router for filter_ in observer.filters: diff --git a/aiogram/dispatcher/filters/base.py b/aiogram/dispatcher/filters/base.py index e49db340..cab3c4ec 100644 --- a/aiogram/dispatcher/filters/base.py +++ b/aiogram/dispatcher/filters/base.py @@ -16,7 +16,7 @@ class BaseFilter(ABC, BaseModel): @abstractmethod async def __call__( self, *args: Any, **kwargs: Any - ) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]: + ) -> Union[bool, Dict[str, Any]]: pass def __await__(self): diff --git a/tests/test_dispatcher/test_event/test_observer.py b/tests/test_dispatcher/test_event/test_observer.py index 651387dd..0b21b62e 100644 --- a/tests/test_dispatcher/test_event/test_observer.py +++ b/tests/test_dispatcher/test_event/test_observer.py @@ -1,10 +1,11 @@ import functools -from typing import Any, NoReturn +from typing import Any, Awaitable, Callable, Dict, NoReturn, Union import pytest - from aiogram.dispatcher.event.handler import HandlerObject -from aiogram.dispatcher.event.observer import EventObserver, SkipHandler +from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver +from aiogram.dispatcher.filters.base import BaseFilter +from aiogram.dispatcher.router import Router async def my_handler(event: Any, index: int = 0) -> Any: @@ -106,3 +107,65 @@ class TestEventObserver: results = [result async for result in observer.trigger(42)] assert results == [((42,), {"b": 2})] + + +class TestTelegramEventObserver: + def test_bind_filter(self): + event_observer = TelegramEventObserver(Router(), "test") + with pytest.raises(TypeError): + event_observer.bind_filter(object) # type: ignore + + class MyFilter(BaseFilter): + async def __call__( + self, *args: Any, **kwargs: Any + ) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]: + pass + + event_observer.bind_filter(MyFilter) + assert event_observer.filters + assert event_observer.filters[0] == MyFilter + + def test_resolve_filters_chain(self): + router1 = Router() + router2 = Router() + router3 = Router() + router1.include_router(router2) + router2.include_router(router3) + + class MyFilter1(BaseFilter): + test: str + + async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: + return True + + class MyFilter2(MyFilter1): + pass + + class MyFilter3(MyFilter1): + pass + + router1.message_handler.bind_filter(MyFilter1) + router1.message_handler.bind_filter(MyFilter2) + router2.message_handler.bind_filter(MyFilter2) + router3.message_handler.bind_filter(MyFilter3) + + filters_chain1 = list(router1.message_handler._resolve_filters_chain()) + filters_chain2 = list(router2.message_handler._resolve_filters_chain()) + filters_chain3 = list(router3.message_handler._resolve_filters_chain()) + + assert filters_chain1 == [MyFilter1, MyFilter2] + assert filters_chain2 == [MyFilter2, MyFilter1] + assert filters_chain3 == [MyFilter3, MyFilter2, MyFilter1] + + def test_resolve_filters(self): + pass + + def test_register(self): + pass + + def test_register_decorator(self): + pass + + @pytest.mark.asyncio + async def test_trigger(self): + pass