Rewrite filters

This commit is contained in:
Alex Root Junior 2022-08-14 21:43:33 +03:00
parent 3f57c69d4f
commit 8d2aae77c1
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
24 changed files with 311 additions and 539 deletions

View file

@ -5,7 +5,7 @@ import pytest
from aiogram import F
from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject
from aiogram.filters import BaseFilter
from aiogram.filters import Filter
from aiogram.handlers import BaseHandler
from aiogram.types import Update
@ -28,7 +28,7 @@ async def callback4(foo: int, *, bar: int, baz: int):
return locals()
class Filter(BaseFilter):
class Filter(Filter):
async def __call__(self, foo: int, bar: int, baz: int) -> Union[bool, Dict[str, Any]]:
return locals()

View file

@ -9,7 +9,7 @@ from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router
from aiogram.exceptions import FiltersResolveError
from aiogram.filters import BaseFilter, Command
from aiogram.filters import Command, Filter
from aiogram.types import Chat, Message, User
from tests.deprecated import check_deprecated
@ -31,7 +31,7 @@ async def pipe_handler(*args, **kwargs):
return args, kwargs
class MyFilter1(BaseFilter):
class MyFilter1(Filter):
test: str
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -46,14 +46,14 @@ class MyFilter3(MyFilter1):
pass
class OptionalFilter(BaseFilter):
class OptionalFilter(Filter):
optional: Optional[str]
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
return True
class DefaultFilter(BaseFilter):
class DefaultFilter(Filter):
default: str = "Default"
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
@ -66,7 +66,7 @@ class TestTelegramEventObserver:
with pytest.raises(TypeError):
event_observer.bind_filter(object) # type: ignore
class MyFilter(BaseFilter):
class MyFilter(Filter):
async def __call__(
self, *args: Any, **kwargs: Any
) -> Callable[[Any], Awaitable[Union[bool, Dict[str, Any]]]]:
@ -103,13 +103,13 @@ class TestTelegramEventObserver:
assert MyFilter3 not in filters_chain1
async def test_resolve_filters_data_from_parent_router(self):
class FilterSet(BaseFilter):
class FilterSet(Filter):
set_filter: bool
async def __call__(self, message: Message) -> dict:
return {"test": "hello world"}
class FilterGet(BaseFilter):
class FilterGet(Filter):
get_filter: bool
async def __call__(self, message: Message, **data) -> bool:

View file

@ -2,7 +2,7 @@ from typing import Awaitable
import pytest
from aiogram.filters import BaseFilter
from aiogram.filters import Filter
try:
from asynctest import CoroutineMock, patch
@ -13,7 +13,7 @@ except ImportError:
pytestmark = pytest.mark.asyncio
class MyFilter(BaseFilter):
class MyFilter(Filter):
foo: str
async def __call__(self, event: str):

View file

@ -16,34 +16,34 @@ class TestCommandFilter:
def test_convert_to_list(self):
cmd = Command(commands="start")
assert cmd.commands
assert isinstance(cmd.commands, list)
assert isinstance(cmd.commands, tuple)
assert cmd.commands[0] == "start"
assert cmd == Command(commands=["start"])
# assert cmd == Command(commands=["start"])
@pytest.mark.parametrize(
"text,command,result",
[
["/test@tbot", Command(commands=["test"], commands_prefix="/"), True],
["!test", Command(commands=["test"], commands_prefix="/"), False],
["/test@mention", Command(commands=["test"], commands_prefix="/"), False],
["/tests", Command(commands=["test"], commands_prefix="/"), False],
["/", Command(commands=["test"], commands_prefix="/"), False],
["/ test", Command(commands=["test"], commands_prefix="/"), False],
["", Command(commands=["test"], commands_prefix="/"), False],
[" ", Command(commands=["test"], commands_prefix="/"), False],
["test", Command(commands=["test"], commands_prefix="/"), False],
[" test", Command(commands=["test"], commands_prefix="/"), False],
["a", Command(commands=["test"], commands_prefix="/"), False],
["/test@tbot", Command(commands=["test"], prefix="/"), True],
["!test", Command(commands=["test"], prefix="/"), False],
["/test@mention", Command(commands=["test"], prefix="/"), False],
["/tests", Command(commands=["test"], prefix="/"), False],
["/", Command(commands=["test"], prefix="/"), False],
["/ test", Command(commands=["test"], prefix="/"), False],
["", Command(commands=["test"], prefix="/"), False],
[" ", Command(commands=["test"], prefix="/"), False],
["test", Command(commands=["test"], prefix="/"), False],
[" test", Command(commands=["test"], prefix="/"), False],
["a", Command(commands=["test"], prefix="/"), False],
["/test@tbot some args", Command(commands=["test"]), True],
["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"),
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "some args"),
True,
],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"),
Command(commands=[re.compile(r"test(\d+)")], magic=F.args == "test"),
False,
],
["/start test", CommandStart(), True],
@ -99,7 +99,7 @@ class TestCommandFilter:
chat=Chat(id=42, type="private"),
date=datetime.datetime.now(),
)
command = Command(commands=["test"], command_magic=(F.args.as_("args")))
command = Command(commands=["test"], magic=(F.args.as_("args")))
result = await command(message=message, bot=bot)
assert "args" in result
assert result["args"] == "42"

View file

@ -1,52 +0,0 @@
from dataclasses import dataclass
from typing import cast
import pytest
from pydantic import ValidationError
from aiogram.filters import ContentTypesFilter
from aiogram.types import ContentType, Message
pytestmark = pytest.mark.asyncio
@dataclass
class MinimalMessage:
content_type: str
class TestContentTypesFilter:
def test_validator_empty_list(self):
filter_ = ContentTypesFilter(content_types=[])
assert filter_.content_types == []
def test_convert_to_list(self):
filter_ = ContentTypesFilter(content_types="text")
assert filter_.content_types
assert isinstance(filter_.content_types, list)
assert filter_.content_types[0] == "text"
assert filter_ == ContentTypesFilter(content_types=["text"])
@pytest.mark.parametrize("values", [["text", "photo"], ["sticker"]])
def test_validator_with_values(self, values):
filter_ = ContentTypesFilter(content_types=values)
assert filter_.content_types == values
@pytest.mark.parametrize("values", [["test"], ["text", "test"], ["TEXT"]])
def test_validator_with_bad_values(self, values):
with pytest.raises(ValidationError):
ContentTypesFilter(content_types=values)
@pytest.mark.parametrize(
"values,content_type,result",
[
[[ContentType.TEXT], ContentType.TEXT, True],
[[ContentType.PHOTO], ContentType.TEXT, False],
[[ContentType.ANY], ContentType.TEXT, True],
[[ContentType.TEXT, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
[[ContentType.ANY, ContentType.PHOTO, ContentType.DOCUMENT], ContentType.TEXT, True],
],
)
async def test_call(self, values, content_type, result):
filter_ = ContentTypesFilter(content_types=values)
assert await filter_(cast(Message, MinimalMessage(content_type=content_type))) == result

View file

@ -46,7 +46,7 @@ class TestExceptionTypeFilter:
],
)
async def test_check(self, exception: Exception, value: bool):
obj = ExceptionTypeFilter(exception=MyException)
obj = ExceptionTypeFilter(exceptions=MyException)
result = await obj(Update(update_id=0), exception=exception)