diff --git a/aiogram/filters/callback_data.py b/aiogram/filters/callback_data.py index fde39761..7c0dadf8 100644 --- a/aiogram/filters/callback_data.py +++ b/aiogram/filters/callback_data.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from decimal import Decimal from enum import Enum from fractions import Fraction @@ -18,6 +19,7 @@ from uuid import UUID from magic_filter import MagicFilter from pydantic import BaseModel +from pydantic.fields import FieldInfo from aiogram.filters.base import Filter from aiogram.types import CallbackQuery @@ -120,8 +122,9 @@ class CallbackData(BaseModel): raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})") payload = {} for k, v in zip(names, parts): # type: str, Optional[str] - if v == "": - v = None + if field := cls.model_fields.get(k): + if v == "" and _check_field_is_nullable(field): + v = None payload[k] = v return cls(**payload) @@ -179,3 +182,19 @@ class CallbackQueryFilter(Filter): if self.rule is None or self.rule.resolve(callback_data): return {"callback_data": callback_data} return False + + +def _check_field_is_nullable(field: FieldInfo) -> bool: + """ + Check if the given field is nullable. + + :param field: The FieldInfo object representing the field to check. + :return: True if the field is nullable, False otherwise. + + """ + if not field.is_required(): + return True + + return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args( + field.annotation + ) diff --git a/tests/test_filters/test_callback_data.py b/tests/test_filters/test_callback_data.py index e52e66b2..4314aa34 100644 --- a/tests/test_filters/test_callback_data.py +++ b/tests/test_filters/test_callback_data.py @@ -1,7 +1,7 @@ from decimal import Decimal from enum import Enum, auto from fractions import Fraction -from typing import Optional +from typing import Optional, Union from uuid import UUID import pytest @@ -147,12 +147,19 @@ class TestCallbackData: assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) - def test_unpack_optional_wo_default(self): + @pytest.mark.parametrize( + "hint", + [ + Union[int, None], + Optional[int], + ], + ) + def test_unpack_optional_wo_default(self, hint): """Test CallbackData without default optional.""" class TgData(CallbackData, prefix="tg"): chat_id: int - thread_id: Optional[int] + thread_id: hint assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)