mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge branch 'dev-3.x' into scenes
# Conflicts: # aiogram/__init__.py
This commit is contained in:
commit
09d0244a79
3 changed files with 55 additions and 2 deletions
1
CHANGES/1368.bugfix.rst
Normal file
1
CHANGES/1368.bugfix.rst
Normal file
|
|
@ -0,0 +1 @@
|
|||
Fixed a situation where a :code:`CallbackData` could not be parsed without a default value.
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
|
|
@ -18,6 +21,7 @@ from uuid import UUID
|
|||
|
||||
from magic_filter import MagicFilter
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from aiogram.filters.base import Filter
|
||||
from aiogram.types import CallbackQuery
|
||||
|
|
@ -27,6 +31,11 @@ T = TypeVar("T", bound="CallbackData")
|
|||
MAX_CALLBACK_LENGTH: int = 64
|
||||
|
||||
|
||||
_UNION_TYPES = {typing.Union}
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
_UNION_TYPES.add(types.UnionType)
|
||||
|
||||
|
||||
class CallbackDataException(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -121,7 +130,7 @@ class CallbackData(BaseModel):
|
|||
payload = {}
|
||||
for k, v in zip(names, parts): # type: str, Optional[str]
|
||||
if field := cls.model_fields.get(k):
|
||||
if v == "" and not field.is_required():
|
||||
if v == "" and _check_field_is_nullable(field):
|
||||
v = None
|
||||
payload[k] = v
|
||||
return cls(**payload)
|
||||
|
|
@ -180,3 +189,19 @@ class CallbackQueryFilter(Filter):
|
|||
if self.rule is None or self.rule.resolve(callback_data):
|
||||
return {"callback_data": callback_data}
|
||||
return False
|
||||
|
||||
|
||||
def _check_field_is_nullable(field: FieldInfo) -> bool:
|
||||
"""
|
||||
Check if the given field is nullable.
|
||||
|
||||
:param field: The FieldInfo object representing the field to check.
|
||||
:return: True if the field is nullable, False otherwise.
|
||||
|
||||
"""
|
||||
if not field.is_required():
|
||||
return True
|
||||
|
||||
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
|
||||
field.annotation
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import sys
|
||||
from decimal import Decimal
|
||||
from enum import Enum, auto
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
|
@ -147,6 +148,32 @@ class TestCallbackData:
|
|||
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
|
||||
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hint",
|
||||
[
|
||||
Union[int, None],
|
||||
Optional[int],
|
||||
],
|
||||
)
|
||||
def test_unpack_optional_wo_default(self, hint):
|
||||
"""Test CallbackData without default optional."""
|
||||
|
||||
class TgData(CallbackData, prefix="tg"):
|
||||
chat_id: int
|
||||
thread_id: hint
|
||||
|
||||
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType is added in Python 3.10")
|
||||
def test_unpack_optional_wo_default_union_type(self):
|
||||
"""Test CallbackData without default optional."""
|
||||
|
||||
class TgData(CallbackData, prefix="tg"):
|
||||
chat_id: int
|
||||
thread_id: int | None
|
||||
|
||||
assert TgData.unpack("tg:123:") == TgData(chat_id=123, thread_id=None)
|
||||
|
||||
def test_build_filter(self):
|
||||
filter_object = MyCallback.filter(F.foo == "test")
|
||||
assert isinstance(filter_object.rule, MagicFilter)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue