Added possibility to use empty callback data factory filter

This commit is contained in:
Alex Root Junior 2021-07-29 00:05:15 +03:00
parent 7d56751b09
commit 5e3356fa62
3 changed files with 13 additions and 5 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from decimal import Decimal
from enum import Enum
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union, Literal
from uuid import UUID
from magic_filter import MagicFilter
@ -86,7 +86,7 @@ class CallbackData(BaseModel):
return cls(**payload)
@classmethod
def filter(cls, rule: MagicFilter) -> CallbackQueryFilter:
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter:
return CallbackQueryFilter(callback_data=cls, rule=rule)
class Config:
@ -95,9 +95,9 @@ class CallbackData(BaseModel):
class CallbackQueryFilter(BaseFilter):
callback_data: Type[CallbackData]
rule: MagicFilter
rule: Optional[MagicFilter] = None
async def __call__(self, query: CallbackQuery) -> Union[bool, Dict[str, Any]]:
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
if not isinstance(query, CallbackQuery) or not query.data:
return False
try:
@ -105,7 +105,7 @@ class CallbackQueryFilter(BaseFilter):
except (TypeError, ValueError):
return False
if self.rule.resolve(callback_data):
if self.rule is None or self.rule.resolve(callback_data):
return {"callback_data": callback_data}
return False

View file

@ -42,6 +42,10 @@ async def redis_storage(redis_server):
if not redis_server:
pytest.skip("Redis is not available here")
storage = RedisStorage.from_url(redis_server)
try:
await storage.redis.info()
except ConnectionError as e:
pytest.skip(str(e))
try:
yield storage
finally:

View file

@ -156,7 +156,11 @@ class TestCallbackDataFilter:
["test", F.foo == "test", False],
["test:spam:42", F.foo == "test", False],
["test:test:42", F.foo == "test", {"callback_data": MyCallback(foo="test", bar=42)}],
["test:test:42", None, {"callback_data": MyCallback(foo="test", bar=42)}],
["test:test:777", None, {"callback_data": MyCallback(foo="test", bar=777)}],
["spam:test:777", None, False],
["test:test:", F.foo == "test", False],
["test:test:", None, False],
],
)
@pytest.mark.asyncio