diff --git a/CHANGES/1368.bugfix.rst b/CHANGES/1368.bugfix.rst new file mode 100644 index 00000000..f541e016 --- /dev/null +++ b/CHANGES/1368.bugfix.rst @@ -0,0 +1 @@ +Fixed a situation where a :code:`CallbackData` could not be parsed without a default value. diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index 7a09dedb..17ccffbf 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -1,5 +1,8 @@ from __future__ import annotations +import sys +import types +import typing from decimal import Decimal from enum import Enum from fractions import Fraction @@ -18,6 +21,7 @@ from uuid import UUID from magic_filter import MagicFilter from pydantic import BaseModel +from pydantic.fields import FieldInfo from aiogram.filters.base import Filter from aiogram.types import CallbackQuery @@ -27,6 +31,11 @@ T = TypeVar("T", bound="CallbackData") MAX_CALLBACK_LENGTH: int = 64 +_UNION_TYPES = {typing.Union} +if sys.version_info >= (3, 10): # pragma: no cover + _UNION_TYPES.add(types.UnionType) + + class CallbackDataException(Exception): pass @@ -121,7 +130,7 @@ class CallbackData(BaseModel): payload = {} for k, v in zip(names, parts): # type: str, Optional[str] if field := cls.model_fields.get(k): - if v == "" and not field.is_required(): + if v == "" and _check_field_is_nullable(field): v = None payload[k] = v return cls(**payload) @@ -180,3 +189,19 @@ class CallbackQueryFilter(Filter): if self.rule is None or self.rule.resolve(callback_data): return {"callback_data": callback_data} return False + + +def _check_field_is_nullable(field: FieldInfo) -> bool: + """ + Check if the given field is nullable. + + :param field: The FieldInfo object representing the field to check. + :return: True if the field is nullable, False otherwise. + + """ + if not field.is_required(): + return True + + return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args( + field.annotation + ) diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index e8721a41..635b8e9f 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,7 +1,8 @@ +import sys from decimal import Decimal from enum import Enum, auto from fractions import Fraction -from typing import Optional +from typing import Optional, Union from uuid import UUID import pytest @@ -147,6 +148,32 @@ class TestCallbackData: assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) + @pytest.mark.parametrize( + "hint", + [ + Union[int, None], + Optional[int], + ], + ) + def test_unpack_optional_wo_default(self, hint): + """Test CallbackData without default optional.""" + + class TgData(CallbackData, prefix="tg"): + chat_id: int + thread_id: hint + + assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType is added in Python 3.10") + def test_unpack_optional_wo_default_union_type(self): + """Test CallbackData without default optional.""" + + class TgData(CallbackData, prefix="tg"): + chat_id: int + thread_id: int | None + + assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None) + def test_build_filter(self): filter_object = MyCallback.filter(F.foo == "test") assert isinstance(filter_object.rule, MagicFilter)