Merge branch 'dev-3.x' into scenes

# Conflicts:
#	aiogram/__init__.py
This commit is contained in:
JRoot Junior 2023-11-23 00:33:19 +02:00
commit 09d0244a79
No known key found for this signature in database
GPG key ID: 738964250D5FF6E2
3 changed files with 55 additions and 2 deletions

1
CHANGES/1368.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Fixed a situation where a :code:`CallbackData` could not be parsed without a default value.

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import sys
import types
import typing
from decimal import Decimal
from enum import Enum
from fractions import Fraction
@ -18,6 +21,7 @@ from uuid import UUID
from magic_filter import MagicFilter
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery
@ -27,6 +31,11 @@ T = TypeVar("T", bound="CallbackData")
MAX_CALLBACK_LENGTH: int = 64
_UNION_TYPES = {typing.Union}
if sys.version_info >= (3, 10): # pragma: no cover
_UNION_TYPES.add(types.UnionType)
class CallbackDataException(Exception):
pass
@ -121,7 +130,7 @@ class CallbackData(BaseModel):
payload = {}
for k, v in zip(names, parts): # type: str, Optional[str]
if field := cls.model_fields.get(k):
if v == "" and not field.is_required():
if v == "" and _check_field_is_nullable(field):
v = None
payload[k] = v
return cls(**payload)
@ -180,3 +189,19 @@ class CallbackQueryFilter(Filter):
if self.rule is None or self.rule.resolve(callback_data):
return {"callback_data": callback_data}
return False
def _check_field_is_nullable(field: FieldInfo) -> bool:
"""
Check if the given field is nullable.
:param field: The FieldInfo object representing the field to check.
:return: True if the field is nullable, False otherwise.
"""
if not field.is_required():
return True
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
field.annotation
)

View file

@ -1,7 +1,8 @@
import sys
from decimal import Decimal
from enum import Enum, auto
from fractions import Fraction
from typing import Optional
from typing import Optional, Union
from uuid import UUID
import pytest
@ -147,6 +148,32 @@ class TestCallbackData:
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
@pytest.mark.parametrize(
"hint",
[
Union[int, None],
Optional[int],
],
)
def test_unpack_optional_wo_default(self, hint):
"""Test CallbackData without default optional."""
class TgData(CallbackData, prefix="tg"):
chat_id: int
thread_id: hint
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
@pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType is added in Python 3.10")
def test_unpack_optional_wo_default_union_type(self):
"""Test CallbackData without default optional."""
class TgData(CallbackData, prefix="tg"):
chat_id: int
thread_id: int | None
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
def test_build_filter(self):
filter_object = MyCallback.filter(F.foo == "test")
assert isinstance(filter_object.rule, MagicFilter)