Fixed tests

This commit is contained in:
Alex Root Junior 2022-09-11 23:16:12 +03:00
parent 2aa4a9fe31
commit 2e6eed0949
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
25 changed files with 670 additions and 515 deletions

View file

@ -30,6 +30,7 @@ from aiogram.types import (
Update,
User,
)
from aiogram.types.error_event import ErrorEvent
from tests.mocked_bot import MockedBot
try:
@ -650,15 +651,15 @@ class TestDispatcher:
await dp.feed_update(bot, update)
@router.errors()
async def error_handler(event: Update, exception: Exception):
async def error_handler(event: ErrorEvent):
return "KABOOM"
response = await dp.feed_update(bot, update)
assert response == "KABOOM"
@dp.errors()
async def root_error_handler(event: Update, exception: Exception):
return exception
async def root_error_handler(event: ErrorEvent):
return event.exception
response = await dp.feed_update(bot, update)

View file

@ -28,7 +28,7 @@ async def callback4(foo: int, *, bar: int, baz: int):
return locals()
class Filter(Filter):
class TestFilter(Filter):
async def __call__(self, foo: int, bar: int, baz: int) -> Union[bool, Dict[str, Any]]:
return locals()
@ -39,7 +39,7 @@ class SyncCallable:
class TestCallableMixin:
@pytest.mark.parametrize("callback", [callback2, Filter()])
@pytest.mark.parametrize("callback", [callback2, TestFilter()])
def test_init_awaitable(self, callback):
obj = CallableMixin(callback)
assert obj.awaitable
@ -57,7 +57,7 @@ class TestCallableMixin:
pytest.param(callback1, {"foo", "bar", "baz"}),
pytest.param(callback2, {"foo", "bar", "baz"}),
pytest.param(callback3, {"foo"}),
pytest.param(Filter(), {"self", "foo", "bar", "baz"}),
pytest.param(TestFilter(), {"self", "foo", "bar", "baz"}),
pytest.param(SyncCallable(), {"self", "foo", "bar", "baz"}),
],
)
@ -117,7 +117,7 @@ class TestCallableMixin:
{"foo": 42, "baz": "fuz", "bar": "test"},
),
pytest.param(
Filter(), {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
TestFilter(), {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}
),
pytest.param(
SyncCallable(), {"foo": 42, "spam": True, "baz": "fuz"}, {"foo": 42, "baz": "fuz"}

View file

@ -1,17 +1,16 @@
import datetime
import functools
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, Union
from typing import Any, Dict, NoReturn, Optional, Union
import pytest
from pydantic import BaseModel
from aiogram.dispatcher.event.bases import REJECTED, SkipHandler
from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router
from aiogram.exceptions import FiltersResolveError
from aiogram.filters import Command, Filter
from aiogram.filters import Filter
from aiogram.types import Chat, Message, User
from tests.deprecated import check_deprecated
pytestmark = pytest.mark.asyncio
@ -31,7 +30,7 @@ async def pipe_handler(*args, **kwargs):
return args, kwargs
class MyFilter1(Filter):
class MyFilter1(Filter, BaseModel):
test: str
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -46,14 +45,14 @@ class MyFilter3(MyFilter1):
pass
class OptionalFilter(Filter):
class OptionalFilter(Filter, BaseModel):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(Filter):
class DefaultFilter(Filter, BaseModel):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -61,144 +60,9 @@ class DefaultFilter(Filter):
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
with pytest.raises(TypeError):
event_observer.bind_filter(object) # type: ignore
class MyFilter(Filter):
async def __call__(
self, *args: Any, **kwargs: Any
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
pass
event_observer.bind_filter(MyFilter)
assert event_observer.filters
assert MyFilter in event_observer.filters
def test_resolve_filters_chain(self):
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router3 = Router(use_builtin_filters=False)
router1.include_router(router2)
router2.include_router(router3)
router1.message.bind_filter(MyFilter1)
router1.message.bind_filter(MyFilter2)
router2.message.bind_filter(MyFilter2)
router3.message.bind_filter(MyFilter3)
filters_chain1 = list(router1.message._resolve_filters_chain())
filters_chain2 = list(router2.message._resolve_filters_chain())
filters_chain3 = list(router3.message._resolve_filters_chain())
assert MyFilter1 in filters_chain1
assert MyFilter1 in filters_chain2
assert MyFilter1 in filters_chain3
assert MyFilter2 in filters_chain1
assert MyFilter2 in filters_chain2
assert MyFilter2 in filters_chain3
assert MyFilter3 in filters_chain3
assert MyFilter3 not in filters_chain1
async def test_resolve_filters_data_from_parent_router(self):
class FilterSet(Filter):
set_filter: bool
async def __call__(self, message: Message) -> dict:
return {"test": "hello world"}
class FilterGet(Filter):
get_filter: bool
async def __call__(self, message: Message, **data) -> bool:
assert "test" in data
return True
router1 = Router(use_builtin_filters=False)
router2 = Router(use_builtin_filters=False)
router1.include_router(router2)
router1.message.bind_filter(FilterSet)
router2.message.bind_filter(FilterGet)
@router2.message(set_filter=True, get_filter=True)
def handler_test(msg: Message, test: str):
assert test == "hello world"
await router1.propagate_event(
"message",
Message(message_id=1, date=datetime.datetime.now(), chat=Chat(id=1, type="private")),
)
def test_resolve_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
resolved = observer.resolve_filters((), {"test": "PASS"})
assert isinstance(resolved, list)
assert any(isinstance(item, MyFilter1) for item in resolved)
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters((), {"@bad": "very"})
# Unknown filter
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'@bad'}"):
assert observer.resolve_filters((), {"test": "ok", "@bad": "very"})
# Bad argument type
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
assert observer.resolve_filters((), {"test": ...})
# Disallow same filter using
with pytest.raises(FiltersResolveError, match="Unknown keyword filters: {'test'}"):
observer.resolve_filters((MyFilter1(test="test"),), {"test": ...})
def test_dont_autoresolve_optional_filters_for_router(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
observer.filter(test="test")
assert len(observer._handler.filters) == 1
def test_register_autoresolve_optional_filters(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.bind_filter(MyFilter1)
observer.bind_filter(OptionalFilter)
observer.bind_filter(DefaultFilter)
assert observer.register(my_handler) == my_handler
assert isinstance(observer.handlers[0], HandlerObject)
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert len(observer.handlers[0].filters) == 2
assert isinstance(observer.handlers[0].filters[0].callback, OptionalFilter)
assert isinstance(observer.handlers[0].filters[1].callback, DefaultFilter)
observer.register(my_handler, test="ok")
assert isinstance(observer.handlers[1], HandlerObject)
assert len(observer.handlers[1].filters) == 3
assert isinstance(observer.handlers[1].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[1].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[1].filters[2].callback, DefaultFilter)
observer.register(my_handler, test="ok", optional="ok")
assert isinstance(observer.handlers[2], HandlerObject)
assert len(observer.handlers[2].filters) == 3
assert isinstance(observer.handlers[2].filters[0].callback, MyFilter1)
assert isinstance(observer.handlers[2].filters[1].callback, OptionalFilter)
assert isinstance(observer.handlers[2].filters[2].callback, DefaultFilter)
def test_register(self):
router = Router(use_builtin_filters=False)
router = Router()
observer = router.message
observer.bind_filter(MyFilter1)
assert observer.register(my_handler) == my_handler
assert isinstance(observer.handlers[0], HandlerObject)
@ -210,19 +74,19 @@ class TestTelegramEventObserver:
assert len(observer.handlers[1].filters) == 1
assert observer.handlers[1].filters[0].callback == f
observer.register(my_handler, test="PASS")
observer.register(my_handler, MyFilter1(test="PASS"))
assert isinstance(observer.handlers[2], HandlerObject)
assert any(isinstance(item.callback, MyFilter1) for item in observer.handlers[2].filters)
f2 = MyFilter2(test="ok")
observer.register(my_handler, f2, test="PASS")
observer.register(my_handler, f2, MyFilter1(test="PASS"))
assert isinstance(observer.handlers[3], HandlerObject)
callbacks = [filter_.callback for filter_ in observer.handlers[3].filters]
assert f2 in callbacks
assert MyFilter1(test="PASS") in callbacks
def test_register_decorator(self):
router = Router(use_builtin_filters=False)
router = Router()
observer = router.message
@observer()
@ -233,10 +97,9 @@ class TestTelegramEventObserver:
assert observer.handlers[0].callback == my_handler
async def test_trigger(self):
router = Router(use_builtin_filters=False)
router = Router()
observer = router.message
observer.bind_filter(MyFilter1)
observer.register(my_handler, test="ok")
observer.register(my_handler, MyFilter1(test="ok"))
message = Message(
message_id=42,
@ -258,7 +121,7 @@ class TestTelegramEventObserver:
),
)
def test_register_filters_via_decorator(self, count, handler, filters):
router = Router(use_builtin_filters=False)
router = Router()
observer = router.message
for index in range(count):
@ -272,7 +135,7 @@ class TestTelegramEventObserver:
assert len(registered_handler.filters) == len(filters)
async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False)
router = Router()
observer = router.message
async def mix_unnecessary_data(event):
@ -328,7 +191,7 @@ class TestTelegramEventObserver:
assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3]
def test_register_global_filters(self):
router = Router(use_builtin_filters=False)
router = Router()
assert isinstance(router.message._handler.filters, list)
assert not router.message._handler.filters
@ -369,13 +232,3 @@ class TestTelegramEventObserver:
r2.message.register(handler)
assert await r1.message.trigger(None) is REJECTED
def test_deprecated_bind_filter(self):
router = Router()
with check_deprecated("3.0b5", exception=AttributeError):
router.message.bind_filter(MyFilter1)
def test_deprecated_resolve_filters(self):
router = Router()
with check_deprecated("3.0b5", exception=AttributeError):
router.message.resolve_filters([Command], full_config={"commands": ["test"]})

View file

@ -2,7 +2,6 @@ import pytest
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler, skip
from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect
pytestmark = pytest.mark.asyncio
@ -36,15 +35,6 @@ class TestRouter:
assert router3.parent_router is router2
assert router3.sub_routers == []
def test_include_router_code_has_no_effect(self):
router1 = Router()
router2 = Router(use_builtin_filters=False)
assert router1.use_builtin_filters
assert not router2.use_builtin_filters
with pytest.warns(CodeHasNoEffect):
assert router1.include_router(router2)
def test_include_router_by_string_bad_type(self):
router = Router()
with pytest.raises(ValueError, match=r"router should be instance of Router"):

View file

@ -3,6 +3,7 @@ from typing import Awaitable
import pytest
from aiogram.filters import Filter
from aiogram.filters.base import _InvertFilter
try:
from asynctest import CoroutineMock, patch
@ -14,15 +15,13 @@ pytestmark = pytest.mark.asyncio
class MyFilter(Filter):
foo: str
async def __call__(self, event: str):
return
class TestBaseFilter:
async def test_awaitable(self):
my_filter = MyFilter(foo="bar")
my_filter = MyFilter()
assert isinstance(my_filter, Awaitable)
@ -33,3 +32,18 @@ class TestBaseFilter:
call = my_filter(event="test")
await call
mocked_call.assert_awaited_with(event="test")
async def test_invert(self):
my_filter = MyFilter()
my_inverted_filter = ~my_filter
assert isinstance(my_inverted_filter, _InvertFilter)
with patch(
"tests.test_filters.test_base.MyFilter.__call__",
new_callable=CoroutineMock,
) as mocked_call:
call = my_inverted_filter(event="test")
result = await call
mocked_call.assert_awaited_with(event="test")
assert not result

View file

@ -5,6 +5,7 @@ import pytest
from aiogram import Dispatcher
from aiogram.filters import ExceptionMessageFilter, ExceptionTypeFilter
from aiogram.types import Update
from aiogram.types.error_event import ErrorEvent
pytestmark = pytest.mark.asyncio
@ -18,10 +19,10 @@ class TestExceptionMessageFilter:
async def test_match(self):
obj = ExceptionMessageFilter(pattern="KABOOM")
result = await obj(Update(update_id=0), exception=Exception())
result = await obj(ErrorEvent(update=Update(update_id=0), exception=Exception()))
assert not result
result = await obj(Update(update_id=0), exception=Exception("KABOOM"))
result = await obj(ErrorEvent(update=Update(update_id=0), exception=Exception("KABOOM")))
assert isinstance(result, dict)
assert "match_exception" in result
@ -46,9 +47,9 @@ class TestExceptionTypeFilter:
],
)
async def test_check(self, exception: Exception, value: bool):
obj = ExceptionTypeFilter(exceptions=MyException)
obj = ExceptionTypeFilter(MyException)
result = await obj(Update(update_id=0), exception=exception)
result = await obj(ErrorEvent(update=Update(update_id=0), exception=exception))
assert result == value
@ -62,7 +63,7 @@ class TestDispatchException:
raise ValueError("KABOOM")
@dp.errors(ExceptionMessageFilter(pattern="KABOOM"))
async def handler0(update, exception):
async def handler0(event):
return "Handled"
assert await dp.feed_update(bot, Update(update_id=0)) == "Handled"

View file

@ -1,37 +0,0 @@
import pytest
from aiogram.filters import Text, and_f, invert_f, or_f
from aiogram.filters.logic import _AndFilter, _InvertFilter, _OrFilter
class TestLogic:
@pytest.mark.parametrize(
"obj,case,result",
[
[True, and_f(lambda t: t is True, lambda t: t is True), True],
[True, and_f(lambda t: t is True, lambda t: t is False), False],
[True, and_f(lambda t: t is False, lambda t: t is False), False],
[True, and_f(lambda t: {"t": t}, lambda t: t is False), False],
[True, and_f(lambda t: {"t": t}, lambda t: t is True), {"t": True}],
[True, or_f(lambda t: t is True, lambda t: t is True), True],
[True, or_f(lambda t: t is True, lambda t: t is False), True],
[True, or_f(lambda t: t is False, lambda t: t is False), False],
[True, or_f(lambda t: t is False, lambda t: t is True), True],
[True, or_f(lambda t: t is False, lambda t: {"t": t}), {"t": True}],
[True, or_f(lambda t: {"t": t}, lambda t: {"a": 42}), {"t": True}],
[True, invert_f(lambda t: t is False), True],
],
)
async def test_logic(self, obj, case, result):
assert await case(obj) == result
@pytest.mark.parametrize(
"case,type_",
[
[Text(text="test") | Text(text="test"), _OrFilter],
[Text(text="test") & Text(text="test"), _AndFilter],
[~Text(text="test"), _InvertFilter],
],
)
def test_dunder_methods(self, case, type_):
assert isinstance(case, type_)

View file

@ -16,13 +16,11 @@ class MyGroup(StatesGroup):
class TestStateFilter:
@pytest.mark.parametrize(
"state", [None, State("test"), MyGroup, MyGroup(), "state", ["state"]]
)
@pytest.mark.parametrize("state", [None, State("test"), MyGroup, MyGroup(), "state"])
def test_validator(self, state):
f = StateFilter(state=state)
assert isinstance(f.state, list)
value = f.state[0]
f = StateFilter(state)
assert isinstance(f.states, tuple)
value = f.states[0]
assert (
isinstance(value, (State, str, MyGroup))
or (isclass(value) and issubclass(value, StatesGroup))
@ -32,17 +30,11 @@ class TestStateFilter:
@pytest.mark.parametrize(
"state,current_state,result",
[
[State("state"), "@:state", True],
[[State("state")], "@:state", True],
[MyGroup, "MyGroup:state", True],
[[MyGroup], "MyGroup:state", True],
[MyGroup(), "MyGroup:state", True],
[[MyGroup()], "MyGroup:state", True],
["*", "state", True],
[None, None, True],
[["*"], "state", True],
[[None], None, True],
[None, "state", False],
[[], "state", False],
[[State("state"), "state"], "state", True],
[[MyGroup(), State("state")], "@:state", True],
[[MyGroup, State("state")], "state", False],
@ -50,9 +42,13 @@ class TestStateFilter:
)
@pytestmark
async def test_filter(self, state, current_state, result):
f = StateFilter(state=state)
f = StateFilter(*state)
assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result
def test_empty_filter(self):
with pytest.raises(ValueError):
StateFilter()
@pytestmark
async def test_create_filter_from_state(self):
FilterObject(callback=State(state="state"))

View file

@ -3,50 +3,38 @@ from itertools import permutations
from typing import Sequence, Type
import pytest
from pydantic import ValidationError
from aiogram.filters import BUILTIN_FILTERS, Text
from aiogram.filters import Text
from aiogram.types import CallbackQuery, Chat, InlineQuery, Message, Poll, PollOption, User
pytestmark = pytest.mark.asyncio
class TestText:
def test_default_for_observer(self):
registered_for = {
update_type for update_type, filters in BUILTIN_FILTERS.items() if Text in filters
}
assert registered_for == {
"message",
"edited_message",
"channel_post",
"edited_channel_post",
"inline_query",
"callback_query",
}
def test_validator_not_enough_arguments(self):
with pytest.raises(ValidationError):
Text()
with pytest.raises(ValidationError):
Text(text_ignore_case=True)
@pytest.mark.parametrize(
"first,last",
permutations(["text", "text_contains", "text_startswith", "text_endswith"], 2),
"kwargs",
[
{},
{"ignore_case": True},
{"ignore_case": False},
],
)
@pytest.mark.parametrize("ignore_case", [True, False])
def test_validator_too_few_arguments(self, first, last, ignore_case):
kwargs = {first: "test", last: "test"}
if ignore_case:
kwargs["text_ignore_case"] = True
with pytest.raises(ValidationError):
def test_not_enough_arguments(self, kwargs):
with pytest.raises(ValueError):
Text(**kwargs)
@pytest.mark.parametrize(
"argument", ["text", "text_contains", "text_startswith", "text_endswith"]
"first,last",
permutations(["text", "contains", "startswith", "endswith"], 2),
)
@pytest.mark.parametrize("ignore_case", [True, False])
def test_validator_too_few_arguments(self, first, last, ignore_case):
kwargs = {first: "test", last: "test", "ignore_case": ignore_case}
with pytest.raises(ValueError):
Text(**kwargs)
@pytest.mark.parametrize("argument", ["text", "contains", "startswith", "endswith"])
@pytest.mark.parametrize("input_type", [str, list, tuple])
def test_validator_convert_to_list(self, argument: str, input_type: Type):
text = Text(**{argument: input_type("test")})
@ -121,7 +109,7 @@ class TestText:
False,
],
[
"text_startswith",
"startswith",
False,
"test",
Message(
@ -134,7 +122,7 @@ class TestText:
True,
],
[
"text_endswith",
"endswith",
False,
"case",
Message(
@ -147,7 +135,7 @@ class TestText:
True,
],
[
"text_contains",
"contains",
False,
" ",
Message(
@ -160,7 +148,7 @@ class TestText:
True,
],
[
"text_startswith",
"startswith",
True,
"question",
Message(
@ -182,7 +170,7 @@ class TestText:
True,
],
[
"text_startswith",
"startswith",
True,
"callback:",
CallbackQuery(
@ -194,7 +182,7 @@ class TestText:
True,
],
[
"text_startswith",
"startswith",
True,
"query",
InlineQuery(
@ -242,5 +230,6 @@ class TestText:
],
)
async def test_check_text(self, argument, ignore_case, input_value, result, update_type):
text = Text(**{argument: input_value}, text_ignore_case=ignore_case)
assert await text(obj=update_type) is result
text = Text(**{argument: input_value}, ignore_case=ignore_case)
test = await text(update_type)
assert test is result