mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add support for nullable fields in callback data
This update extends the callback data handling by adding support for nullable fields. The code now uses the Python typing structures `Optional` and `Union` to parse such fields correctly. A helper function `_check_field_is_nullable` has been added to assist in efficiently checking if a given field is nullable.
This commit is contained in:
parent
7c295f6b3d
commit
42599fa82a
2 changed files with 31 additions and 5 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
|
|
@ -18,6 +19,7 @@ from uuid import UUID
|
|||
|
||||
from magic_filter import MagicFilter
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import CallbackQuery
|
||||
|
|
@ -120,8 +122,9 @@ class CallbackData(BaseModel):
|
|||
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
|
||||
payload = {}
|
||||
for k, v in zip(names, parts): # type: str, Optional[str]
|
||||
if v == "":
|
||||
v = None
|
||||
if field := cls.model_fields.get(k):
|
||||
if v == "" and _check_field_is_nullable(field):
|
||||
v = None
|
||||
payload[k] = v
|
||||
return cls(**payload)
|
||||
|
||||
|
|
@ -179,3 +182,19 @@ class CallbackQueryFilter(Filter):
|
|||
if self.rule is None or self.rule.resolve(callback_data):
|
||||
return {"callback_data": callback_data}
|
||||
return False
|
||||
|
||||
|
||||
def _check_field_is_nullable(field: FieldInfo) -> bool:
|
||||
"""
|
||||
Check if the given field is nullable.
|
||||
|
||||
:param field: The FieldInfo object representing the field to check.
|
||||
:return: True if the field is nullable, False otherwise.
|
||||
|
||||
"""
|
||||
if not field.is_required():
|
||||
return True
|
||||
|
||||
return typing.get_origin(field.annotation) is typing.Union and type(None) in typing.get_args(
|
||||
field.annotation
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from decimal import Decimal
|
||||
from enum import Enum, auto
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
|
@ -147,12 +147,19 @@ class TestCallbackData:
|
|||
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
|
||||
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
|
||||
|
||||
def test_unpack_optional_wo_default(self):
|
||||
@pytest.mark.parametrize(
|
||||
"hint",
|
||||
[
|
||||
Union[int, None],
|
||||
Optional[int],
|
||||
],
|
||||
)
|
||||
def test_unpack_optional_wo_default(self, hint):
|
||||
"""Test CallbackData without default optional."""
|
||||
|
||||
class TgData(CallbackData, prefix="tg"):
|
||||
chat_id: int
|
||||
thread_id: Optional[int]
|
||||
thread_id: hint
|
||||
|
||||
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue