EOL of Py3.9 (#1726)
Some checks failed
Tests / tests (macos-latest, 3.10) (push) Has been cancelled
Tests / tests (macos-latest, 3.11) (push) Has been cancelled
Tests / tests (macos-latest, 3.12) (push) Has been cancelled
Tests / tests (macos-latest, 3.13) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / tests (windows-latest, 3.10) (push) Has been cancelled
Tests / tests (windows-latest, 3.11) (push) Has been cancelled
Tests / tests (windows-latest, 3.12) (push) Has been cancelled
Tests / tests (windows-latest, 3.13) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.11) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.11) (push) Has been cancelled

* Drop py3.9 and pypy3.9

Add pypy3.11 (testing) into `tests.yml`

Remove py3.9 from matrix in `tests.yml`

Refactor not auto-gen code to be compatible with py3.10+, droping ugly 3.9 annotation.

Replace some `from typing` imports to `from collections.abc`, due to deprecation

Add `from __future__ import annotations` and `if TYPE_CHECKING:` where possible

Add some `noqa` to calm down Ruff in some places, if Ruff will be used as default linting+formatting tool in future

Replace some relative imports to absolute

Sort `__all__` tuples in `__init__.py` and some other `.py` files

Sort `__slots__` tuples in classes

Split raises into `msg` and `raise` (`EM101`, `EM102`) to not duplicate error message in the traceback

Add `Self` from `typing_extenstion` where possible

Resolve typing problem in `aiogram/filters/command.py:18`

Concatenate nested `if` statements

Convert `HandlerContainer` into a dataclass in `aiogram/fsm/scene.py`

Bump tests docker-compose.yml `redis:6-alpine` -> `redis:8-alpine`

Bump tests docker-compose.yml `mongo:7.0.6` -> `mongo:8.0.14`

Bump pre-commit-config `black==24.4.2` -> `black==25.9.0`

Bump pre-commit-config `ruff==0.5.1` -> `ruff==0.13.3`

Update Makefile lint for ruff to show fixes

Add `make outdated` into Makefile

Use `pathlib` instead of `os.path`

Bump `redis[hiredis]>=5.0.1,<5.3.0` -> `redis[hiredis]>=6.2.0,<7`

Bump `cryptography>=43.0.0` -> `cryptography>=46.0.0` due to security reasons

Bump `pytz~=2023.3` -> `pytz~=2025.2`

Bump `pycryptodomex~=3.19.0` -> `pycryptodomex~=3.23.0` due to security reasons

Bump linting and formatting tools

* Add `1726.removal.rst`

* Update aiogram/utils/dataclass.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update aiogram/filters/callback_data.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update 1726.removal.rst

* Remove `outdated` from Makefile

* Add `__slots__` to `HandlerContainer`

* Remove unused imports

* Add `@dataclass` with `slots=True` to `HandlerContainer`

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Andrew 2025-10-06 19:19:23 +03:00 committed by GitHub
parent ab32296d07
commit df7b16d5b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 1383 additions and 1215 deletions

View file

@ -19,7 +19,7 @@ jobs:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: '0'
- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"

View file

@ -29,7 +29,6 @@ jobs:
- macos-latest
- windows-latest
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
@ -111,8 +110,8 @@ jobs:
- macos-latest
# - windows-latest
python-version:
- "pypy3.9"
- "pypy3.10"
- "pypy3.11"
defaults:
# Windows sucks. Force use bash instead of PowerShell

View file

@ -14,12 +14,12 @@ repos:
- id: "check-json"
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 25.9.0
hooks:
- id: black
files: &files '^(aiogram|tests|examples)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.5.1'
rev: 'v0.13.3'
hooks:
- id: ruff

7
CHANGES/1726.removal.rst Normal file
View file

@ -0,0 +1,7 @@
This PR updates the codebase following the end of life for Python 3.9.
Reference: https://devguide.python.org/versions/
- Updated type annotations to Python 3.10+ style, replacing deprecated ``List``, ``Set``, etc., with built-in ``list``, ``set``, and related types.
- Refactored code by simplifying nested ``if`` expressions.
- Updated several dependencies, including security-related upgrades.

View file

@ -39,7 +39,7 @@ install: clean
lint:
isort --check-only $(code_dir)
black --check --diff $(code_dir)
ruff check $(package_dir) $(examples_dir)
ruff check --show-fixes --preview $(package_dir) $(examples_dir)
mypy $(package_dir)
.PHONY: reformat

View file

@ -35,7 +35,7 @@ aiogram
:alt: Codecov
**aiogram** is a modern and fully asynchronous framework for
`Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.8+ using
`Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.10+ using
`asyncio <https://docs.python.org/3/library/asyncio.html>`_ and
`aiohttp <https://github.com/aio-libs/aiohttp>`_.

View file

@ -24,18 +24,18 @@ F = MagicFilter()
flags = FlagGenerator()
__all__ = (
"BaseMiddleware",
"Bot",
"Dispatcher",
"F",
"Router",
"__api_version__",
"__version__",
"types",
"methods",
"enums",
"Bot",
"session",
"Dispatcher",
"Router",
"BaseMiddleware",
"F",
"flags",
"html",
"md",
"flags",
"methods",
"session",
"types",
)

View file

@ -10,7 +10,7 @@ if TYPE_CHECKING:
class BotContextController(BaseModel):
_bot: Optional["Bot"] = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
def model_post_init(self, __context: Any) -> None: # noqa: PYI063
self._bot = __context.get("bot") if __context else None
def as_(self, bot: Optional["Bot"]) -> Self:

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from aiogram.utils.dataclass import dataclass_kwargs
@ -35,25 +35,25 @@ class DefaultBotProperties:
Default bot properties.
"""
parse_mode: Optional[str] = None
parse_mode: str | None = None
"""Default parse mode for messages."""
disable_notification: Optional[bool] = None
disable_notification: bool | None = None
"""Sends the message silently. Users will receive a notification with no sound."""
protect_content: Optional[bool] = None
protect_content: bool | None = None
"""Protects content from copying."""
allow_sending_without_reply: Optional[bool] = None
allow_sending_without_reply: bool | None = None
"""Allows to send messages without reply."""
link_preview: Optional[LinkPreviewOptions] = None
link_preview: LinkPreviewOptions | None = None
"""Link preview settings."""
link_preview_is_disabled: Optional[bool] = None
link_preview_is_disabled: bool | None = None
"""Disables link preview."""
link_preview_prefer_small_media: Optional[bool] = None
link_preview_prefer_small_media: bool | None = None
"""Prefer small media in link preview."""
link_preview_prefer_large_media: Optional[bool] = None
link_preview_prefer_large_media: bool | None = None
"""Prefer large media in link preview."""
link_preview_show_above_text: Optional[bool] = None
link_preview_show_above_text: bool | None = None
"""Show link preview above text."""
show_caption_above_media: Optional[bool] = None
show_caption_above_media: bool | None = None
"""Show caption above media."""
def __post_init__(self) -> None:
@ -63,11 +63,11 @@ class DefaultBotProperties:
self.link_preview_prefer_small_media,
self.link_preview_prefer_large_media,
self.link_preview_show_above_text,
)
),
)
if has_any_link_preview_option and self.link_preview is None:
from ..types import LinkPreviewOptions
from aiogram.types import LinkPreviewOptions
self.link_preview = LinkPreviewOptions(
is_disabled=self.link_preview_is_disabled,

View file

@ -2,45 +2,35 @@ from __future__ import annotations
import asyncio
import ssl
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from collections.abc import AsyncGenerator, Iterable
from typing import TYPE_CHECKING, Any, cast
import certifi
from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
from aiohttp.hdrs import USER_AGENT
from aiohttp.http import SERVER_SOFTWARE
from typing_extensions import Self
from aiogram.__meta__ import __version__
from aiogram.methods import TelegramMethod
from aiogram.exceptions import TelegramNetworkError
from aiogram.methods.base import TelegramType
from ...exceptions import TelegramNetworkError
from ...methods.base import TelegramType
from ...types import InputFile
from .base import BaseSession
if TYPE_CHECKING:
from ..bot import Bot
from aiogram.client.bot import Bot
from aiogram.methods import TelegramMethod
from aiogram.types import InputFile
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
_ProxyBasic = str | tuple[str, BasicAuth]
_ProxyChain = Iterable[_ProxyBasic]
_ProxyType = Union[_ProxyChain, _ProxyBasic]
_ProxyType = _ProxyChain | _ProxyBasic
def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
def _retrieve_basic(basic: _ProxyBasic) -> dict[str, Any]:
from aiohttp_socks.utils import parse_proxy_url
proxy_auth: Optional[BasicAuth] = None
proxy_auth: BasicAuth | None = None
if isinstance(basic, str):
proxy_url = basic
@ -62,7 +52,7 @@ def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
}
def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"], Dict[str, Any]]:
def _prepare_connector(chain_or_plain: _ProxyType) -> tuple[type[TCPConnector], dict[str, Any]]:
from aiohttp_socks import ChainProxyConnector, ProxyConnector, ProxyInfo
# since tuple is Iterable(compatible with _ProxyChain) object, we assume that
@ -74,17 +64,13 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
return ProxyConnector, _retrieve_basic(chain_or_plain)
chain_or_plain = cast(_ProxyChain, chain_or_plain)
infos: List[ProxyInfo] = []
for basic in chain_or_plain:
infos.append(ProxyInfo(**_retrieve_basic(basic)))
infos: list[ProxyInfo] = [ProxyInfo(**_retrieve_basic(basic)) for basic in chain_or_plain]
return ChainProxyConnector, {"proxy_infos": infos}
class AiohttpSession(BaseSession):
def __init__(
self, proxy: Optional[_ProxyType] = None, limit: int = 100, **kwargs: Any
) -> None:
def __init__(self, proxy: _ProxyType | None = None, limit: int = 100, **kwargs: Any) -> None:
"""
Client session based on aiohttp.
@ -94,31 +80,32 @@ class AiohttpSession(BaseSession):
"""
super().__init__(**kwargs)
self._session: Optional[ClientSession] = None
self._connector_type: Type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = {
self._session: ClientSession | None = None
self._connector_type: type[TCPConnector] = TCPConnector
self._connector_init: dict[str, Any] = {
"ssl": ssl.create_default_context(cafile=certifi.where()),
"limit": limit,
"ttl_dns_cache": 3600, # Workaround for https://github.com/aiogram/aiogram/issues/1500
}
self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None
self._proxy: _ProxyType | None = None
if proxy is not None:
try:
self._setup_proxy_connector(proxy)
except ImportError as exc: # pragma: no cover
raise RuntimeError(
msg = (
"In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/"
) from exc
)
raise RuntimeError(msg) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy
@property
def proxy(self) -> Optional[_ProxyType]:
def proxy(self) -> _ProxyType | None:
return self._proxy
@proxy.setter
@ -151,7 +138,7 @@ class AiohttpSession(BaseSession):
def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData:
form = FormData(quote_fields=False)
files: Dict[str, InputFile] = {}
files: dict[str, InputFile] = {}
for key, value in method.model_dump(warnings=False).items():
value = self.prepare_value(value, bot=bot, files=files)
if not value:
@ -166,7 +153,10 @@ class AiohttpSession(BaseSession):
return form
async def make_request(
self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = None
self,
bot: Bot,
method: TelegramMethod[TelegramType],
timeout: int | None = None,
) -> TelegramType:
session = await self.create_session()
@ -175,7 +165,9 @@ class AiohttpSession(BaseSession):
try:
async with session.post(
url, data=form, timeout=self.timeout if timeout is None else timeout
url,
data=form,
timeout=self.timeout if timeout is None else timeout,
) as resp:
raw_result = await resp.text()
except asyncio.TimeoutError:
@ -183,14 +175,17 @@ class AiohttpSession(BaseSession):
except ClientError as e:
raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}")
response = self.check_response(
bot=bot, method=method, status_code=resp.status, content=raw_result
bot=bot,
method=method,
status_code=resp.status,
content=raw_result,
)
return cast(TelegramType, response.result)
async def stream_content(
self,
url: str,
headers: Optional[Dict[str, Any]] = None,
headers: dict[str, Any] | None = None,
timeout: int = 30,
chunk_size: int = 65536,
raise_for_status: bool = True,
@ -201,11 +196,14 @@ class AiohttpSession(BaseSession):
session = await self.create_session()
async with session.get(
url, timeout=timeout, headers=headers, raise_for_status=raise_for_status
url,
timeout=timeout,
headers=headers,
raise_for_status=raise_for_status,
) as resp:
async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk
async def __aenter__(self) -> AiohttpSession:
async def __aenter__(self) -> Self:
await self.create_session()
return self

View file

@ -4,23 +4,16 @@ import abc
import datetime
import json
import secrets
from collections.abc import AsyncGenerator, Callable
from enum import Enum
from http import HTTPStatus
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Final,
Optional,
Type,
cast,
)
from typing import TYPE_CHECKING, Any, Final, cast
from pydantic import ValidationError
from typing_extensions import Self
from aiogram.client.default import Default
from aiogram.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.exceptions import (
ClientDecodeError,
RestartingTelegram,
@ -35,16 +28,16 @@ from aiogram.exceptions import (
TelegramServerError,
TelegramUnauthorizedError,
)
from aiogram.methods import Response, TelegramMethod
from aiogram.methods.base import TelegramType
from aiogram.types import InputFile, TelegramObject
from ...methods import Response, TelegramMethod
from ...methods.base import TelegramType
from ...types import InputFile, TelegramObject
from ..default import Default
from ..telegram import PRODUCTION, TelegramAPIServer
from .middlewares.manager import RequestMiddlewareManager
if TYPE_CHECKING:
from ..bot import Bot
from types import TracebackType
from aiogram.client.bot import Bot
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
@ -81,24 +74,30 @@ class BaseSession(abc.ABC):
self.middleware = RequestMiddlewareManager()
def check_response(
self, bot: Bot, method: TelegramMethod[TelegramType], status_code: int, content: str
self,
bot: Bot,
method: TelegramMethod[TelegramType],
status_code: int,
content: str,
) -> Response[TelegramType]:
"""
Check response status
"""
try:
json_data = self.json_loads(content)
except Exception as e:
except Exception as e: # noqa: BLE001
# Handled error type can't be classified as specific error
# in due to decoder can be customized and raise any exception
raise ClientDecodeError("Failed to decode object", e, content)
msg = "Failed to decode object"
raise ClientDecodeError(msg, e, content)
try:
response_type = Response[method.__returning__] # type: ignore
response = response_type.model_validate(json_data, context={"bot": bot})
except ValidationError as e:
raise ClientDecodeError("Failed to deserialize object", e, json_data)
msg = "Failed to deserialize object"
raise ClientDecodeError(msg, e, json_data)
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED and response.ok:
return response
@ -108,7 +107,9 @@ class BaseSession(abc.ABC):
if parameters := response.parameters:
if parameters.retry_after:
raise TelegramRetryAfter(
method=method, message=description, retry_after=parameters.retry_after
method=method,
message=description,
retry_after=parameters.retry_after,
)
if parameters.migrate_to_chat_id:
raise TelegramMigrateToChat(
@ -143,14 +144,13 @@ class BaseSession(abc.ABC):
"""
Close client session
"""
pass
@abc.abstractmethod
async def make_request(
self,
bot: Bot,
method: TelegramMethod[TelegramType],
timeout: Optional[int] = None,
timeout: int | None = None,
) -> TelegramType: # pragma: no cover
"""
Make request to Telegram Bot API
@ -161,13 +161,12 @@ class BaseSession(abc.ABC):
:return:
:raise TelegramApiError:
"""
pass
@abc.abstractmethod
async def stream_content(
self,
url: str,
headers: Optional[Dict[str, Any]] = None,
headers: dict[str, Any] | None = None,
timeout: int = 30,
chunk_size: int = 65536,
raise_for_status: bool = True,
@ -181,7 +180,7 @@ class BaseSession(abc.ABC):
self,
value: Any,
bot: Bot,
files: Dict[str, Any],
files: dict[str, Any],
_dumps_json: bool = True,
) -> Any:
"""
@ -204,7 +203,10 @@ class BaseSession(abc.ABC):
for key, item in value.items()
if (
prepared_item := self.prepare_value(
item, bot=bot, files=files, _dumps_json=False
item,
bot=bot,
files=files,
_dumps_json=False,
)
)
is not None
@ -218,7 +220,10 @@ class BaseSession(abc.ABC):
for item in value
if (
prepared_item := self.prepare_value(
item, bot=bot, files=files, _dumps_json=False
item,
bot=bot,
files=files,
_dumps_json=False,
)
)
is not None
@ -227,7 +232,7 @@ class BaseSession(abc.ABC):
return self.json_dumps(value)
return value
if isinstance(value, datetime.timedelta):
now = datetime.datetime.now()
now = datetime.datetime.now() # noqa: DTZ005
return str(round((now + value).timestamp()))
if isinstance(value, datetime.datetime):
return str(round(value.timestamp()))
@ -248,18 +253,18 @@ class BaseSession(abc.ABC):
self,
bot: Bot,
method: TelegramMethod[TelegramType],
timeout: Optional[int] = None,
timeout: int | None = None,
) -> TelegramType:
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
return cast(TelegramType, await middleware(bot, method))
async def __aenter__(self) -> BaseSession:
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.close()

View file

@ -3,17 +3,17 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol
from aiogram.methods import Response, TelegramMethod
from aiogram.methods.base import TelegramType
if TYPE_CHECKING:
from ...bot import Bot
from aiogram.client.bot import Bot
from aiogram.methods import Response, TelegramMethod
class NextRequestMiddlewareType(Protocol[TelegramType]): # pragma: no cover
async def __call__(
self,
bot: "Bot",
bot: Bot,
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
pass
@ -23,7 +23,7 @@ class RequestMiddlewareType(Protocol): # pragma: no cover
async def __call__(
self,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
bot: Bot,
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
pass
@ -38,7 +38,7 @@ class BaseRequestMiddleware(ABC):
async def __call__(
self,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
bot: Bot,
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
"""
@ -50,4 +50,3 @@ class BaseRequestMiddleware(ABC):
:return: :class:`aiogram.methods.Response`
"""
pass

View file

@ -1,7 +1,8 @@
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
from typing import Any, cast, overload
from aiogram.client.session.middlewares.base import (
NextRequestMiddlewareType,
@ -12,7 +13,7 @@ from aiogram.methods.base import TelegramType
class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
def __init__(self) -> None:
self._middlewares: List[RequestMiddlewareType] = []
self._middlewares: list[RequestMiddlewareType] = []
def register(
self,
@ -26,11 +27,8 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
def __call__(
self,
middleware: Optional[RequestMiddlewareType] = None,
) -> Union[
Callable[[RequestMiddlewareType], RequestMiddlewareType],
RequestMiddlewareType,
]:
middleware: RequestMiddlewareType | None = None,
) -> Callable[[RequestMiddlewareType], RequestMiddlewareType] | RequestMiddlewareType:
if middleware is None:
return self.register
return self.register(middleware)
@ -44,8 +42,9 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[RequestMiddlewareType, Sequence[RequestMiddlewareType]]:
self,
item: int | slice,
) -> RequestMiddlewareType | Sequence[RequestMiddlewareType]:
return self._middlewares[item]
def __len__(self) -> int:

View file

@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Type
from typing import TYPE_CHECKING, Any
from aiogram import loggers
from aiogram.methods import TelegramMethod
@ -8,19 +8,19 @@ from aiogram.methods.base import Response, TelegramType
from .base import BaseRequestMiddleware, NextRequestMiddlewareType
if TYPE_CHECKING:
from ...bot import Bot
from aiogram.client.bot import Bot
logger = logging.getLogger(__name__)
class RequestLogging(BaseRequestMiddleware):
def __init__(self, ignore_methods: Optional[List[Type[TelegramMethod[Any]]]] = None):
def __init__(self, ignore_methods: list[type[TelegramMethod[Any]]] | None = None):
"""
Middleware for logging outgoing requests
:param ignore_methods: methods to ignore in logging middleware
"""
self.ignore_methods = ignore_methods if ignore_methods else []
self.ignore_methods = ignore_methods or []
async def __call__(
self,

View file

@ -1,24 +1,24 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Union
from typing import Any
class FilesPathWrapper(ABC):
@abstractmethod
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
def to_local(self, path: Path | str) -> Path | str:
pass
@abstractmethod
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
def to_server(self, path: Path | str) -> Path | str:
pass
class BareFilesPathWrapper(FilesPathWrapper):
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
def to_local(self, path: Path | str) -> Path | str:
return path
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
def to_server(self, path: Path | str) -> Path | str:
return path
@ -29,15 +29,18 @@ class SimpleFilesPathWrapper(FilesPathWrapper):
@classmethod
def _resolve(
cls, base1: Union[Path, str], base2: Union[Path, str], value: Union[Path, str]
cls,
base1: Path | str,
base2: Path | str,
value: Path | str,
) -> Path:
relative = Path(value).relative_to(base1)
return base2 / relative
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
def to_local(self, path: Path | str) -> Path | str:
return self._resolve(base1=self.server_path, base2=self.local_path, value=path)
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
def to_server(self, path: Path | str) -> Path | str:
return self._resolve(base1=self.local_path, base2=self.server_path, value=path)
@ -54,7 +57,7 @@ class TelegramAPIServer:
is_local: bool = False
"""Mark this server is
in `local mode <https://core.telegram.org/bots/api#using-a-local-bot-api-server>`_."""
wrap_local_file: FilesPathWrapper = BareFilesPathWrapper()
wrap_local_file: FilesPathWrapper = field(default=BareFilesPathWrapper())
"""Callback to wrap files path in local mode"""
def api_url(self, token: str, method: str) -> str:
@ -67,7 +70,7 @@ class TelegramAPIServer:
"""
return self.base.format(token=token, method=method)
def file_url(self, token: str, path: Union[str, Path]) -> str:
def file_url(self, token: str, path: str | Path) -> str:
"""
Generate URL for downloading files

View file

@ -5,28 +5,32 @@ import contextvars
import signal
import warnings
from asyncio import CancelledError, Event, Future, Lock
from collections.abc import AsyncGenerator, Awaitable
from contextlib import suppress
from typing import Any, AsyncGenerator, Awaitable, Dict, List, Optional, Set, Union
from typing import TYPE_CHECKING, Any
from aiogram import loggers
from aiogram.exceptions import TelegramAPIError
from aiogram.fsm.middleware import FSMContextMiddleware
from aiogram.fsm.storage.base import BaseEventIsolation, BaseStorage
from aiogram.fsm.storage.memory import DisabledEventIsolation, MemoryStorage
from aiogram.fsm.strategy import FSMStrategy
from aiogram.methods import GetUpdates, TelegramMethod
from aiogram.types import Update, User
from aiogram.types.base import UNSET, UNSET_TYPE
from aiogram.types.update import UpdateTypeLookupError
from aiogram.utils.backoff import Backoff, BackoffConfig
from .. import loggers
from ..client.bot import Bot
from ..exceptions import TelegramAPIError
from ..fsm.middleware import FSMContextMiddleware
from ..fsm.storage.base import BaseEventIsolation, BaseStorage
from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage
from ..fsm.strategy import FSMStrategy
from ..methods import GetUpdates, TelegramMethod
from ..methods.base import TelegramType
from ..types import Update, User
from ..types.base import UNSET, UNSET_TYPE
from ..types.update import UpdateTypeLookupError
from ..utils.backoff import Backoff, BackoffConfig
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
from .middlewares.error import ErrorsMiddleware
from .middlewares.user_context import UserContextMiddleware
from .router import Router
if TYPE_CHECKING:
from aiogram.client.bot import Bot
from aiogram.methods.base import TelegramType
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
@ -38,11 +42,11 @@ class Dispatcher(Router):
def __init__(
self,
*, # * - Preventing to pass instance of Bot to the FSM storage
storage: Optional[BaseStorage] = None,
storage: BaseStorage | None = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
events_isolation: Optional[BaseEventIsolation] = None,
events_isolation: BaseEventIsolation | None = None,
disable_fsm: bool = False,
name: Optional[str] = None,
name: str | None = None,
**kwargs: Any,
) -> None:
"""
@ -55,18 +59,18 @@ class Dispatcher(Router):
then you should not use storage and events isolation
:param kwargs: Other arguments, will be passed as keyword arguments to handlers
"""
super(Dispatcher, self).__init__(name=name)
super().__init__(name=name)
if storage and not isinstance(storage, BaseStorage):
raise TypeError(
f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}"
)
msg = f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}"
raise TypeError(msg)
# Telegram API provides originally only one event type - Update
# For making easily interactions with events here is registered handler which helps
# to separate Update to different event types like Message, CallbackQuery etc.
self.update = self.observers["update"] = TelegramEventObserver(
router=self, event_name="update"
router=self,
event_name="update",
)
self.update.register(self._listen_update)
@ -91,11 +95,11 @@ class Dispatcher(Router):
self.update.outer_middleware(self.fsm)
self.shutdown.register(self.fsm.close)
self.workflow_data: Dict[str, Any] = kwargs
self.workflow_data: dict[str, Any] = kwargs
self._running_lock = Lock()
self._stop_signal: Optional[Event] = None
self._stopped_signal: Optional[Event] = None
self._handle_update_tasks: Set[asyncio.Task[Any]] = set()
self._stop_signal: Event | None = None
self._stopped_signal: Event | None = None
self._handle_update_tasks: set[asyncio.Task[Any]] = set()
def __getitem__(self, item: str) -> Any:
return self.workflow_data[item]
@ -106,7 +110,7 @@ class Dispatcher(Router):
def __delitem__(self, key: str) -> None:
del self.workflow_data[key]
def get(self, key: str, /, default: Optional[Any] = None) -> Optional[Any]:
def get(self, key: str, /, default: Any | None = None) -> Any | None:
return self.workflow_data.get(key, default)
@property
@ -114,13 +118,13 @@ class Dispatcher(Router):
return self.fsm.storage
@property
def parent_router(self) -> Optional[Router]:
def parent_router(self) -> Router | None:
"""
Dispatcher has no parent router and can't be included to any other routers or dispatchers
:return:
"""
return None # noqa: RET501
return None
@parent_router.setter
def parent_router(self, value: Router) -> None:
@ -130,7 +134,8 @@ class Dispatcher(Router):
:param value:
:return:
"""
raise RuntimeError("Dispatcher can not be attached to another Router.")
msg = "Dispatcher can not be attached to another Router."
raise RuntimeError(msg)
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
@ -177,7 +182,7 @@ class Dispatcher(Router):
bot.id,
)
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
async def feed_raw_update(self, bot: Bot, update: dict[str, Any], **kwargs: Any) -> Any:
"""
Main entry point for incoming updates with automatic Dict->Update serializer
@ -194,7 +199,7 @@ class Dispatcher(Router):
bot: Bot,
polling_timeout: int = 30,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
allowed_updates: list[str] | None = None,
) -> AsyncGenerator[Update, None]:
"""
Endless updates reader with correctly handling any server-side or connection errors.
@ -212,7 +217,7 @@ class Dispatcher(Router):
while True:
try:
updates = await bot(get_updates, **kwargs)
except Exception as e:
except Exception as e: # noqa: BLE001
failed = True
# In cases when Telegram Bot API was inaccessible don't need to stop polling
# process because some developers can't make auto-restarting of the script
@ -268,6 +273,7 @@ class Dispatcher(Router):
"installed not latest version of aiogram framework"
f"\nUpdate: {update.model_dump_json(exclude_unset=True)}",
RuntimeWarning,
stacklevel=2,
)
raise SkipHandler() from e
@ -294,7 +300,11 @@ class Dispatcher(Router):
loggers.event.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
async def _process_update(
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any
self,
bot: Bot,
update: Update,
call_answer: bool = True,
**kwargs: Any,
) -> bool:
"""
Propagate update to event listeners
@ -309,9 +319,8 @@ class Dispatcher(Router):
response = await self.feed_update(bot, update, **kwargs)
if call_answer and isinstance(response, TelegramMethod):
await self.silent_call_request(bot=bot, result=response)
return response is not UNHANDLED
except Exception as e:
except Exception as e: # noqa: BLE001
loggers.event.exception(
"Cause exception while process update id=%d by bot id=%d\n%s: %s",
update.update_id,
@ -321,8 +330,13 @@ class Dispatcher(Router):
)
return True # because update was processed but unsuccessful
else:
return response is not UNHANDLED
async def _process_with_semaphore(
self, handle_update: Awaitable[bool], semaphore: asyncio.Semaphore
self,
handle_update: Awaitable[bool],
semaphore: asyncio.Semaphore,
) -> bool:
"""
Process update with semaphore to limit concurrent tasks
@ -342,8 +356,8 @@ class Dispatcher(Router):
polling_timeout: int = 30,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
tasks_concurrency_limit: Optional[int] = None,
allowed_updates: list[str] | None = None,
tasks_concurrency_limit: int | None = None,
**kwargs: Any,
) -> None:
"""
@ -361,7 +375,10 @@ class Dispatcher(Router):
"""
user: User = await bot.me()
loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
"Run polling for bot @%s id=%d - %r",
user.username,
bot.id,
user.full_name,
)
# Create semaphore if tasks_concurrency_limit is specified
@ -382,7 +399,7 @@ class Dispatcher(Router):
# Use semaphore to limit concurrent tasks
await semaphore.acquire()
handle_update_task = asyncio.create_task(
self._process_with_semaphore(handle_update, semaphore)
self._process_with_semaphore(handle_update, semaphore),
)
else:
handle_update_task = asyncio.create_task(handle_update)
@ -393,7 +410,10 @@ class Dispatcher(Router):
await handle_update
finally:
loggers.dispatcher.info(
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
"Polling stopped for bot @%s id=%d - %r",
user.username,
bot.id,
user.full_name,
)
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
@ -413,8 +433,12 @@ class Dispatcher(Router):
raise
async def feed_webhook_update(
self, bot: Bot, update: Union[Update, Dict[str, Any]], _timeout: float = 55, **kwargs: Any
) -> Optional[TelegramMethod[TelegramType]]:
self,
bot: Bot,
update: Update | dict[str, Any],
_timeout: float = 55,
**kwargs: Any,
) -> TelegramMethod[TelegramType] | None:
if not isinstance(update, Update): # Allow to use raw updates
update = Update.model_validate(update, context={"bot": bot})
@ -429,7 +453,7 @@ class Dispatcher(Router):
timeout_handle = loop.call_later(_timeout, release_waiter)
process_updates: Future[Any] = asyncio.ensure_future(
self._feed_webhook_update(bot=bot, update=update, **kwargs)
self._feed_webhook_update(bot=bot, update=update, **kwargs),
)
process_updates.add_done_callback(release_waiter, context=ctx)
@ -440,11 +464,9 @@ class Dispatcher(Router):
"For preventing this situation response into webhook returned immediately "
"and handler is moved to background and still processing update.",
RuntimeWarning,
stacklevel=2,
)
try:
result = task.result()
except Exception as e:
raise e
if isinstance(result, TelegramMethod):
asyncio.ensure_future(self.silent_call_request(bot=bot, result=result))
@ -478,7 +500,8 @@ class Dispatcher(Router):
:return:
"""
if not self._running_lock.locked():
raise RuntimeError("Polling is not started")
msg = "Polling is not started"
raise RuntimeError(msg)
if not self._stop_signal or not self._stopped_signal:
return
self._stop_signal.set()
@ -499,10 +522,10 @@ class Dispatcher(Router):
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
handle_signals: bool = True,
close_bot_session: bool = True,
tasks_concurrency_limit: Optional[int] = None,
tasks_concurrency_limit: int | None = None,
**kwargs: Any,
) -> None:
"""
@ -522,12 +545,14 @@ class Dispatcher(Router):
:return:
"""
if not bots:
raise ValueError("At least one bot instance is required to start polling")
msg = "At least one bot instance is required to start polling"
raise ValueError(msg)
if "bot" in kwargs:
raise ValueError(
msg = (
"Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument"
)
raise ValueError(msg)
async with self._running_lock: # Prevent to run this method twice at a once
if self._stop_signal is None:
@ -547,10 +572,14 @@ class Dispatcher(Router):
# Signals handling is not supported on Windows
# It also can't be covered on Windows
loop.add_signal_handler(
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
signal.SIGTERM,
self._signal_stop_polling,
signal.SIGTERM,
)
loop.add_signal_handler(
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
signal.SIGINT,
self._signal_stop_polling,
signal.SIGINT,
)
workflow_data = {
@ -565,7 +594,7 @@ class Dispatcher(Router):
await self.emit_startup(bot=bots[-1], **workflow_data)
loggers.dispatcher.info("Start polling")
try:
tasks: List[asyncio.Task[Any]] = [
tasks: list[asyncio.Task[Any]] = [
asyncio.create_task(
self._polling(
bot=bot,
@ -575,7 +604,7 @@ class Dispatcher(Router):
allowed_updates=allowed_updates,
tasks_concurrency_limit=tasks_concurrency_limit,
**workflow_data,
)
),
)
for bot in bots
]
@ -605,10 +634,10 @@ class Dispatcher(Router):
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
handle_signals: bool = True,
close_bot_session: bool = True,
tasks_concurrency_limit: Optional[int] = None,
tasks_concurrency_limit: int | None = None,
**kwargs: Any,
) -> None:
"""
@ -638,5 +667,5 @@ class Dispatcher(Router):
handle_signals=handle_signals,
close_bot_session=close_bot_session,
tasks_concurrency_limit=tasks_concurrency_limit,
)
),
)

View file

@ -1,20 +1,22 @@
from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, TypeVar, Union
from collections.abc import Awaitable, Callable
from typing import Any, NoReturn, TypeVar
from unittest.mock import sentinel
from ...types import TelegramObject
from ..middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import TelegramObject
MiddlewareEventType = TypeVar("MiddlewareEventType", bound=TelegramObject)
NextMiddlewareType = Callable[[MiddlewareEventType, Dict[str, Any]], Awaitable[Any]]
MiddlewareType = Union[
BaseMiddleware,
Callable[
[NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, Dict[str, Any]],
NextMiddlewareType = Callable[[MiddlewareEventType, dict[str, Any]], Awaitable[Any]]
MiddlewareType = (
BaseMiddleware
| Callable[
[NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, dict[str, Any]],
Awaitable[Any],
],
]
)
UNHANDLED = sentinel.UNHANDLED
REJECTED = sentinel.REJECTED
@ -28,7 +30,7 @@ class CancelHandler(Exception):
pass
def skip(message: Optional[str] = None) -> NoReturn:
def skip(message: str | None = None) -> NoReturn:
"""
Raise an SkipHandler
"""

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Callable, List
from collections.abc import Callable
from typing import Any
from .handler import CallbackType, HandlerObject
@ -25,7 +26,7 @@ class EventObserver:
"""
def __init__(self) -> None:
self.handlers: List[HandlerObject] = []
self.handlers: list[HandlerObject] = []
def register(self, callback: CallbackType) -> None:
"""

View file

@ -1,10 +1,10 @@
import asyncio
import contextvars
import inspect
import warnings
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any
from magic_filter.magic import MagicFilter as OriginalMagicFilter
@ -21,7 +21,7 @@ CallbackType = Callable[..., Any]
class CallableObject:
callback: CallbackType
awaitable: bool = field(init=False)
params: Set[str] = field(init=False)
params: set[str] = field(init=False)
varkw: bool = field(init=False)
def __post_init__(self) -> None:
@ -31,7 +31,7 @@ class CallableObject:
self.params = {*spec.args, *spec.kwonlyargs}
self.varkw = spec.varkw is not None
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.varkw:
return kwargs
@ -46,7 +46,7 @@ class CallableObject:
@dataclass
class FilterObject(CallableObject):
magic: Optional[MagicFilter] = None
magic: MagicFilter | None = None
def __post_init__(self) -> None:
if isinstance(self.callback, OriginalMagicFilter):
@ -65,7 +65,7 @@ class FilterObject(CallableObject):
stacklevel=6,
)
super(FilterObject, self).__post_init__()
super().__post_init__()
if isinstance(self.callback, Filter):
self.awaitable = True
@ -73,17 +73,17 @@ class FilterObject(CallableObject):
@dataclass
class HandlerObject(CallableObject):
filters: Optional[List[FilterObject]] = None
flags: Dict[str, Any] = field(default_factory=dict)
filters: list[FilterObject] | None = None
flags: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
super(HandlerObject, self).__post_init__()
super().__post_init__()
callback = inspect.unwrap(self.callback)
if inspect.isclass(callback) and issubclass(callback, BaseHandler):
self.awaitable = True
self.flags.update(extract_flags_from_object(callback))
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
async def check(self, *args: Any, **kwargs: Any) -> tuple[bool, dict[str, Any]]:
if not self.filters:
return True, kwargs
for event_filter in self.filters:

View file

@ -1,17 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.exceptions import UnsupportedKeywordArgument
from aiogram.filters.base import Filter
from ...exceptions import UnsupportedKeywordArgument
from ...filters.base import Filter
from ...types import TelegramObject
from .bases import UNHANDLED, MiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, HandlerObject
if TYPE_CHECKING:
from aiogram.dispatcher.router import Router
from aiogram.types import TelegramObject
class TelegramEventObserver:
@ -26,7 +27,7 @@ class TelegramEventObserver:
self.router: Router = router
self.event_name: str = event_name
self.handlers: List[HandlerObject] = []
self.handlers: list[HandlerObject] = []
self.middleware = MiddlewareManager()
self.outer_middleware = MiddlewareManager()
@ -45,8 +46,8 @@ class TelegramEventObserver:
self._handler.filters = []
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
middlewares: List[MiddlewareType[TelegramObject]] = []
def _resolve_middlewares(self) -> list[MiddlewareType[TelegramObject]]:
middlewares: list[MiddlewareType[TelegramObject]] = []
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers.get(self.event_name)
if observer:
@ -58,14 +59,14 @@ class TelegramEventObserver:
self,
callback: CallbackType,
*filters: CallbackType,
flags: Optional[Dict[str, Any]] = None,
flags: dict[str, Any] | None = None,
**kwargs: Any,
) -> CallbackType:
"""
Register event handler
"""
if kwargs:
raise UnsupportedKeywordArgument(
msg = (
"Passing any additional keyword arguments to the registrar method "
"is not supported.\n"
"This error may be caused when you are trying to register filters like in 2.x "
@ -73,6 +74,7 @@ class TelegramEventObserver:
"documentation pages.\n"
f"Please remove the {set(kwargs.keys())} arguments from this call.\n"
)
raise UnsupportedKeywordArgument(msg)
if flags is None:
flags = {}
@ -86,13 +88,16 @@ class TelegramEventObserver:
callback=callback,
filters=[FilterObject(filter_) for filter_ in filters],
flags=flags,
)
),
)
return callback
def wrap_outer_middleware(
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
self,
callback: Any,
event: TelegramObject,
data: dict[str, Any],
) -> Any:
wrapped_outer = self.middleware.wrap_middlewares(
self.outer_middleware,
@ -127,7 +132,7 @@ class TelegramEventObserver:
def __call__(
self,
*filters: CallbackType,
flags: Optional[Dict[str, Any]] = None,
flags: dict[str, Any] | None = None,
**kwargs: Any,
) -> Callable[[CallbackType], CallbackType]:
"""

View file

@ -1,5 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast, overload
from typing import TYPE_CHECKING, Any, Union, cast, overload
from magic_filter import AttrDict, MagicFilter
@ -39,11 +40,12 @@ class FlagDecorator:
def __call__(
self,
value: Optional[Any] = None,
value: Any | None = None,
**kwargs: Any,
) -> Union[Callable[..., Any], "FlagDecorator"]:
if value and kwargs:
raise ValueError("The arguments `value` and **kwargs can not be used together")
msg = "The arguments `value` and **kwargs can not be used together"
raise ValueError(msg)
if value is not None and callable(value):
value.aiogram_flag = {
@ -70,20 +72,21 @@ if TYPE_CHECKING:
class FlagGenerator:
def __getattr__(self, name: str) -> FlagDecorator:
if name[0] == "_":
raise AttributeError("Flag name must NOT start with underscore")
msg = "Flag name must NOT start with underscore"
raise AttributeError(msg)
return FlagDecorator(Flag(name, True))
if TYPE_CHECKING:
chat_action: _ChatActionFlagProtocol
def extract_flags_from_object(obj: Any) -> Dict[str, Any]:
def extract_flags_from_object(obj: Any) -> dict[str, Any]:
if not hasattr(obj, "aiogram_flag"):
return {}
return cast(Dict[str, Any], obj.aiogram_flag)
return cast(dict[str, Any], obj.aiogram_flag)
def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, Any]:
def extract_flags(handler: Union["HandlerObject", dict[str, Any]]) -> dict[str, Any]:
"""
Extract flags from handler or middleware context data
@ -98,10 +101,10 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
def get_flag(
handler: Union["HandlerObject", Dict[str, Any]],
handler: Union["HandlerObject", dict[str, Any]],
name: str,
*,
default: Optional[Any] = None,
default: Any | None = None,
) -> Any:
"""
Get flag by name
@ -115,7 +118,7 @@ def get_flag(
return flags.get(name, default)
def check_flags(handler: Union["HandlerObject", Dict[str, Any]], magic: MagicFilter) -> Any:
def check_flags(handler: Union["HandlerObject", dict[str, Any]], magic: MagicFilter) -> Any:
"""
Check flags via magic filter

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, TypeVar
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from aiogram.types import TelegramObject
@ -14,9 +15,9 @@ class BaseMiddleware(ABC):
@abstractmethod
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any: # pragma: no cover
"""
Execute middleware
@ -26,4 +27,3 @@ class BaseMiddleware(ABC):
:param data: Contextual data. Will be mapped to handler arguments
:return: :class:`Any`
"""
pass

View file

@ -1,14 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiogram.dispatcher.event.bases import UNHANDLED, CancelHandler, SkipHandler
from aiogram.types import TelegramObject, Update
from aiogram.types.error_event import ErrorEvent
from ...types import TelegramObject, Update
from ...types.error_event import ErrorEvent
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
from .base import BaseMiddleware
if TYPE_CHECKING:
from ..router import Router
from aiogram.dispatcher.router import Router
class ErrorsMiddleware(BaseMiddleware):
@ -17,9 +19,9 @@ class ErrorsMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
try:
return await handler(event, data)

View file

@ -1,5 +1,6 @@
import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
from collections.abc import Callable, Sequence
from typing import Any, overload
from aiogram.dispatcher.event.bases import (
MiddlewareEventType,
@ -12,7 +13,7 @@ from aiogram.types import TelegramObject
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __init__(self) -> None:
self._middlewares: List[MiddlewareType[TelegramObject]] = []
self._middlewares: list[MiddlewareType[TelegramObject]] = []
def register(
self,
@ -26,11 +27,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __call__(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
middleware: MiddlewareType[TelegramObject] | None = None,
) -> (
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]]
| MiddlewareType[TelegramObject]
):
if middleware is None:
return self.register
return self.register(middleware)
@ -44,8 +45,9 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
self,
item: int | slice,
) -> MiddlewareType[TelegramObject] | Sequence[MiddlewareType[TelegramObject]]:
return self._middlewares[item]
def __len__(self) -> int:
@ -53,10 +55,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
@staticmethod
def wrap_middlewares(
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType
middlewares: Sequence[MiddlewareType[MiddlewareEventType]],
handler: CallbackType,
) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler)
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
def handler_wrapper(event: TelegramObject, kwargs: dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = handler_wrapper

View file

@ -1,5 +1,6 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import (
@ -20,29 +21,30 @@ EVENT_THREAD_ID_KEY = "event_thread_id"
@dataclass(frozen=True)
class EventContext:
chat: Optional[Chat] = None
user: Optional[User] = None
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
chat: Chat | None = None
user: User | None = None
thread_id: int | None = None
business_connection_id: str | None = None
@property
def user_id(self) -> Optional[int]:
def user_id(self) -> int | None:
return self.user.id if self.user else None
@property
def chat_id(self) -> Optional[int]:
def chat_id(self) -> int | None:
return self.chat.id if self.chat else None
class UserContextMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
msg = "UserContextMiddleware got an unexpected event type!"
raise RuntimeError(msg)
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
# Backward compatibility
@ -116,13 +118,15 @@ class UserContextMiddleware(BaseMiddleware):
)
if event.my_chat_member:
return EventContext(
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user
chat=event.my_chat_member.chat,
user=event.my_chat_member.from_user,
)
if event.chat_member:
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
if event.chat_join_request:
return EventContext(
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user
chat=event.chat_join_request.chat,
user=event.chat_join_request.from_user,
)
if event.message_reaction:
return EventContext(

View file

@ -1,12 +1,15 @@
from __future__ import annotations
from typing import Any, Dict, Final, Generator, List, Optional, Set
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Final
from ..types import TelegramObject
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
from .event.telegram import TelegramEventObserver
if TYPE_CHECKING:
from aiogram.types import TelegramObject
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
@ -21,31 +24,34 @@ class Router:
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
"""
def __init__(self, *, name: Optional[str] = None) -> None:
def __init__(self, *, name: str | None = None) -> None:
"""
:param name: Optional router name, can be useful for debugging
"""
self.name = name or hex(id(self))
self._parent_router: Optional[Router] = None
self.sub_routers: List[Router] = []
self._parent_router: Router | None = None
self.sub_routers: list[Router] = []
# Observers
self.message = TelegramEventObserver(router=self, event_name="message")
self.edited_message = TelegramEventObserver(router=self, event_name="edited_message")
self.channel_post = TelegramEventObserver(router=self, event_name="channel_post")
self.edited_channel_post = TelegramEventObserver(
router=self, event_name="edited_channel_post"
router=self,
event_name="edited_channel_post",
)
self.inline_query = TelegramEventObserver(router=self, event_name="inline_query")
self.chosen_inline_result = TelegramEventObserver(
router=self, event_name="chosen_inline_result"
router=self,
event_name="chosen_inline_result",
)
self.callback_query = TelegramEventObserver(router=self, event_name="callback_query")
self.shipping_query = TelegramEventObserver(router=self, event_name="shipping_query")
self.pre_checkout_query = TelegramEventObserver(
router=self, event_name="pre_checkout_query"
router=self,
event_name="pre_checkout_query",
)
self.poll = TelegramEventObserver(router=self, event_name="poll")
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
@ -54,24 +60,30 @@ class Router:
self.chat_join_request = TelegramEventObserver(router=self, event_name="chat_join_request")
self.message_reaction = TelegramEventObserver(router=self, event_name="message_reaction")
self.message_reaction_count = TelegramEventObserver(
router=self, event_name="message_reaction_count"
router=self,
event_name="message_reaction_count",
)
self.chat_boost = TelegramEventObserver(router=self, event_name="chat_boost")
self.removed_chat_boost = TelegramEventObserver(
router=self, event_name="removed_chat_boost"
router=self,
event_name="removed_chat_boost",
)
self.deleted_business_messages = TelegramEventObserver(
router=self, event_name="deleted_business_messages"
router=self,
event_name="deleted_business_messages",
)
self.business_connection = TelegramEventObserver(
router=self, event_name="business_connection"
router=self,
event_name="business_connection",
)
self.edited_business_message = TelegramEventObserver(
router=self, event_name="edited_business_message"
router=self,
event_name="edited_business_message",
)
self.business_message = TelegramEventObserver(router=self, event_name="business_message")
self.purchased_paid_media = TelegramEventObserver(
router=self, event_name="purchased_paid_media"
router=self,
event_name="purchased_paid_media",
)
self.errors = self.error = TelegramEventObserver(router=self, event_name="error")
@ -79,7 +91,7 @@ class Router:
self.startup = EventObserver()
self.shutdown = EventObserver()
self.observers: Dict[str, TelegramEventObserver] = {
self.observers: dict[str, TelegramEventObserver] = {
"message": self.message,
"edited_message": self.edited_message,
"channel_post": self.channel_post,
@ -112,7 +124,7 @@ class Router:
def __repr__(self) -> str:
return f"<{self}>"
def resolve_used_update_types(self, skip_events: Optional[Set[str]] = None) -> List[str]:
def resolve_used_update_types(self, skip_events: set[str] | None = None) -> list[str]:
"""
Resolve registered event names
@ -121,7 +133,7 @@ class Router:
:param skip_events: skip specified event names
:return: set of registered names
"""
handlers_in_use: Set[str] = set()
handlers_in_use: set[str] = set()
if skip_events is None:
skip_events = set()
skip_events = {*skip_events, *INTERNAL_UPDATE_TYPES}
@ -139,7 +151,10 @@ class Router:
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
return await self._propagate_event(
observer=observer, update_type=update_type, event=telegram_event, **data
observer=observer,
update_type=update_type,
event=telegram_event,
**data,
)
if observer:
@ -148,7 +163,7 @@ class Router:
async def _propagate_event(
self,
observer: Optional[TelegramEventObserver],
observer: TelegramEventObserver | None,
update_type: str,
event: TelegramObject,
**kwargs: Any,
@ -179,7 +194,7 @@ class Router:
@property
def chain_head(self) -> Generator[Router, None, None]:
router: Optional[Router] = self
router: Router | None = self
while router:
yield router
router = router.parent_router
@ -191,7 +206,7 @@ class Router:
yield from router.chain_tail
@property
def parent_router(self) -> Optional[Router]:
def parent_router(self) -> Router | None:
return self._parent_router
@parent_router.setter
@ -206,16 +221,20 @@ class Router:
:param router:
"""
if not isinstance(router, Router):
raise ValueError(f"router should be instance of Router not {type(router).__name__!r}")
msg = f"router should be instance of Router not {type(router).__name__!r}"
raise ValueError(msg)
if self._parent_router:
raise RuntimeError(f"Router is already attached to {self._parent_router!r}")
msg = f"Router is already attached to {self._parent_router!r}"
raise RuntimeError(msg)
if self == router:
raise RuntimeError("Self-referencing routers is not allowed")
msg = "Self-referencing routers is not allowed"
raise RuntimeError(msg)
parent: Optional[Router] = router
parent: Router | None = router
while parent is not None:
if parent == self:
raise RuntimeError("Circular referencing of Router is not allowed")
msg = "Circular referencing of Router is not allowed"
raise RuntimeError(msg)
parent = parent.parent_router
@ -230,7 +249,8 @@ class Router:
:return:
"""
if not routers:
raise ValueError("At least one router must be provided")
msg = "At least one router must be provided"
raise ValueError(msg)
for router in routers:
self.include_router(router)
@ -242,9 +262,8 @@ class Router:
:return:
"""
if not isinstance(router, Router):
raise ValueError(
f"router should be instance of Router not {type(router).__class__.__name__}"
)
msg = f"router should be instance of Router not {type(router).__class__.__name__}"
raise ValueError(msg)
router.parent_router = self
return router

View file

@ -16,7 +16,7 @@ class DetailedAiogramError(AiogramError):
Base exception for all aiogram errors with detailed message.
"""
url: Optional[str] = None
url: str | None = None
def __init__(self, message: str) -> None:
self.message = message

View file

@ -23,29 +23,29 @@ from .state import StateFilter
BaseFilter = Filter
__all__ = (
"Filter",
"ADMINISTRATOR",
"CREATOR",
"IS_ADMIN",
"IS_MEMBER",
"IS_NOT_MEMBER",
"JOIN_TRANSITION",
"KICKED",
"LEAVE_TRANSITION",
"LEFT",
"MEMBER",
"PROMOTED_TRANSITION",
"RESTRICTED",
"BaseFilter",
"ChatMemberUpdatedFilter",
"Command",
"CommandObject",
"CommandStart",
"ExceptionMessageFilter",
"ExceptionTypeFilter",
"StateFilter",
"Filter",
"MagicData",
"ChatMemberUpdatedFilter",
"CREATOR",
"ADMINISTRATOR",
"MEMBER",
"RESTRICTED",
"LEFT",
"KICKED",
"IS_MEMBER",
"IS_ADMIN",
"PROMOTED_TRANSITION",
"IS_NOT_MEMBER",
"JOIN_TRANSITION",
"LEAVE_TRANSITION",
"StateFilter",
"and_f",
"or_f",
"invert_f",
"or_f",
)

View file

@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from aiogram.filters.logic import _InvertFilter
class Filter(ABC):
class Filter(ABC): # noqa: B024
"""
If you want to register own filters like builtin filters you will need to write subclass
of this class with overriding the :code:`__call__`
@ -16,11 +17,11 @@ class Filter(ABC):
# This checking type-hint is needed because mypy checks validity of overrides and raises:
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
__call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]]
__call__: Callable[..., Awaitable[bool | dict[str, Any]]]
else: # pragma: no cover
@abstractmethod
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
"""
This method should be overridden.
@ -28,21 +29,19 @@ class Filter(ABC):
:return: :class:`bool` or :class:`Dict[str, Any]`
"""
pass
def __invert__(self) -> "_InvertFilter":
from aiogram.filters.logic import invert_f
return invert_f(self)
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
def update_handler_flags(self, flags: dict[str, Any]) -> None: # noqa: B027
"""
Also if you want to extend handler flags with using this filter
you should implement this method
:param flags: existing flags, can be updated directly
"""
pass
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
items = [repr(arg) for arg in args]

View file

@ -1,40 +1,30 @@
from __future__ import annotations
import sys
import types
import typing
from decimal import Decimal
from enum import Enum
from fractions import Fraction
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
from uuid import UUID
from magic_filter import MagicFilter
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import Self
from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery
if TYPE_CHECKING:
from magic_filter import MagicFilter
from pydantic.fields import FieldInfo
T = TypeVar("T", bound="CallbackData")
MAX_CALLBACK_LENGTH: int = 64
_UNION_TYPES = {typing.Union}
if sys.version_info >= (3, 10): # pragma: no cover
_UNION_TYPES.add(types.UnionType)
_UNION_TYPES = {typing.Union, types.UnionType}
class CallbackDataException(Exception):
@ -59,17 +49,19 @@ class CallbackData(BaseModel):
def __init_subclass__(cls, **kwargs: Any) -> None:
if "prefix" not in kwargs:
raise ValueError(
msg = (
f"prefix required, usage example: "
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
)
raise ValueError(msg)
cls.__separator__ = kwargs.pop("sep", ":")
cls.__prefix__ = kwargs.pop("prefix")
if cls.__separator__ in cls.__prefix__:
raise ValueError(
msg = (
f"Separator symbol {cls.__separator__!r} can not be used "
f"inside prefix {cls.__prefix__!r}"
)
raise ValueError(msg)
super().__init_subclass__(**kwargs)
def _encode_value(self, key: str, value: Any) -> str:
@ -83,10 +75,11 @@ class CallbackData(BaseModel):
return str(int(value))
if isinstance(value, (int, str, float, Decimal, Fraction)):
return str(value)
raise ValueError(
msg = (
f"Attribute {key}={value!r} of type {type(value).__name__!r}"
f" can not be packed to callback data"
)
raise ValueError(msg)
def pack(self) -> str:
"""
@ -98,21 +91,23 @@ class CallbackData(BaseModel):
for key, value in self.model_dump(mode="python").items():
encoded = self._encode_value(key, value)
if self.__separator__ in encoded:
raise ValueError(
msg = (
f"Separator symbol {self.__separator__!r} can not be used "
f"in value {key}={encoded!r}"
)
raise ValueError(msg)
result.append(encoded)
callback_data = self.__separator__.join(result)
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
raise ValueError(
msg = (
f"Resulted callback data is too long! "
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
)
raise ValueError(msg)
return callback_data
@classmethod
def unpack(cls: Type[T], value: str) -> T:
def unpack(cls, value: str) -> Self:
"""
Parse callback data string
@ -122,22 +117,28 @@ class CallbackData(BaseModel):
prefix, *parts = value.split(cls.__separator__)
names = cls.model_fields.keys()
if len(parts) != len(names):
raise TypeError(
msg = (
f"Callback data {cls.__name__!r} takes {len(names)} arguments "
f"but {len(parts)} were given"
)
raise TypeError(msg)
if prefix != cls.__prefix__:
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
msg = f"Bad prefix ({prefix!r} != {cls.__prefix__!r})"
raise ValueError(msg)
payload = {}
for k, v in zip(names, parts): # type: str, Optional[str]
if field := cls.model_fields.get(k):
if v == "" and _check_field_is_nullable(field) and field.default != "":
for k, v in zip(names, parts, strict=True): # type: str, str
if (
(field := cls.model_fields.get(k))
and v == ""
and _check_field_is_nullable(field)
and field.default != ""
):
v = field.default if field.default is not PydanticUndefined else None
payload[k] = v
return cls(**payload)
@classmethod
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter:
def filter(cls, rule: MagicFilter | None = None) -> CallbackQueryFilter:
"""
Generates a filter for callback query with rule
@ -163,8 +164,8 @@ class CallbackQueryFilter(Filter):
def __init__(
self,
*,
callback_data: Type[CallbackData],
rule: Optional[MagicFilter] = None,
callback_data: type[CallbackData],
rule: MagicFilter | None = None,
):
"""
:param callback_data: Expected type of callback data
@ -179,7 +180,7 @@ class CallbackQueryFilter(Filter):
rule=self.rule,
)
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
async def __call__(self, query: CallbackQuery) -> Literal[False] | dict[str, Any]:
if not isinstance(query, CallbackQuery) or not query.data:
return False
try:
@ -204,5 +205,5 @@ def _check_field_is_nullable(field: FieldInfo) -> bool:
return True
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
field.annotation
field.annotation,
)

View file

@ -1,4 +1,6 @@
from typing import Any, Dict, Optional, TypeVar, Union
from typing import Any, TypeVar, Union
from typing_extensions import Self
from aiogram.filters.base import Filter
from aiogram.types import ChatMember, ChatMemberUpdated
@ -10,11 +12,11 @@ TransitionT = TypeVar("TransitionT", bound="_MemberStatusTransition")
class _MemberStatusMarker:
__slots__ = (
"name",
"is_member",
"name",
)
def __init__(self, name: str, *, is_member: Optional[bool] = None) -> None:
def __init__(self, name: str, *, is_member: bool | None = None) -> None:
self.name = name
self.is_member = is_member
@ -22,53 +24,59 @@ class _MemberStatusMarker:
result = self.name.upper()
if self.is_member is not None:
result = ("+" if self.is_member else "-") + result
return result # noqa: RET504
return result
def __pos__(self: MarkerT) -> MarkerT:
def __pos__(self) -> Self:
return type(self)(name=self.name, is_member=True)
def __neg__(self: MarkerT) -> MarkerT:
def __neg__(self) -> Self:
return type(self)(name=self.name, is_member=False)
def __or__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusGroupMarker":
if isinstance(other, _MemberStatusMarker):
return _MemberStatusGroupMarker(self, other)
if isinstance(other, _MemberStatusGroupMarker):
return other | self
raise TypeError(
msg = (
f"unsupported operand type(s) for |: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
__ror__ = __or__
def __rshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition":
old = _MemberStatusGroupMarker(self)
if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=old, new=_MemberStatusGroupMarker(other))
if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=old, new=other)
raise TypeError(
msg = (
f"unsupported operand type(s) for >>: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
def __lshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition":
new = _MemberStatusGroupMarker(self)
if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=new)
if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=other, new=new)
raise TypeError(
msg = (
f"unsupported operand type(s) for <<: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
def __hash__(self) -> int:
return hash((self.name, self.is_member))
@ -87,44 +95,51 @@ class _MemberStatusGroupMarker:
def __init__(self, *statuses: _MemberStatusMarker) -> None:
if not statuses:
raise ValueError("Member status group should have at least one status included")
msg = "Member status group should have at least one status included"
raise ValueError(msg)
self.statuses = frozenset(statuses)
def __or__(
self: MarkerGroupT, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
) -> MarkerGroupT:
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> Self:
if isinstance(other, _MemberStatusMarker):
return type(self)(*self.statuses, other)
if isinstance(other, _MemberStatusGroupMarker):
return type(self)(*self.statuses, *other.statuses)
raise TypeError(
msg = (
f"unsupported operand type(s) for |: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
def __rshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition":
if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=self, new=_MemberStatusGroupMarker(other))
if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=self, new=other)
raise TypeError(
msg = (
f"unsupported operand type(s) for >>: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
def __lshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition":
if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=self)
if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=other, new=self)
raise TypeError(
msg = (
f"unsupported operand type(s) for <<: "
f"{type(self).__name__!r} and {type(other).__name__!r}"
)
raise TypeError(msg)
def __str__(self) -> str:
result = " | ".join(map(str, sorted(self.statuses, key=str)))
@ -138,8 +153,8 @@ class _MemberStatusGroupMarker:
class _MemberStatusTransition:
__slots__ = (
"old",
"new",
"old",
)
def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None:
@ -149,7 +164,7 @@ class _MemberStatusTransition:
def __str__(self) -> str:
return f"{self.old} >> {self.new}"
def __invert__(self: TransitionT) -> TransitionT:
def __invert__(self) -> Self:
return type(self)(old=self.new, new=self.old)
def check(self, *, old: ChatMember, new: ChatMember) -> bool:
@ -177,11 +192,9 @@ class ChatMemberUpdatedFilter(Filter):
def __init__(
self,
member_status_changed: Union[
_MemberStatusMarker,
_MemberStatusGroupMarker,
_MemberStatusTransition,
],
member_status_changed: (
_MemberStatusMarker | _MemberStatusGroupMarker | _MemberStatusTransition
),
):
self.member_status_changed = member_status_changed
@ -190,7 +203,7 @@ class ChatMemberUpdatedFilter(Filter):
member_status_changed=self.member_status_changed,
)
async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]:
async def __call__(self, member_updated: ChatMemberUpdated) -> bool | dict[str, Any]:
old = member_updated.old_chat_member
new = member_updated.new_chat_member
rule = self.member_status_changed

View file

@ -1,31 +1,21 @@
from __future__ import annotations
import re
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Match,
Optional,
Pattern,
Sequence,
Union,
cast,
)
from magic_filter import MagicFilter
from re import Match, Pattern
from typing import TYPE_CHECKING, Any, cast
from aiogram.filters.base import Filter
from aiogram.types import BotCommand, Message
from aiogram.utils.deep_linking import decode_payload
if TYPE_CHECKING:
from magic_filter import MagicFilter
from aiogram import Bot
# TODO: rm type ignore after py3.8 support expiration or mypy bug fix
CommandPatternType = Union[str, re.Pattern, BotCommand] # type: ignore[type-arg]
CommandPatternType = str | re.Pattern[str] | BotCommand
class CommandException(Exception):
@ -41,20 +31,20 @@ class Command(Filter):
__slots__ = (
"commands",
"prefix",
"ignore_case",
"ignore_mention",
"magic",
"prefix",
)
def __init__(
self,
*values: CommandPatternType,
commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None,
commands: Sequence[CommandPatternType] | CommandPatternType | None = None,
prefix: str = "/",
ignore_case: bool = False,
ignore_mention: bool = False,
magic: Optional[MagicFilter] = None,
magic: MagicFilter | None = None,
):
"""
List of commands (string or compiled regexp patterns)
@ -74,26 +64,29 @@ class Command(Filter):
commands = [commands]
if not isinstance(commands, Iterable):
raise ValueError(
msg = (
"Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable"
)
raise ValueError(msg)
items = []
for command in (*values, *commands):
if isinstance(command, BotCommand):
command = command.command
if not isinstance(command, (str, re.Pattern)):
raise ValueError(
msg = (
"Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable"
)
raise ValueError(msg)
if ignore_case and isinstance(command, str):
command = command.casefold()
items.append(command)
if not items:
raise ValueError("At least one command should be specified")
msg = "At least one command should be specified"
raise ValueError(msg)
self.commands = tuple(items)
self.prefix = prefix
@ -110,11 +103,11 @@ class Command(Filter):
magic=self.magic,
)
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
def update_handler_flags(self, flags: dict[str, Any]) -> None:
commands = flags.setdefault("commands", [])
commands.append(self)
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
async def __call__(self, message: Message, bot: Bot) -> bool | dict[str, Any]:
if not isinstance(message, Message):
return False
@ -137,7 +130,8 @@ class Command(Filter):
try:
full_command, *args = text.split(maxsplit=1)
except ValueError:
raise CommandException("not enough values to unpack")
msg = "not enough values to unpack"
raise CommandException(msg)
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
@ -151,13 +145,15 @@ class Command(Filter):
def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.prefix:
raise CommandException("Invalid command prefix")
msg = "Invalid command prefix"
raise CommandException(msg)
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.ignore_mention:
me = await bot.me()
if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match")
msg = "Mention did not match"
raise CommandException(msg)
def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatternType], self.commands):
@ -174,7 +170,8 @@ class Command(Filter):
if command_name == allowed_command: # String
return command
raise CommandException("Command did not match pattern")
msg = "Command did not match pattern"
raise CommandException(msg)
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
@ -196,7 +193,8 @@ class Command(Filter):
return command
result = self.magic.resolve(command)
if not result:
raise CommandException("Rejected via magic filter")
msg = "Rejected via magic filter"
raise CommandException(msg)
return replace(command, magic_result=result)
@ -211,13 +209,13 @@ class CommandObject:
"""Command prefix"""
command: str = ""
"""Command without prefix and mention"""
mention: Optional[str] = None
mention: str | None = None
"""Mention (if available)"""
args: Optional[str] = field(repr=False, default=None)
args: str | None = field(repr=False, default=None)
"""Command argument"""
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
regexp_match: Match[str] | None = field(repr=False, default=None)
"""Will be presented match result if the command is presented as regexp in filter"""
magic_result: Optional[Any] = field(repr=False, default=None)
magic_result: Any | None = field(repr=False, default=None)
@property
def mentioned(self) -> bool:
@ -246,7 +244,7 @@ class CommandStart(Command):
deep_link_encoded: bool = False,
ignore_case: bool = False,
ignore_mention: bool = False,
magic: Optional[MagicFilter] = None,
magic: MagicFilter | None = None,
):
super().__init__(
"start",
@ -287,12 +285,14 @@ class CommandStart(Command):
if not self.deep_link:
return command
if not command.args:
raise CommandException("Deep-link was missing")
msg = "Deep-link was missing"
raise CommandException(msg)
args = command.args
if self.deep_link_encoded:
try:
args = decode_payload(args)
except UnicodeDecodeError as e:
raise CommandException(f"Failed to decode Base64: {e}")
msg = f"Failed to decode Base64: {e}"
raise CommandException(msg)
return replace(command, args=args)
return command

View file

@ -1,5 +1,6 @@
import re
from typing import Any, Dict, Pattern, Type, Union, cast
from re import Pattern
from typing import Any, cast
from aiogram.filters.base import Filter
from aiogram.types import TelegramObject
@ -13,15 +14,16 @@ class ExceptionTypeFilter(Filter):
__slots__ = ("exceptions",)
def __init__(self, *exceptions: Type[Exception]):
def __init__(self, *exceptions: type[Exception]):
"""
:param exceptions: Exception type(s)
"""
if not exceptions:
raise ValueError("At least one exception type is required")
msg = "At least one exception type is required"
raise ValueError(msg)
self.exceptions = exceptions
async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]:
async def __call__(self, obj: TelegramObject) -> bool | dict[str, Any]:
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
@ -32,7 +34,7 @@ class ExceptionMessageFilter(Filter):
__slots__ = ("pattern",)
def __init__(self, pattern: Union[str, Pattern[str]]):
def __init__(self, pattern: str | Pattern[str]):
"""
:param pattern: Regexp pattern
"""
@ -48,7 +50,7 @@ class ExceptionMessageFilter(Filter):
async def __call__(
self,
obj: TelegramObject,
) -> Union[bool, Dict[str, Any]]:
) -> bool | dict[str, Any]:
result = self.pattern.match(str(cast(ErrorEvent, obj).exception))
if not result:
return False

View file

@ -1,5 +1,5 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Union
from typing import TYPE_CHECKING, Any
from aiogram.filters import Filter
@ -17,7 +17,7 @@ class _InvertFilter(_LogicFilter):
def __init__(self, target: "FilterObject") -> None:
self.target = target
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
return not bool(await self.target.call(*args, **kwargs))
@ -27,7 +27,7 @@ class _AndFilter(_LogicFilter):
def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
final_result = {}
for target in self.targets:
@ -48,7 +48,7 @@ class _OrFilter(_LogicFilter):
def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
for target in self.targets:
result = await target.call(*args, **kwargs)
if not result:

View file

@ -18,7 +18,7 @@ class MagicData(Filter):
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
return self.magic_data.resolve(
AttrDict({"event": event, **dict(enumerate(args)), **kwargs})
AttrDict({"event": event, **dict(enumerate(args)), **kwargs}),
)
def __str__(self) -> str:

View file

@ -1,11 +1,12 @@
from collections.abc import Sequence
from inspect import isclass
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
from typing import Any, cast
from aiogram.filters.base import Filter
from aiogram.fsm.state import State, StatesGroup
from aiogram.types import TelegramObject
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]]
StateType = str | State | StatesGroup | type[StatesGroup] | None
class StateFilter(Filter):
@ -17,7 +18,8 @@ class StateFilter(Filter):
def __init__(self, *states: StateType) -> None:
if not states:
raise ValueError("At least one state is required")
msg = "At least one state is required"
raise ValueError(msg)
self.states = states
@ -27,17 +29,22 @@ class StateFilter(Filter):
)
async def __call__(
self, obj: TelegramObject, raw_state: Optional[str] = None
) -> Union[bool, Dict[str, Any]]:
self,
obj: TelegramObject,
raw_state: str | None = None,
) -> bool | dict[str, Any]:
allowed_states = cast(Sequence[StateType], self.states)
for allowed_state in allowed_states:
if isinstance(allowed_state, str) or allowed_state is None:
if allowed_state == "*" or raw_state == allowed_state:
if allowed_state in {"*", raw_state}:
return True
elif isinstance(allowed_state, (State, StatesGroup)):
if allowed_state(event=obj, raw_state=raw_state):
return True
elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup):
if allowed_state()(event=obj, raw_state=raw_state):
elif (
isclass(allowed_state)
and issubclass(allowed_state, StatesGroup)
and allowed_state()(event=obj, raw_state=raw_state)
):
return True
return False

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, overload
from collections.abc import Mapping
from typing import Any, overload
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
@ -11,27 +12,29 @@ class FSMContext:
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(key=self.key, state=state)
async def get_state(self) -> Optional[str]:
async def get_state(self) -> str | None:
return await self.storage.get_state(key=self.key)
async def set_data(self, data: Mapping[str, Any]) -> None:
await self.storage.set_data(key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]:
async def get_data(self) -> dict[str, Any]:
return await self.storage.get_data(key=self.key)
@overload
async def get_value(self, key: str) -> Optional[Any]: ...
async def get_value(self, key: str) -> Any | None: ...
@overload
async def get_value(self, key: str, default: Any) -> Any: ...
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
async def get_value(self, key: str, default: Any | None = None) -> Any | None:
return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default)
async def update_data(
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
self,
data: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
if data:
kwargs.update(data)
return await self.storage.update_data(key=self.key, data=kwargs)

View file

@ -1,4 +1,5 @@
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from collections.abc import Awaitable, Callable
from typing import Any, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
@ -27,9 +28,9 @@ class FSMContextMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data)
@ -45,9 +46,9 @@ class FSMContextMiddleware(BaseMiddleware):
def resolve_event_context(
self,
bot: Bot,
data: Dict[str, Any],
data: dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
) -> FSMContext | None:
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
return self.resolve_context(
bot=bot,
@ -61,12 +62,12 @@ class FSMContextMiddleware(BaseMiddleware):
def resolve_context(
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
chat_id: int | None,
user_id: int | None,
thread_id: int | None = None,
business_connection_id: str | None = None,
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
) -> FSMContext | None:
if chat_id is None:
chat_id = user_id
@ -92,8 +93,8 @@ class FSMContextMiddleware(BaseMiddleware):
bot: Bot,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
thread_id: int | None = None,
business_connection_id: str | None = None,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(

View file

@ -2,26 +2,15 @@ from __future__ import annotations
import inspect
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass, replace
from enum import Enum, auto
from typing import (
Any,
ClassVar,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
overload,
)
from typing import TYPE_CHECKING, Any, ClassVar, overload
from typing_extensions import Self
from aiogram import loggers
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.dispatcher.router import Router
@ -36,16 +25,20 @@ from aiogram.utils.class_attrs_resolver import (
get_sorted_mro_attrs_resolver,
)
if TYPE_CHECKING:
from aiogram.dispatcher.event.bases import NextMiddlewareType
class HistoryManager:
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
self._size = size
self._state = state
self._history_state = FSMContext(
storage=state.storage, key=replace(state.key, destiny=destiny)
storage=state.storage,
key=replace(state.key, destiny=destiny),
)
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
async def push(self, state: str | None, data: dict[str, Any]) -> None:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
history.append({"state": state, "data": data})
@ -55,7 +48,7 @@ class HistoryManager:
await self._history_state.update_data(history=history)
async def pop(self) -> Optional[MemoryStorageRecord]:
async def pop(self) -> MemoryStorageRecord | None:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
@ -70,14 +63,14 @@ class HistoryManager:
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
return MemoryStorageRecord(state=state, data=data)
async def get(self) -> Optional[MemoryStorageRecord]:
async def get(self) -> MemoryStorageRecord | None:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
return None
return MemoryStorageRecord(**history[-1])
async def all(self) -> List[MemoryStorageRecord]:
async def all(self) -> list[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
return [MemoryStorageRecord(**item) for item in history]
@ -91,11 +84,11 @@ class HistoryManager:
data = await self._state.get_data()
await self.push(state, data)
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
async def _set_state(self, state: str | None, data: dict[str, Any]) -> None:
await self._state.set_state(state)
await self._state.set_data(data)
async def rollback(self) -> Optional[str]:
async def rollback(self) -> str | None:
previous_state = await self.pop()
if not previous_state:
await self._set_state(None, {})
@ -116,14 +109,14 @@ class ObserverDecorator:
name: str,
filters: tuple[CallbackType, ...],
action: SceneAction | None = None,
after: Optional[After] = None,
after: After | None = None,
) -> None:
self.name = name
self.filters = filters
self.action = action
self.after = after
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
def _wrap_filter(self, target: type[Scene] | CallbackType) -> None:
handlers = getattr(target, "__aiogram_handler__", None)
if not handlers:
handlers = []
@ -135,7 +128,7 @@ class ObserverDecorator:
handler=target,
filters=self.filters,
after=self.after,
)
),
)
def _wrap_action(self, target: CallbackType) -> None:
@ -154,13 +147,14 @@ class ObserverDecorator:
else:
self._wrap_action(target)
else:
raise TypeError("Only function or method is allowed")
msg = "Only function or method is allowed"
raise TypeError(msg)
return target
def leave(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.leave)
def enter(self, target: Type[Scene]) -> ActionContainer:
def enter(self, target: type[Scene]) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.enter, target)
def exit(self) -> ActionContainer:
@ -181,9 +175,9 @@ class ActionContainer:
def __init__(
self,
name: str,
filters: Tuple[CallbackType, ...],
filters: tuple[CallbackType, ...],
action: SceneAction,
target: Optional[Union[Type[Scene], State, str]] = None,
target: type[Scene] | State | str | None = None,
) -> None:
self.name = name
self.filters = filters
@ -201,33 +195,27 @@ class ActionContainer:
await wizard.back()
@dataclass(slots=True)
class HandlerContainer:
def __init__(
self,
name: str,
handler: CallbackType,
filters: Tuple[CallbackType, ...],
after: Optional[After] = None,
) -> None:
self.name = name
self.handler = handler
self.filters = filters
self.after = after
name: str
handler: CallbackType
filters: tuple[CallbackType, ...]
after: After | None = None
@dataclass()
@dataclass
class SceneConfig:
state: Optional[str]
state: str | None
"""Scene state"""
handlers: List[HandlerContainer]
handlers: list[HandlerContainer]
"""Scene handlers"""
actions: Dict[SceneAction, Dict[str, CallableObject]]
actions: dict[SceneAction, dict[str, CallableObject]]
"""Scene actions"""
reset_data_on_enter: Optional[bool] = None
reset_data_on_enter: bool | None = None
"""Reset scene data on enter"""
reset_history_on_enter: Optional[bool] = None
reset_history_on_enter: bool | None = None
"""Reset scene history on enter"""
callback_query_without_state: Optional[bool] = None
callback_query_without_state: bool | None = None
"""Allow callback query without state"""
attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver
"""
@ -247,9 +235,9 @@ async def _empty_handler(*args: Any, **kwargs: Any) -> None:
class SceneHandlerWrapper:
def __init__(
self,
scene: Type[Scene],
scene: type[Scene],
handler: CallbackType,
after: Optional[After] = None,
after: After | None = None,
) -> None:
self.scene = scene
self.handler = CallableObject(handler)
@ -271,7 +259,7 @@ class SceneHandlerWrapper:
update_type=event_update.event_type,
event=event,
data=kwargs,
)
),
)
result = await self.handler.call(scene, event, **kwargs)
@ -331,7 +319,7 @@ class Scene:
super().__init_subclass__(**kwargs)
handlers: list[HandlerContainer] = []
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
actions: defaultdict[SceneAction, dict[str, CallableObject]] = defaultdict(dict)
for base in cls.__bases__:
if not issubclass(base, Scene):
@ -353,7 +341,7 @@ class Scene:
if attrs_resolver is None:
attrs_resolver = get_sorted_mro_attrs_resolver
for name, value in attrs_resolver(cls):
for _name, value in attrs_resolver(cls):
if scene_handlers := getattr(value, "__aiogram_handler__", None):
handlers.extend(scene_handlers)
if isinstance(value, ObserverDecorator):
@ -363,7 +351,7 @@ class Scene:
_empty_handler,
value.filters,
after=value.after,
)
),
)
if hasattr(value, "__aiogram_action__"):
for action, action_handlers in value.__aiogram_action__.items():
@ -408,7 +396,7 @@ class Scene:
router.observers[observer_name].filter(StateFilter(scene_config.state))
@classmethod
def as_router(cls, name: Optional[str] = None) -> Router:
def as_router(cls, name: str | None = None) -> Router:
"""
Returns the scene as a router.
@ -433,7 +421,9 @@ class Scene:
"""
async def enter_to_scene_handler(
event: TelegramObject, scenes: ScenesManager, **middleware_kwargs: Any
event: TelegramObject,
scenes: ScenesManager,
**middleware_kwargs: Any,
) -> None:
await scenes.enter(cls, **{**handler_kwargs, **middleware_kwargs})
@ -461,7 +451,7 @@ class SceneWizard:
state: FSMContext,
update_type: str,
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
):
"""
A class that represents a wizard for managing scenes in a Telegram bot.
@ -480,7 +470,7 @@ class SceneWizard:
self.event = event
self.data = data
self.scene: Optional[Scene] = None
self.scene: Scene | None = None
async def enter(self, **kwargs: Any) -> None:
"""
@ -548,7 +538,7 @@ class SceneWizard:
assert self.scene_config.state is not None, "Scene state is not specified"
await self.goto(self.scene_config.state, **kwargs)
async def goto(self, scene: Union[Type[Scene], State, str], **kwargs: Any) -> None:
async def goto(self, scene: type[Scene] | State | str, **kwargs: Any) -> None:
"""
The `goto` method transitions to a new scene.
It first calls the `leave` method to perform any necessary cleanup
@ -565,13 +555,16 @@ class SceneWizard:
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
if not self.scene:
raise SceneException("Scene is not initialized")
msg = "Scene is not initialized"
raise SceneException(msg)
loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state)
action_config = self.scene_config.actions.get(action, {})
if not action_config:
loggers.scene.debug(
"Action %r not found in scene %r", action.name, self.scene_config.state
"Action %r not found in scene %r",
action.name,
self.scene_config.state,
)
return False
@ -597,7 +590,7 @@ class SceneWizard:
"""
await self.state.set_data(data=data)
async def get_data(self) -> Dict[str, Any]:
async def get_data(self) -> dict[str, Any]:
"""
This method returns the data stored in the current state.
@ -606,7 +599,7 @@ class SceneWizard:
return await self.state.get_data()
@overload
async def get_value(self, key: str) -> Optional[Any]:
async def get_value(self, key: str) -> Any | None:
"""
This method returns the value from key in the data of the current state.
@ -614,7 +607,6 @@ class SceneWizard:
:return: A dictionary containing the data stored in the scene state.
"""
pass
@overload
async def get_value(self, key: str, default: Any) -> Any:
@ -626,14 +618,15 @@ class SceneWizard:
:return: A dictionary containing the data stored in the scene state.
"""
pass
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
async def get_value(self, key: str, default: Any | None = None) -> Any | None:
return await self.state.get_value(key, default)
async def update_data(
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
self,
data: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""
This method updates the data stored in the current state
@ -666,7 +659,7 @@ class ScenesManager:
update_type: str,
event: TelegramObject,
state: FSMContext,
data: Dict[str, Any],
data: dict[str, Any],
) -> None:
self.registry = registry
self.update_type = update_type
@ -676,7 +669,7 @@ class ScenesManager:
self.history = HistoryManager(self.state)
async def _get_scene(self, scene_type: Optional[Union[Type[Scene], State, str]]) -> Scene:
async def _get_scene(self, scene_type: type[Scene] | State | str | None) -> Scene:
scene_type = self.registry.get(scene_type)
return scene_type(
wizard=SceneWizard(
@ -689,7 +682,7 @@ class ScenesManager:
),
)
async def _get_active_scene(self) -> Optional[Scene]:
async def _get_active_scene(self) -> Scene | None:
state = await self.state.get_state()
try:
return await self._get_scene(state)
@ -698,7 +691,7 @@ class ScenesManager:
async def enter(
self,
scene_type: Optional[Union[Type[Scene], State, str]],
scene_type: type[Scene] | State | str | None,
_check_active: bool = True,
**kwargs: Any,
) -> None:
@ -753,7 +746,7 @@ class SceneRegistry:
self.router = router
self.register_on_add = register_on_add
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._scenes: dict[str | None, type[Scene]] = {}
self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None:
@ -772,7 +765,7 @@ class SceneRegistry:
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
assert isinstance(event, Update), "Event must be an Update instance"
@ -789,7 +782,7 @@ class SceneRegistry:
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
update: Update = data["event_update"]
data["scenes"] = ScenesManager(
@ -801,7 +794,7 @@ class SceneRegistry:
)
return await handler(event, data)
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
def add(self, *scenes: type[Scene], router: Router | None = None) -> None:
"""
This method adds the specified scenes to the registry
and optionally registers it to the router.
@ -820,13 +813,13 @@ class SceneRegistry:
:return: None
"""
if not scenes:
raise ValueError("At least one scene must be specified")
msg = "At least one scene must be specified"
raise ValueError(msg)
for scene in scenes:
if scene.__scene_config__.state in self._scenes:
raise SceneException(
f"Scene with state {scene.__scene_config__.state!r} already exists"
)
msg = f"Scene with state {scene.__scene_config__.state!r} already exists"
raise SceneException(msg)
self._scenes[scene.__scene_config__.state] = scene
@ -835,7 +828,7 @@ class SceneRegistry:
elif self.register_on_add:
self.router.include_router(scene.as_router())
def register(self, *scenes: Type[Scene]) -> None:
def register(self, *scenes: type[Scene]) -> None:
"""
Registers one or more scenes to the SceneRegistry.
@ -844,7 +837,7 @@ class SceneRegistry:
"""
self.add(*scenes, router=self.router)
def get(self, scene: Optional[Union[Type[Scene], State, str]]) -> Type[Scene]:
def get(self, scene: type[Scene] | State | str | None) -> type[Scene]:
"""
This method returns the registered Scene object for the specified scene.
The scene parameter can be either a Scene object, State object or a string representing
@ -865,18 +858,20 @@ class SceneRegistry:
if isinstance(scene, State):
scene = scene.state
if scene is not None and not isinstance(scene, str):
raise SceneException("Scene must be a subclass of Scene, State or a string")
msg = "Scene must be a subclass of Scene, State or a string"
raise SceneException(msg)
try:
return self._scenes[scene]
except KeyError:
raise SceneException(f"Scene {scene!r} is not registered")
msg = f"Scene {scene!r} is not registered"
raise SceneException(msg)
@dataclass
class After:
action: SceneAction
scene: Optional[Union[Type[Scene], State, str]] = None
scene: type[Scene] | State | str | None = None
@classmethod
def exit(cls) -> After:
@ -887,7 +882,7 @@ class After:
return cls(action=SceneAction.back)
@classmethod
def goto(cls, scene: Optional[Union[Type[Scene], State, str]]) -> After:
def goto(cls, scene: type[Scene] | State | str | None) -> After:
return cls(action=SceneAction.enter, scene=scene)
@ -898,7 +893,7 @@ class ObserverMarker:
def __call__(
self,
*filters: CallbackType,
after: Optional[After] = None,
after: After | None = None,
) -> ObserverDecorator:
return ObserverDecorator(
self.name,

View file

@ -1,5 +1,6 @@
import inspect
from typing import Any, Iterator, Optional, Tuple, Type, no_type_check
from collections.abc import Iterator
from typing import Any, no_type_check
from aiogram.types import TelegramObject
@ -9,19 +10,20 @@ class State:
State object
"""
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None) -> None:
def __init__(self, state: str | None = None, group_name: str | None = None) -> None:
self._state = state
self._group_name = group_name
self._group: Optional[Type[StatesGroup]] = None
self._group: type[StatesGroup] | None = None
@property
def group(self) -> "Type[StatesGroup]":
def group(self) -> "type[StatesGroup]":
if not self._group:
raise RuntimeError("This state is not in any group.")
msg = "This state is not in any group."
raise RuntimeError(msg)
return self._group
@property
def state(self) -> Optional[str]:
def state(self) -> str | None:
if self._state is None or self._state == "*":
return self._state
@ -34,12 +36,13 @@ class State:
return f"{group}:{self._state}"
def set_parent(self, group: "Type[StatesGroup]") -> None:
def set_parent(self, group: "type[StatesGroup]") -> None:
if not issubclass(group, StatesGroup):
raise ValueError("Group must be subclass of StatesGroup")
msg = "Group must be subclass of StatesGroup"
raise ValueError(msg)
self._group = group
def __set_name__(self, owner: "Type[StatesGroup]", name: str) -> None:
def __set_name__(self, owner: "type[StatesGroup]", name: str) -> None:
if self._state is None:
self._state = name
self.set_parent(owner)
@ -49,12 +52,12 @@ class State:
__repr__ = __str__
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
if self.state == "*":
return True
return raw_state == self.state
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.state == other.state
if isinstance(other, str):
@ -66,13 +69,13 @@ class State:
class StatesGroupMeta(type):
__parent__: "Optional[Type[StatesGroup]]"
__childs__: "Tuple[Type[StatesGroup], ...]"
__states__: Tuple[State, ...]
__state_names__: Tuple[str, ...]
__all_childs__: Tuple[Type["StatesGroup"], ...]
__all_states__: Tuple[State, ...]
__all_states_names__: Tuple[str, ...]
__parent__: type["StatesGroup"] | None
__childs__: tuple[type["StatesGroup"], ...]
__states__: tuple[State, ...]
__state_names__: tuple[str, ...]
__all_childs__: tuple[type["StatesGroup"], ...]
__all_states__: tuple[State, ...]
__all_states_names__: tuple[str, ...]
@no_type_check
def __new__(mcs, name, bases, namespace, **kwargs):
@ -81,7 +84,7 @@ class StatesGroupMeta(type):
states = []
childs = []
for name, arg in namespace.items():
for arg in namespace.values():
if isinstance(arg, State):
states.append(arg)
elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
@ -106,10 +109,10 @@ class StatesGroupMeta(type):
@property
def __full_group_name__(cls) -> str:
if cls.__parent__:
return ".".join((cls.__parent__.__full_group_name__, cls.__name__))
return f"{cls.__parent__.__full_group_name__}.{cls.__name__}"
return cls.__name__
def _prepare_child(cls, child: Type["StatesGroup"]) -> Type["StatesGroup"]:
def _prepare_child(cls, child: type["StatesGroup"]) -> type["StatesGroup"]:
"""Prepare child.
While adding `cls` for its children, we also need to recalculate
@ -123,19 +126,19 @@ class StatesGroupMeta(type):
child.__all_states_names__ = child._get_all_states_names()
return child
def _get_all_childs(cls) -> Tuple[Type["StatesGroup"], ...]:
def _get_all_childs(cls) -> tuple[type["StatesGroup"], ...]:
result = cls.__childs__
for child in cls.__childs__:
result += child.__childs__
return result
def _get_all_states(cls) -> Tuple[State, ...]:
def _get_all_states(cls) -> tuple[State, ...]:
result = cls.__states__
for group in cls.__childs__:
result += group.__all_states__
return result
def _get_all_states_names(cls) -> Tuple[str, ...]:
def _get_all_states_names(cls) -> tuple[str, ...]:
return tuple(state.state for state in cls.__all_states__ if state.state)
def __contains__(cls, item: Any) -> bool:
@ -156,12 +159,12 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
def get_root(cls) -> Type["StatesGroup"]:
def get_root(cls) -> type["StatesGroup"]:
if cls.__parent__ is None:
return cls
return cls.__parent__.get_root()
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
return raw_state in type(self).__all_states_names__
def __str__(self) -> str:

View file

@ -1,20 +1,12 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Dict,
Literal,
Mapping,
Optional,
Union,
overload,
)
from typing import Any, Literal, overload
from aiogram.fsm.state import State
StateType = Optional[Union[str, State]]
StateType = str | State | None
DEFAULT_DESTINY = "default"
@ -24,8 +16,8 @@ class StorageKey:
bot_id: int
chat_id: int
user_id: int
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
thread_id: int | None = None
business_connection_id: str | None = None
destiny: str = DEFAULT_DESTINY
@ -36,7 +28,7 @@ class KeyBuilder(ABC):
def build(
self,
key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None,
part: Literal["data", "state", "lock"] | None = None,
) -> str:
"""
Build key to be used in storage's db queries
@ -45,7 +37,6 @@ class KeyBuilder(ABC):
:param part: part of the record
:return: key to be used in storage's db queries
"""
pass
class DefaultKeyBuilder(KeyBuilder):
@ -84,7 +75,7 @@ class DefaultKeyBuilder(KeyBuilder):
def build(
self,
key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None,
part: Literal["data", "state", "lock"] | None = None,
) -> str:
parts = [self.prefix]
if self.with_bot_id:
@ -121,17 +112,15 @@ class BaseStorage(ABC):
:param key: storage key
:param state: new state
"""
pass
@abstractmethod
async def get_state(self, key: StorageKey) -> Optional[str]:
async def get_state(self, key: StorageKey) -> str | None:
"""
Get key state
:param key: storage key
:return: current state
"""
pass
@abstractmethod
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
@ -141,20 +130,18 @@ class BaseStorage(ABC):
:param key: storage key
:param data: new data
"""
pass
@abstractmethod
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
async def get_data(self, key: StorageKey) -> dict[str, Any]:
"""
Get current data for key
:param key: storage key
:return: current data
"""
pass
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]:
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None:
"""
Get single value from data by key
@ -162,7 +149,6 @@ class BaseStorage(ABC):
:param dict_key: value key
:return: value stored in key of dict or ``None``
"""
pass
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any:
@ -174,15 +160,17 @@ class BaseStorage(ABC):
:param default: default value to return
:return: value stored in key of dict or default
"""
pass
async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
) -> Optional[Any]:
self,
storage_key: StorageKey,
dict_key: str,
default: Any | None = None,
) -> Any | None:
data = await self.get_data(storage_key)
return data.get(dict_key, default)
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
"""
Update date in the storage for key (like dict.update)
@ -200,7 +188,6 @@ class BaseStorage(ABC):
"""
Close storage (database connection, file or etc.)
"""
pass
class BaseEventIsolation(ABC):

View file

@ -1,18 +1,10 @@
from asyncio import Lock
from collections import defaultdict
from collections.abc import AsyncGenerator, Hashable, Mapping
from contextlib import asynccontextmanager
from copy import copy
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
DefaultDict,
Dict,
Hashable,
Mapping,
Optional,
overload,
)
from typing import Any, overload
from aiogram.exceptions import DataNotDictLikeError
from aiogram.fsm.state import State
@ -26,8 +18,8 @@ from aiogram.fsm.storage.base import (
@dataclass
class MemoryStorageRecord:
data: Dict[str, Any] = field(default_factory=dict)
state: Optional[str] = None
data: dict[str, Any] = field(default_factory=dict)
state: str | None = None
class MemoryStorage(BaseStorage):
@ -41,8 +33,8 @@ class MemoryStorage(BaseStorage):
"""
def __init__(self) -> None:
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord
self.storage: defaultdict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord,
)
async def close(self) -> None:
@ -51,28 +43,30 @@ class MemoryStorage(BaseStorage):
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, key: StorageKey) -> Optional[str]:
async def get_state(self, key: StorageKey) -> str | None:
return self.storage[key].state
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict):
raise DataNotDictLikeError(
f"Data must be a dict or dict-like object, got {type(data).__name__}"
)
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
raise DataNotDictLikeError(msg)
self.storage[key].data = data.copy()
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
async def get_data(self, key: StorageKey) -> dict[str, Any]:
return self.storage[key].data.copy()
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]: ...
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None: ...
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
) -> Optional[Any]:
self,
storage_key: StorageKey,
dict_key: str,
default: Any | None = None,
) -> Any | None:
data = self.storage[storage_key].data
return copy(data.get(dict_key, default))
@ -89,7 +83,7 @@ class DisabledEventIsolation(BaseEventIsolation):
class SimpleEventIsolation(BaseEventIsolation):
def __init__(self) -> None:
# TODO: Unused locks cleaner is needed
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
self._locks: defaultdict[Hashable, Lock] = defaultdict(Lock)
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, cast
from collections.abc import Mapping
from typing import Any, cast
from motor.motor_asyncio import AsyncIOMotorClient
@ -27,7 +28,7 @@ class MongoStorage(BaseStorage):
def __init__(
self,
client: AsyncIOMotorClient,
key_builder: Optional[KeyBuilder] = None,
key_builder: KeyBuilder | None = None,
db_name: str = "aiogram_fsm",
collection_name: str = "states_and_data",
) -> None:
@ -46,7 +47,10 @@ class MongoStorage(BaseStorage):
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "MongoStorage":
"""
Create an instance of :class:`MongoStorage` with specifying the connection string
@ -65,7 +69,7 @@ class MongoStorage(BaseStorage):
"""Cleanup client resources and disconnect from MongoDB."""
self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]:
def resolve_state(self, value: StateType) -> str | None:
if value is None:
return None
if isinstance(value, State):
@ -90,7 +94,7 @@ class MongoStorage(BaseStorage):
upsert=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
async def get_state(self, key: StorageKey) -> str | None:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None:
@ -99,9 +103,8 @@ class MongoStorage(BaseStorage):
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict):
raise DataNotDictLikeError(
f"Data must be a dict or dict-like object, got {type(data).__name__}"
)
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
raise DataNotDictLikeError(msg)
document_id = self._key_builder.build(key)
if not data:
@ -120,14 +123,14 @@ class MongoStorage(BaseStorage):
upsert=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
async def get_data(self, key: StorageKey) -> dict[str, Any]:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None or not document.get("data"):
return {}
return cast(Dict[str, Any], document["data"])
return cast(dict[str, Any], document["data"])
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
document_id = self._key_builder.build(key)
update_with = {f"data.{key}": value for key, value in data.items()}
update_result = await self._collection.find_one_and_update(

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, cast
from collections.abc import Mapping
from typing import Any, cast
from pymongo import AsyncMongoClient
@ -21,7 +22,7 @@ class PyMongoStorage(BaseStorage):
def __init__(
self,
client: AsyncMongoClient[Any],
key_builder: Optional[KeyBuilder] = None,
key_builder: KeyBuilder | None = None,
db_name: str = "aiogram_fsm",
collection_name: str = "states_and_data",
) -> None:
@ -40,7 +41,10 @@ class PyMongoStorage(BaseStorage):
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "PyMongoStorage":
"""
Create an instance of :class:`PyMongoStorage` with specifying the connection string
@ -59,7 +63,7 @@ class PyMongoStorage(BaseStorage):
"""Cleanup client resources and disconnect from MongoDB."""
return await self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]:
def resolve_state(self, value: StateType) -> str | None:
if value is None:
return None
if isinstance(value, State):
@ -84,18 +88,17 @@ class PyMongoStorage(BaseStorage):
upsert=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
async def get_state(self, key: StorageKey) -> str | None:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None:
return None
return cast(Optional[str], document.get("state"))
return cast(str | None, document.get("state"))
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict):
raise DataNotDictLikeError(
f"Data must be a dict or dict-like object, got {type(data).__name__}"
)
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
raise DataNotDictLikeError(msg)
document_id = self._key_builder.build(key)
if not data:
@ -114,14 +117,14 @@ class PyMongoStorage(BaseStorage):
upsert=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
async def get_data(self, key: StorageKey) -> dict[str, Any]:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None or not document.get("data"):
return {}
return cast(Dict[str, Any], document["data"])
return cast(dict[str, Any], document["data"])
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
document_id = self._key_builder.build(key)
update_with = {f"data.{key}": value for key, value in data.items()}
update_result = await self._collection.find_one_and_update(
@ -133,4 +136,4 @@ class PyMongoStorage(BaseStorage):
)
if not update_result:
await self._collection.delete_one({"_id": document_id})
return cast(Dict[str, Any], update_result.get("data", {}))
return cast(dict[str, Any], update_result.get("data", {}))

View file

@ -1,6 +1,7 @@
import json
from collections.abc import AsyncGenerator, Callable, Mapping
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Optional, cast
from typing import Any, cast
from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
@ -31,9 +32,9 @@ class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[ExpiryT] = None,
data_ttl: Optional[ExpiryT] = None,
key_builder: KeyBuilder | None = None,
state_ttl: ExpiryT | None = None,
data_ttl: ExpiryT | None = None,
json_loads: _JsonLoads = json.loads,
json_dumps: _JsonDumps = json.dumps,
) -> None:
@ -54,7 +55,10 @@ class RedisStorage(BaseStorage):
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "RedisStorage":
"""
Create an instance of :class:`RedisStorage` with specifying the connection string
@ -94,12 +98,12 @@ class RedisStorage(BaseStorage):
async def get_state(
self,
key: StorageKey,
) -> Optional[str]:
) -> str | None:
redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
return cast(str | None, value)
async def set_data(
self,
@ -107,9 +111,8 @@ class RedisStorage(BaseStorage):
data: Mapping[str, Any],
) -> None:
if not isinstance(data, dict):
raise DataNotDictLikeError(
f"Data must be a dict or dict-like object, got {type(data).__name__}"
)
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
raise DataNotDictLikeError(msg)
redis_key = self.key_builder.build(key, "data")
if not data:
@ -124,22 +127,22 @@ class RedisStorage(BaseStorage):
async def get_data(
self,
key: StorageKey,
) -> Dict[str, Any]:
) -> dict[str, Any]:
redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], self.json_loads(value))
return cast(dict[str, Any], self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation):
def __init__(
self,
redis: Redis,
key_builder: Optional[KeyBuilder] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
key_builder: KeyBuilder | None = None,
lock_kwargs: dict[str, Any] | None = None,
) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
@ -153,7 +156,7 @@ class RedisEventIsolation(BaseEventIsolation):
def from_url(
cls,
url: str,
connection_kwargs: Optional[Dict[str, Any]] = None,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "RedisEventIsolation":
if connection_kwargs is None:

View file

@ -1,5 +1,4 @@
from enum import Enum, auto
from typing import Optional, Tuple
class FSMStrategy(Enum):
@ -23,8 +22,8 @@ def apply_strategy(
strategy: FSMStrategy,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
) -> Tuple[int, int, Optional[int]]:
thread_id: int | None = None,
) -> tuple[int, int, int | None]:
if strategy == FSMStrategy.CHAT:
return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER:

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from aiogram.types import Update
@ -14,7 +14,7 @@ T = TypeVar("T")
class BaseHandlerMixin(Generic[T]):
if TYPE_CHECKING:
event: T
data: Dict[str, Any]
data: dict[str, Any]
class BaseHandler(BaseHandlerMixin[T], ABC):
@ -24,7 +24,7 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
def __init__(self, event: T, **kwargs: Any) -> None:
self.event: T = event
self.data: Dict[str, Any] = kwargs
self.data: dict[str, Any] = kwargs
@property
def bot(self) -> Bot:
@ -32,7 +32,8 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
if "bot" in self.data:
return cast(Bot, self.data["bot"])
raise RuntimeError("Bot instance not found in the context")
msg = "Bot instance not found in the context"
raise RuntimeError(msg)
@property
def update(self) -> Update:

View file

@ -29,14 +29,14 @@ class CallbackQueryHandler(BaseHandler[CallbackQuery], ABC):
return self.event.from_user
@property
def message(self) -> Optional[MaybeInaccessibleMessage]:
def message(self) -> MaybeInaccessibleMessage | None:
"""
Is alias for `event.message`
"""
return self.event.message
@property
def callback_data(self) -> Optional[str]:
def callback_data(self) -> str | None:
"""
Is alias for `event.data`
"""

View file

@ -12,7 +12,7 @@ class MessageHandler(BaseHandler[Message], ABC):
"""
@property
def from_user(self) -> Optional[User]:
def from_user(self) -> User | None:
return self.event.from_user
@property
@ -22,7 +22,7 @@ class MessageHandler(BaseHandler[Message], ABC):
class MessageHandlerCommandMixin(BaseHandlerMixin[Message]):
@property
def command(self) -> Optional[CommandObject]:
def command(self) -> CommandObject | None:
if "command" in self.data:
return cast(CommandObject, self.data["command"])
return None

View file

@ -1,5 +1,4 @@
from abc import ABC
from typing import List
from aiogram.handlers import BaseHandler
from aiogram.types import Poll, PollOption
@ -15,5 +14,5 @@ class PollHandler(BaseHandler[Poll], ABC):
return self.event.question
@property
def options(self) -> List[PollOption]:
def options(self) -> list[PollOption]:
return self.event.options

View file

@ -1,6 +1,6 @@
import hashlib
import hmac
from typing import Any, Dict
from typing import Any
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
@ -17,12 +17,14 @@ def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
secret = hashlib.sha256(token.encode("utf-8"))
check_string = "\n".join(f"{k}={kwargs[k]}" for k in sorted(kwargs))
hmac_string = hmac.new(
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
secret.digest(),
check_string.encode("utf-8"),
digestmod=hashlib.sha256,
).hexdigest()
return hmac_string == hash
def check_integrity(token: str, data: Dict[str, Any]) -> bool:
def check_integrity(token: str, data: dict[str, Any]) -> bool:
"""
Verify the authentication and the integrity
of the data received on user's auth

View file

@ -13,9 +13,11 @@ class BackoffConfig:
def __post_init__(self) -> None:
if self.max_delay <= self.min_delay:
raise ValueError("`max_delay` should be greater than `min_delay`")
msg = "`max_delay` should be greater than `min_delay`"
raise ValueError(msg)
if self.factor <= 1:
raise ValueError("`factor` should be greater than 1")
msg = "`factor` should be greater than 1"
raise ValueError(msg)
class Backoff:

View file

@ -1,4 +1,5 @@
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from collections.abc import Awaitable, Callable
from typing import Any
from aiogram import BaseMiddleware, loggers
from aiogram.dispatcher.flags import get_flag
@ -12,10 +13,10 @@ class CallbackAnswer:
self,
answered: bool,
disabled: bool = False,
text: Optional[str] = None,
show_alert: Optional[bool] = None,
url: Optional[str] = None,
cache_time: Optional[int] = None,
text: str | None = None,
show_alert: bool | None = None,
url: str | None = None,
cache_time: int | None = None,
) -> None:
"""
Callback answer configuration
@ -48,7 +49,8 @@ class CallbackAnswer:
@disabled.setter
def disabled(self, value: bool) -> None:
if self._answered:
raise CallbackAnswerException("Can't change disabled state after answer")
msg = "Can't change disabled state after answer"
raise CallbackAnswerException(msg)
self._disabled = value
@property
@ -59,7 +61,7 @@ class CallbackAnswer:
return self._answered
@property
def text(self) -> Optional[str]:
def text(self) -> str | None:
"""
Response text
:return:
@ -67,48 +69,52 @@ class CallbackAnswer:
return self._text
@text.setter
def text(self, value: Optional[str]) -> None:
def text(self, value: str | None) -> None:
if self._answered:
raise CallbackAnswerException("Can't change text after answer")
msg = "Can't change text after answer"
raise CallbackAnswerException(msg)
self._text = value
@property
def show_alert(self) -> Optional[bool]:
def show_alert(self) -> bool | None:
"""
Whether to display an alert
"""
return self._show_alert
@show_alert.setter
def show_alert(self, value: Optional[bool]) -> None:
def show_alert(self, value: bool | None) -> None:
if self._answered:
raise CallbackAnswerException("Can't change show_alert after answer")
msg = "Can't change show_alert after answer"
raise CallbackAnswerException(msg)
self._show_alert = value
@property
def url(self) -> Optional[str]:
def url(self) -> str | None:
"""
Game url
"""
return self._url
@url.setter
def url(self, value: Optional[str]) -> None:
def url(self, value: str | None) -> None:
if self._answered:
raise CallbackAnswerException("Can't change url after answer")
msg = "Can't change url after answer"
raise CallbackAnswerException(msg)
self._url = value
@property
def cache_time(self) -> Optional[int]:
def cache_time(self) -> int | None:
"""
Response cache time
"""
return self._cache_time
@cache_time.setter
def cache_time(self, value: Optional[int]) -> None:
def cache_time(self, value: int | None) -> None:
if self._answered:
raise CallbackAnswerException("Can't change cache_time after answer")
msg = "Can't change cache_time after answer"
raise CallbackAnswerException(msg)
self._cache_time = value
def __str__(self) -> str:
@ -131,10 +137,10 @@ class CallbackAnswerMiddleware(BaseMiddleware):
def __init__(
self,
pre: bool = False,
text: Optional[str] = None,
show_alert: Optional[bool] = None,
url: Optional[str] = None,
cache_time: Optional[int] = None,
text: str | None = None,
show_alert: bool | None = None,
url: str | None = None,
cache_time: int | None = None,
) -> None:
"""
Inner middleware for callback query handlers, can be useful in bots with a lot of callback
@ -154,15 +160,15 @@ class CallbackAnswerMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
if not isinstance(event, CallbackQuery):
return await handler(event, data)
callback_answer = data["callback_answer"] = self.construct_callback_answer(
properties=get_flag(data, "callback_answer")
properties=get_flag(data, "callback_answer"),
)
if not callback_answer.disabled and callback_answer.answered:
@ -174,7 +180,8 @@ class CallbackAnswerMiddleware(BaseMiddleware):
await self.answer(event, callback_answer)
def construct_callback_answer(
self, properties: Optional[Union[Dict[str, Any], bool]]
self,
properties: dict[str, Any] | bool | None,
) -> CallbackAnswer:
pre, disabled, text, show_alert, url, cache_time = (
self.pre,

View file

@ -2,9 +2,10 @@ import asyncio
import logging
import time
from asyncio import Event, Lock
from collections.abc import Awaitable, Callable
from contextlib import suppress
from types import TracebackType
from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union
from typing import Any
from aiogram import BaseMiddleware, Bot
from aiogram.dispatcher.flags import get_flag
@ -32,8 +33,8 @@ class ChatActionSender:
self,
*,
bot: Bot,
chat_id: Union[str, int],
message_thread_id: Optional[int] = None,
chat_id: str | int,
message_thread_id: int | None = None,
action: str = "typing",
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
@ -56,7 +57,7 @@ class ChatActionSender:
self._lock = Lock()
self._close_event = Event()
self._closed_event = Event()
self._task: Optional[asyncio.Task[Any]] = None
self._task: asyncio.Task[Any] | None = None
@property
def running(self) -> bool:
@ -108,7 +109,8 @@ class ChatActionSender:
self._close_event.clear()
self._closed_event.clear()
if self.running:
raise RuntimeError("Already running")
msg = "Already running"
raise RuntimeError(msg)
self._task = asyncio.create_task(self._worker())
async def _stop(self) -> None:
@ -126,18 +128,18 @@ class ChatActionSender:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> Any:
await self._stop()
@classmethod
def typing(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -154,9 +156,9 @@ class ChatActionSender:
@classmethod
def upload_photo(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -173,9 +175,9 @@ class ChatActionSender:
@classmethod
def record_video(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -192,9 +194,9 @@ class ChatActionSender:
@classmethod
def upload_video(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -211,9 +213,9 @@ class ChatActionSender:
@classmethod
def record_voice(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -230,9 +232,9 @@ class ChatActionSender:
@classmethod
def upload_voice(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -249,9 +251,9 @@ class ChatActionSender:
@classmethod
def upload_document(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -268,9 +270,9 @@ class ChatActionSender:
@classmethod
def choose_sticker(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -287,9 +289,9 @@ class ChatActionSender:
@classmethod
def find_location(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -306,9 +308,9 @@ class ChatActionSender:
@classmethod
def record_video_note(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -325,9 +327,9 @@ class ChatActionSender:
@classmethod
def upload_video_note(
cls,
chat_id: Union[int, str],
chat_id: int | str,
bot: Bot,
message_thread_id: Optional[int] = None,
message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -349,9 +351,9 @@ class ChatActionMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
if not isinstance(event, Message):
return await handler(event, data)

View file

@ -1,7 +1,6 @@
from typing import Tuple, Type, Union
from typing import Annotated
from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from aiogram.types import (
ChatMember,
@ -13,22 +12,22 @@ from aiogram.types import (
ChatMemberRestricted,
)
ChatMemberUnion = Union[
ChatMemberOwner,
ChatMemberAdministrator,
ChatMemberMember,
ChatMemberRestricted,
ChatMemberLeft,
ChatMemberBanned,
]
ChatMemberUnion = (
ChatMemberOwner
| ChatMemberAdministrator
| ChatMemberMember
| ChatMemberRestricted
| ChatMemberLeft
| ChatMemberBanned
)
ChatMemberCollection = Tuple[Type[ChatMember], ...]
ChatMemberCollection = tuple[type[ChatMember], ...]
ChatMemberAdapter: TypeAdapter[ChatMemberUnion] = TypeAdapter(
Annotated[
ChatMemberUnion,
Field(discriminator="status"),
]
],
)
ADMINS: ChatMemberCollection = (ChatMemberOwner, ChatMemberAdministrator)

View file

@ -1,7 +1,8 @@
import inspect
from collections.abc import Generator
from dataclasses import dataclass
from operator import itemgetter
from typing import Any, Generator, NamedTuple, Protocol
from typing import Any, NamedTuple, Protocol
from aiogram.utils.dataclass import dataclass_kwargs

View file

@ -9,16 +9,16 @@ from typing import Any, Union
def dataclass_kwargs(
init: Union[bool, None] = None,
repr: Union[bool, None] = None,
eq: Union[bool, None] = None,
order: Union[bool, None] = None,
unsafe_hash: Union[bool, None] = None,
frozen: Union[bool, None] = None,
match_args: Union[bool, None] = None,
kw_only: Union[bool, None] = None,
slots: Union[bool, None] = None,
weakref_slot: Union[bool, None] = None,
init: bool | None = None,
repr: bool | None = None,
eq: bool | None = None,
order: bool | None = None,
unsafe_hash: bool | None = None,
frozen: bool | None = None,
match_args: bool | None = None,
kw_only: bool | None = None,
slots: bool | None = None,
weakref_slot: bool | None = None,
) -> dict[str, Any]:
"""
Generates a dictionary of keyword arguments that can be passed to a Python
@ -48,7 +48,6 @@ def dataclass_kwargs(
params["frozen"] = frozen
# Added in 3.10
if sys.version_info >= (3, 10):
if match_args is not None:
params["match_args"] = match_args
if kw_only is not None:

View file

@ -1,22 +1,24 @@
from __future__ import annotations
__all__ = [
"create_start_link",
"create_startgroup_link",
"create_startapp_link",
"create_deep_link",
"create_start_link",
"create_startapp_link",
"create_startgroup_link",
"create_telegram_link",
"encode_payload",
"decode_payload",
"encode_payload",
]
import re
from typing import TYPE_CHECKING, Callable, Literal, Optional, cast
from typing import TYPE_CHECKING, Literal, Optional, cast
from aiogram.utils.link import create_telegram_link
from aiogram.utils.payload import decode_payload, encode_payload
if TYPE_CHECKING:
from collections.abc import Callable
from aiogram import Bot
BAD_PATTERN = re.compile(r"[^a-zA-Z0-9-_]")
@ -26,7 +28,7 @@ async def create_start_link(
bot: Bot,
payload: str,
encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None,
encoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""
Create 'start' deep link with your payload.
@ -53,7 +55,7 @@ async def create_startgroup_link(
bot: Bot,
payload: str,
encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None,
encoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""
Create 'startgroup' deep link with your payload.
@ -80,8 +82,8 @@ async def create_startapp_link(
bot: Bot,
payload: str,
encode: bool = False,
app_name: Optional[str] = None,
encoder: Optional[Callable[[bytes], bytes]] = None,
app_name: str | None = None,
encoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""
Create 'startapp' deep link with your payload.
@ -115,9 +117,9 @@ def create_deep_link(
username: str,
link_type: Literal["start", "startgroup", "startapp"],
payload: str,
app_name: Optional[str] = None,
app_name: str | None = None,
encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None,
encoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""
Create deep link.
@ -137,13 +139,15 @@ def create_deep_link(
payload = encode_payload(payload, encoder=encoder)
if re.search(BAD_PATTERN, payload):
raise ValueError(
msg = (
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
"Pass `encode=True` or encode payload manually."
)
raise ValueError(msg)
if len(payload) > 64:
raise ValueError("Payload must be up to 64 characters long.")
msg = "Payload must be up to 64 characters long."
raise ValueError(msg)
if not app_name:
deep_link = create_telegram_link(username, **{cast(str, link_type): payload})

View file

@ -1,16 +1,8 @@
from __future__ import annotations
import textwrap
from typing import (
Any,
ClassVar,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
)
from collections.abc import Generator, Iterable, Iterator
from typing import TYPE_CHECKING, Any, ClassVar
from typing_extensions import Self
@ -35,7 +27,7 @@ class Text(Iterable[NodeType]):
Simple text element
"""
type: ClassVar[Optional[str]] = None
type: ClassVar[str | None] = None
__slots__ = ("_body", "_params")
@ -44,16 +36,16 @@ class Text(Iterable[NodeType]):
*body: NodeType,
**params: Any,
) -> None:
self._body: Tuple[NodeType, ...] = body
self._params: Dict[str, Any] = params
self._body: tuple[NodeType, ...] = body
self._params: dict[str, Any] = params
@classmethod
def from_entities(cls, text: str, entities: List[MessageEntity]) -> "Text":
def from_entities(cls, text: str, entities: list[MessageEntity]) -> Text:
return cls(
*_unparse_entities(
text=add_surrogates(text),
entities=sorted(entities, key=lambda item: item.offset) if entities else [],
)
),
)
def render(
@ -62,7 +54,7 @@ class Text(Iterable[NodeType]):
_offset: int = 0,
_sort: bool = True,
_collect_entities: bool = True,
) -> Tuple[str, List[MessageEntity]]:
) -> tuple[str, list[MessageEntity]]:
"""
Render elements tree as text with entities list
@ -108,7 +100,7 @@ class Text(Iterable[NodeType]):
entities_key: str = "entities",
replace_parse_mode: bool = True,
parse_mode_key: str = "parse_mode",
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Render element tree as keyword arguments for usage in an API call, for example:
@ -124,7 +116,7 @@ class Text(Iterable[NodeType]):
:return:
"""
text_value, entities_value = self.render()
result: Dict[str, Any] = {
result: dict[str, Any] = {
text_key: text_value,
entities_key: entities_value,
}
@ -132,7 +124,7 @@ class Text(Iterable[NodeType]):
result[parse_mode_key] = None
return result
def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
"""
Shortcut for :meth:`as_kwargs` for usage with API calls that take
``caption`` as a parameter.
@ -151,7 +143,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode,
)
def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
"""
Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_poll.SendPoll`.
@ -171,7 +163,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode,
)
def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
"""
Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_poll.SendPoll`.
@ -196,7 +188,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode,
)
def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
"""
Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_gift.SendGift`.
@ -252,7 +244,7 @@ class Text(Iterable[NodeType]):
args_str = textwrap.indent("\n" + args_str + "\n", " ")
return f"{type(self).__name__}({args_str})"
def __add__(self, other: NodeType) -> "Text":
def __add__(self, other: NodeType) -> Text:
if isinstance(other, Text) and other.type == self.type and self._params == other._params:
return type(self)(*self, *other, **self._params)
if type(self) is Text and isinstance(other, str):
@ -266,9 +258,10 @@ class Text(Iterable[NodeType]):
text, _ = self.render(_collect_entities=False)
return sizeof(text)
def __getitem__(self, item: slice) -> "Text":
def __getitem__(self, item: slice) -> Text:
if not isinstance(item, slice):
raise TypeError("Can only be sliced")
msg = "Can only be sliced"
raise TypeError(msg)
if (item.start is None or item.start == 0) and item.stop is None:
return self.replace(*self._body)
start = 0 if item.start is None else item.start
@ -313,9 +306,11 @@ class HashTag(Text):
def __init__(self, *body: NodeType, **params: Any) -> None:
if len(body) != 1:
raise ValueError("Hashtag can contain only one element")
msg = "Hashtag can contain only one element"
raise ValueError(msg)
if not isinstance(body[0], str):
raise ValueError("Hashtag can contain only string")
msg = "Hashtag can contain only string"
raise ValueError(msg)
if not body[0].startswith("#"):
body = ("#" + body[0],)
super().__init__(*body, **params)
@ -337,9 +332,11 @@ class CashTag(Text):
def __init__(self, *body: NodeType, **params: Any) -> None:
if len(body) != 1:
raise ValueError("Cashtag can contain only one element")
msg = "Cashtag can contain only one element"
raise ValueError(msg)
if not isinstance(body[0], str):
raise ValueError("Cashtag can contain only string")
msg = "Cashtag can contain only string"
raise ValueError(msg)
if not body[0].startswith("$"):
body = ("$" + body[0],)
super().__init__(*body, **params)
@ -469,7 +466,7 @@ class Pre(Text):
type = MessageEntityType.PRE
def __init__(self, *body: NodeType, language: Optional[str] = None, **params: Any) -> None:
def __init__(self, *body: NodeType, language: str | None = None, **params: Any) -> None:
super().__init__(*body, language=language, **params)
@ -537,7 +534,7 @@ class ExpandableBlockQuote(Text):
type = MessageEntityType.EXPANDABLE_BLOCKQUOTE
NODE_TYPES: Dict[Optional[str], Type[Text]] = {
NODE_TYPES: dict[str | None, type[Text]] = {
Text.type: Text,
HashTag.type: HashTag,
CashTag.type: CashTag,
@ -570,15 +567,16 @@ def _apply_entity(entity: MessageEntity, *nodes: NodeType) -> NodeType:
"""
node_type = NODE_TYPES.get(entity.type, Text)
return node_type(
*nodes, **entity.model_dump(exclude={"type", "offset", "length"}, warnings=False)
*nodes,
**entity.model_dump(exclude={"type", "offset", "length"}, warnings=False),
)
def _unparse_entities(
text: bytes,
entities: List[MessageEntity],
offset: Optional[int] = None,
length: Optional[int] = None,
entities: list[MessageEntity],
offset: int | None = None,
length: int | None = None,
) -> Generator[NodeType, None, None]:
if offset is None:
offset = 0
@ -615,8 +613,7 @@ def as_line(*items: NodeType, end: str = "\n", sep: str = "") -> Text:
nodes = []
for item in items[:-1]:
nodes.extend([item, sep])
nodes.append(items[-1])
nodes.append(end)
nodes.extend([items[-1], end])
else:
nodes = [*items, end]
return Text(*nodes)

View file

@ -8,14 +8,14 @@ from .middleware import (
)
__all__ = (
"ConstI18nMiddleware",
"FSMI18nMiddleware",
"I18n",
"I18nMiddleware",
"SimpleI18nMiddleware",
"ConstI18nMiddleware",
"FSMI18nMiddleware",
"get_i18n",
"gettext",
"lazy_gettext",
"ngettext",
"lazy_ngettext",
"get_i18n",
"ngettext",
)

View file

@ -7,7 +7,8 @@ from aiogram.utils.i18n.lazy_proxy import LazyProxy
def get_i18n() -> I18n:
i18n = I18n.get_current(no_error=True)
if i18n is None:
raise LookupError("I18n context is not set")
msg = "I18n context is not set"
raise LookupError(msg)
return i18n

View file

@ -1,23 +1,27 @@
from __future__ import annotations
import gettext
import os
from contextlib import contextmanager
from contextvars import ContextVar
from pathlib import Path
from typing import Dict, Generator, Optional, Tuple, Union
from typing import TYPE_CHECKING
from aiogram.utils.i18n.lazy_proxy import LazyProxy
from aiogram.utils.mixins import ContextInstanceMixin
if TYPE_CHECKING:
from collections.abc import Generator
class I18n(ContextInstanceMixin["I18n"]):
def __init__(
self,
*,
path: Union[str, Path],
path: str | Path,
default_locale: str = "en",
domain: str = "messages",
) -> None:
self.path = path
self.path = Path(path)
self.default_locale = default_locale
self.domain = domain
self.ctx_locale = ContextVar("aiogram_ctx_locale", default=default_locale)
@ -43,7 +47,7 @@ class I18n(ContextInstanceMixin["I18n"]):
self.ctx_locale.reset(ctx_token)
@contextmanager
def context(self) -> Generator["I18n", None, None]:
def context(self) -> Generator[I18n, None, None]:
"""
Use I18n context
"""
@ -53,24 +57,25 @@ class I18n(ContextInstanceMixin["I18n"]):
finally:
self.reset_current(token)
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
def find_locales(self) -> dict[str, gettext.GNUTranslations]:
"""
Load all compiled locales from path
:return: dict with locales
"""
translations: Dict[str, gettext.GNUTranslations] = {}
translations: dict[str, gettext.GNUTranslations] = {}
for name in os.listdir(self.path):
if not os.path.isdir(os.path.join(self.path, name)):
for name in self.path.iterdir():
if not (self.path / name).is_dir():
continue
mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo")
mo_path = self.path / name / "LC_MESSAGES" / (self.domain + ".mo")
if os.path.exists(mo_path):
with open(mo_path, "rb") as fp:
translations[name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + "po"): # pragma: no cover
raise RuntimeError(f"Found locale '{name}' but this language is not compiled!")
if mo_path.exists():
with mo_path.open("rb") as fp:
translations[name.name] = gettext.GNUTranslations(fp)
elif mo_path.with_suffix(".po").exists(): # pragma: no cover
msg = f"Found locale '{name.name}' but this language is not compiled!"
raise RuntimeError(msg)
return translations
@ -81,7 +86,7 @@ class I18n(ContextInstanceMixin["I18n"]):
self.locales = self.find_locales()
@property
def available_locales(self) -> Tuple[str, ...]:
def available_locales(self) -> tuple[str, ...]:
"""
list of loaded locales
@ -90,7 +95,11 @@ class I18n(ContextInstanceMixin["I18n"]):
return tuple(self.locales.keys())
def gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
self,
singular: str,
plural: str | None = None,
n: int = 1,
locale: str | None = None,
) -> str:
"""
Get text
@ -107,7 +116,7 @@ class I18n(ContextInstanceMixin["I18n"]):
if locale not in self.locales:
if n == 1:
return singular
return plural if plural else singular
return plural or singular
translator = self.locales[locale]
@ -116,8 +125,17 @@ class I18n(ContextInstanceMixin["I18n"]):
return translator.ngettext(singular, plural, n)
def lazy_gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
self,
singular: str,
plural: str | None = None,
n: int = 1,
locale: str | None = None,
) -> LazyProxy:
return LazyProxy(
self.gettext, singular=singular, plural=plural, n=n, locale=locale, enable_cache=False
self.gettext,
singular=singular,
plural=plural,
n=n,
locale=locale,
enable_cache=False,
)

View file

@ -6,8 +6,9 @@ except ImportError: # pragma: no cover
class LazyProxy: # type: ignore
def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError(
msg = (
"LazyProxy can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)
raise RuntimeError(msg)

View file

@ -1,5 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, Optional, Set
from typing import TYPE_CHECKING, Any
try:
from babel import Locale, UnknownLocaleError
@ -11,6 +13,10 @@ except ImportError: # pragma: no cover
from aiogram import BaseMiddleware, Router
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from aiogram.fsm.context import FSMContext
from aiogram.types import TelegramObject, User
from aiogram.utils.i18n.core import I18n
@ -24,7 +30,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
def __init__(
self,
i18n: I18n,
i18n_key: Optional[str] = "i18n",
i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware",
) -> None:
"""
@ -39,7 +45,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
self.middleware_key = middleware_key
def setup(
self: BaseMiddleware, router: Router, exclude: Optional[Set[str]] = None
self: BaseMiddleware,
router: Router,
exclude: set[str] | None = None,
) -> BaseMiddleware:
"""
Register middleware for all events in the Router
@ -59,9 +67,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
data: dict[str, Any],
) -> Any:
current_locale = await self.get_locale(event=event, data=data) or self.i18n.default_locale
@ -74,7 +82,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
return await handler(event, data)
@abstractmethod
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
"""
Detect current user locale based on event and context.
@ -84,7 +92,6 @@ class I18nMiddleware(BaseMiddleware, ABC):
:param data:
:return:
"""
pass
class SimpleI18nMiddleware(I18nMiddleware):
@ -97,27 +104,29 @@ class SimpleI18nMiddleware(I18nMiddleware):
def __init__(
self,
i18n: I18n,
i18n_key: Optional[str] = "i18n",
i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware",
) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
if Locale is None: # pragma: no cover
raise RuntimeError(
msg = (
f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)
raise RuntimeError(msg)
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
if Locale is None: # pragma: no cover
raise RuntimeError(
msg = (
f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)
raise RuntimeError(msg)
event_from_user: Optional[User] = data.get("event_from_user", None)
event_from_user: User | None = data.get("event_from_user")
if event_from_user is None or event_from_user.language_code is None:
return self.i18n.default_locale
try:
@ -139,13 +148,13 @@ class ConstI18nMiddleware(I18nMiddleware):
self,
locale: str,
i18n: I18n,
i18n_key: Optional[str] = "i18n",
i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware",
) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
self.locale = locale
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
return self.locale
@ -158,14 +167,14 @@ class FSMI18nMiddleware(SimpleI18nMiddleware):
self,
i18n: I18n,
key: str = "locale",
i18n_key: Optional[str] = "i18n",
i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware",
) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
self.key = key
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
fsm_context: Optional[FSMContext] = data.get("state")
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
fsm_context: FSMContext | None = data.get("state")
locale = None
if fsm_context:
fsm_data = await fsm_context.get_data()

View file

@ -4,19 +4,7 @@ from abc import ABC
from copy import deepcopy
from itertools import chain
from itertools import cycle as repeat_all
from typing import (
TYPE_CHECKING,
Any,
Generator,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from aiogram.filters.callback_data import CallbackData
from aiogram.types import (
@ -34,11 +22,14 @@ from aiogram.types import (
WebAppInfo,
)
if TYPE_CHECKING:
from collections.abc import Generator, Iterable
ButtonType = TypeVar("ButtonType", InlineKeyboardButton, KeyboardButton)
T = TypeVar("T")
class KeyboardBuilder(Generic[ButtonType], ABC):
class KeyboardBuilder(ABC, Generic[ButtonType]):
"""
Generic keyboard builder that helps to adjust your markup with defined shape of lines.
@ -50,16 +41,19 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
max_buttons: int = 0
def __init__(
self, button_type: Type[ButtonType], markup: Optional[List[List[ButtonType]]] = None
self,
button_type: type[ButtonType],
markup: list[list[ButtonType]] | None = None,
) -> None:
if not issubclass(button_type, (InlineKeyboardButton, KeyboardButton)):
raise ValueError(f"Button type {button_type} are not allowed here")
self._button_type: Type[ButtonType] = button_type
msg = f"Button type {button_type} are not allowed here"
raise ValueError(msg)
self._button_type: type[ButtonType] = button_type
if markup:
self._validate_markup(markup)
else:
markup = []
self._markup: List[List[ButtonType]] = markup
self._markup: list[list[ButtonType]] = markup
@property
def buttons(self) -> Generator[ButtonType, None, None]:
@ -79,9 +73,8 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
"""
allowed = self._button_type
if not isinstance(button, allowed):
raise ValueError(
f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}"
)
msg = f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}"
raise ValueError(msg)
return True
def _validate_buttons(self, *buttons: ButtonType) -> bool:
@ -93,7 +86,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
"""
return all(map(self._validate_button, buttons))
def _validate_row(self, row: List[ButtonType]) -> bool:
def _validate_row(self, row: list[ButtonType]) -> bool:
"""
Check that row of buttons are correct
Row can be only list of allowed button types and has length 0 <= n <= 8
@ -102,16 +95,18 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
:return:
"""
if not isinstance(row, list):
raise ValueError(
msg = (
f"Row {row!r} should be type 'List[{self._button_type.__name__}]' "
f"not type {type(row).__name__}"
)
raise ValueError(msg)
if len(row) > self.max_width:
raise ValueError(f"Row {row!r} is too long (max width: {self.max_width})")
msg = f"Row {row!r} is too long (max width: {self.max_width})"
raise ValueError(msg)
self._validate_buttons(*row)
return True
def _validate_markup(self, markup: List[List[ButtonType]]) -> bool:
def _validate_markup(self, markup: list[list[ButtonType]]) -> bool:
"""
Check that passed markup has correct data structure
Markup is list of lists of buttons
@ -121,15 +116,17 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
"""
count = 0
if not isinstance(markup, list):
raise ValueError(
msg = (
f"Markup should be type 'List[List[{self._button_type.__name__}]]' "
f"not type {type(markup).__name__!r}"
)
raise ValueError(msg)
for row in markup:
self._validate_row(row)
count += len(row)
if count > self.max_buttons:
raise ValueError(f"Too much buttons detected Max allowed count - {self.max_buttons}")
msg = f"Too much buttons detected Max allowed count - {self.max_buttons}"
raise ValueError(msg)
return True
def _validate_size(self, size: Any) -> int:
@ -140,14 +137,14 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
:return:
"""
if not isinstance(size, int):
raise ValueError("Only int sizes are allowed")
msg = "Only int sizes are allowed"
raise ValueError(msg)
if size not in range(self.min_width, self.max_width + 1):
raise ValueError(
f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]"
)
msg = f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]"
raise ValueError(msg)
return size
def export(self) -> List[List[ButtonType]]:
def export(self) -> list[list[ButtonType]]:
"""
Export configured markup as list of lists of buttons
@ -161,7 +158,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
"""
return deepcopy(self._markup)
def add(self, *buttons: ButtonType) -> "KeyboardBuilder[ButtonType]":
def add(self, *buttons: ButtonType) -> KeyboardBuilder[ButtonType]:
"""
Add one or many buttons to markup.
@ -189,9 +186,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
self._markup = markup
return self
def row(
self, *buttons: ButtonType, width: Optional[int] = None
) -> "KeyboardBuilder[ButtonType]":
def row(self, *buttons: ButtonType, width: int | None = None) -> KeyboardBuilder[ButtonType]:
"""
Add row to markup
@ -211,7 +206,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
)
return self
def adjust(self, *sizes: int, repeat: bool = False) -> "KeyboardBuilder[ButtonType]":
def adjust(self, *sizes: int, repeat: bool = False) -> KeyboardBuilder[ButtonType]:
"""
Adjust previously added buttons to specific row sizes.
@ -232,7 +227,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
size = next(sizes_iter)
markup = []
row: List[ButtonType] = []
row: list[ButtonType] = []
for button in self.buttons:
if len(row) >= size:
markup.append(row)
@ -244,33 +239,35 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
self._markup = markup
return self
def _button(self, **kwargs: Any) -> "KeyboardBuilder[ButtonType]":
def _button(self, **kwargs: Any) -> KeyboardBuilder[ButtonType]:
"""
Add button to markup
:param kwargs:
:return:
"""
if isinstance(callback_data := kwargs.get("callback_data", None), CallbackData):
if isinstance(callback_data := kwargs.get("callback_data"), CallbackData):
kwargs["callback_data"] = callback_data.pack()
button = self._button_type(**kwargs)
return self.add(button)
def as_markup(self, **kwargs: Any) -> Union[InlineKeyboardMarkup, ReplyKeyboardMarkup]:
def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup | ReplyKeyboardMarkup:
if self._button_type is KeyboardButton:
keyboard = cast(List[List[KeyboardButton]], self.export()) # type: ignore
keyboard = cast(list[list[KeyboardButton]], self.export()) # type: ignore
return ReplyKeyboardMarkup(keyboard=keyboard, **kwargs)
inline_keyboard = cast(List[List[InlineKeyboardButton]], self.export()) # type: ignore
inline_keyboard = cast(list[list[InlineKeyboardButton]], self.export()) # type: ignore
return InlineKeyboardMarkup(inline_keyboard=inline_keyboard)
def attach(self, builder: "KeyboardBuilder[ButtonType]") -> "KeyboardBuilder[ButtonType]":
def attach(self, builder: KeyboardBuilder[ButtonType]) -> KeyboardBuilder[ButtonType]:
if not isinstance(builder, KeyboardBuilder):
raise ValueError(f"Only KeyboardBuilder can be attached, not {type(builder).__name__}")
msg = f"Only KeyboardBuilder can be attached, not {type(builder).__name__}"
raise ValueError(msg)
if builder._button_type is not self._button_type:
raise ValueError(
msg = (
f"Only builders with same button type can be attached, "
f"not {self._button_type.__name__} and {builder._button_type.__name__}"
)
raise ValueError(msg)
self._markup.extend(builder.export())
return self
@ -306,18 +303,18 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
self,
*,
text: str,
url: Optional[str] = None,
callback_data: Optional[Union[str, CallbackData]] = None,
web_app: Optional[WebAppInfo] = None,
login_url: Optional[LoginUrl] = None,
switch_inline_query: Optional[str] = None,
switch_inline_query_current_chat: Optional[str] = None,
switch_inline_query_chosen_chat: Optional[SwitchInlineQueryChosenChat] = None,
copy_text: Optional[CopyTextButton] = None,
callback_game: Optional[CallbackGame] = None,
pay: Optional[bool] = None,
url: str | None = None,
callback_data: str | CallbackData | None = None,
web_app: WebAppInfo | None = None,
login_url: LoginUrl | None = None,
switch_inline_query: str | None = None,
switch_inline_query_current_chat: str | None = None,
switch_inline_query_chosen_chat: SwitchInlineQueryChosenChat | None = None,
copy_text: CopyTextButton | None = None,
callback_game: CallbackGame | None = None,
pay: bool | None = None,
**kwargs: Any,
) -> "InlineKeyboardBuilder":
) -> InlineKeyboardBuilder:
return cast(
InlineKeyboardBuilder,
self._button(
@ -340,10 +337,10 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
"""Construct an InlineKeyboardMarkup"""
return cast(InlineKeyboardMarkup, super().as_markup(**kwargs))
def __init__(self, markup: Optional[List[List[InlineKeyboardButton]]] = None) -> None:
def __init__(self, markup: list[list[InlineKeyboardButton]] | None = None) -> None:
super().__init__(button_type=InlineKeyboardButton, markup=markup)
def copy(self: "InlineKeyboardBuilder") -> "InlineKeyboardBuilder":
def copy(self: InlineKeyboardBuilder) -> InlineKeyboardBuilder:
"""
Make full copy of current builder with markup
@ -353,8 +350,9 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
@classmethod
def from_markup(
cls: Type["InlineKeyboardBuilder"], markup: InlineKeyboardMarkup
) -> "InlineKeyboardBuilder":
cls: type[InlineKeyboardBuilder],
markup: InlineKeyboardMarkup,
) -> InlineKeyboardBuilder:
"""
Create builder from existing markup
@ -377,14 +375,14 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
self,
*,
text: str,
request_users: Optional[KeyboardButtonRequestUsers] = None,
request_chat: Optional[KeyboardButtonRequestChat] = None,
request_contact: Optional[bool] = None,
request_location: Optional[bool] = None,
request_poll: Optional[KeyboardButtonPollType] = None,
web_app: Optional[WebAppInfo] = None,
request_users: KeyboardButtonRequestUsers | None = None,
request_chat: KeyboardButtonRequestChat | None = None,
request_contact: bool | None = None,
request_location: bool | None = None,
request_poll: KeyboardButtonPollType | None = None,
web_app: WebAppInfo | None = None,
**kwargs: Any,
) -> "ReplyKeyboardBuilder":
) -> ReplyKeyboardBuilder:
return cast(
ReplyKeyboardBuilder,
self._button(
@ -403,10 +401,10 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
"""Construct a ReplyKeyboardMarkup"""
return cast(ReplyKeyboardMarkup, super().as_markup(**kwargs))
def __init__(self, markup: Optional[List[List[KeyboardButton]]] = None) -> None:
def __init__(self, markup: list[list[KeyboardButton]] | None = None) -> None:
super().__init__(button_type=KeyboardButton, markup=markup)
def copy(self: "ReplyKeyboardBuilder") -> "ReplyKeyboardBuilder":
def copy(self: ReplyKeyboardBuilder) -> ReplyKeyboardBuilder:
"""
Make full copy of current builder with markup
@ -415,7 +413,7 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
return ReplyKeyboardBuilder(markup=self.export())
@classmethod
def from_markup(cls, markup: ReplyKeyboardMarkup) -> "ReplyKeyboardBuilder":
def from_markup(cls, markup: ReplyKeyboardMarkup) -> ReplyKeyboardBuilder:
"""
Create builder from existing markup

View file

@ -7,7 +7,7 @@ BRANCH = "dev-3.x"
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query: Any) -> str:
def _format_url(url: str, *path: str, fragment_: str | None = None, **query: Any) -> str:
url = urljoin(url, "/".join(path), allow_fragments=True)
if query:
url += "?" + urlencode(query)
@ -16,7 +16,7 @@ def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query:
return url
def docs_url(*path: str, fragment_: Optional[str] = None, **query: Any) -> str:
def docs_url(*path: str, fragment_: str | None = None, **query: Any) -> str:
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
@ -30,7 +30,7 @@ def create_telegram_link(*path: str, **kwargs: Any) -> str:
def create_channel_bot_link(
username: str,
parameter: Optional[str] = None,
parameter: str | None = None,
change_info: bool = False,
post_messages: bool = False,
edit_messages: bool = False,

View file

@ -1,4 +1,5 @@
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any
from magic_filter import MagicFilter as _MagicFilter
from magic_filter import MagicT as _MagicT

View file

@ -137,7 +137,7 @@ def strikethrough(*content: Any, sep: str = " ") -> str:
:return:
"""
return markdown_decoration.strikethrough(
value=markdown_decoration.quote(_join(*content, sep=sep))
value=markdown_decoration.quote(_join(*content, sep=sep)),
)
@ -183,7 +183,7 @@ def blockquote(*content: Any, sep: str = "\n") -> str:
:return:
"""
return markdown_decoration.blockquote(
value=markdown_decoration.quote(_join(*content, sep=sep))
value=markdown_decoration.quote(_join(*content, sep=sep)),
)

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Optional, Union, overload
from typing import Any, Literal, overload
from aiogram.enums import InputMediaType
from aiogram.types import (
@ -12,12 +12,7 @@ from aiogram.types import (
MessageEntity,
)
MediaType = Union[
InputMediaAudio,
InputMediaPhoto,
InputMediaVideo,
InputMediaDocument,
]
MediaType = InputMediaAudio | InputMediaPhoto | InputMediaVideo | InputMediaDocument
MAX_MEDIA_GROUP_SIZE = 10
@ -27,9 +22,9 @@ class MediaGroupBuilder:
def __init__(
self,
media: Optional[List[MediaType]] = None,
caption: Optional[str] = None,
caption_entities: Optional[List[MessageEntity]] = None,
media: list[MediaType] | None = None,
caption: str | None = None,
caption_entities: list[MessageEntity] | None = None,
) -> None:
"""
Helper class for building media groups.
@ -39,7 +34,7 @@ class MediaGroupBuilder:
:param caption_entities: List of special entities in the caption,
like usernames, URLs, etc. (optional)
"""
self._media: List[MediaType] = []
self._media: list[MediaType] = []
self.caption = caption
self.caption_entities = caption_entities
@ -47,14 +42,16 @@ class MediaGroupBuilder:
def _add(self, media: MediaType) -> None:
if not isinstance(media, InputMedia):
raise ValueError("Media must be instance of InputMedia")
msg = "Media must be instance of InputMedia"
raise ValueError(msg)
if len(self._media) >= MAX_MEDIA_GROUP_SIZE:
raise ValueError("Media group can't contain more than 10 elements")
msg = "Media group can't contain more than 10 elements"
raise ValueError(msg)
self._media.append(media)
def _extend(self, media: List[MediaType]) -> None:
def _extend(self, media: list[MediaType]) -> None:
for m in media:
self._add(m)
@ -63,13 +60,13 @@ class MediaGroupBuilder:
self,
*,
type: Literal[InputMediaType.AUDIO],
media: Union[str, InputFile],
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
duration: Optional[int] = None,
performer: Optional[str] = None,
title: Optional[str] = None,
media: str | InputFile,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
duration: int | None = None,
performer: str | None = None,
title: str | None = None,
**kwargs: Any,
) -> None:
pass
@ -79,11 +76,11 @@ class MediaGroupBuilder:
self,
*,
type: Literal[InputMediaType.PHOTO],
media: Union[str, InputFile],
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
has_spoiler: Optional[bool] = None,
media: str | InputFile,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
has_spoiler: bool | None = None,
**kwargs: Any,
) -> None:
pass
@ -93,16 +90,16 @@ class MediaGroupBuilder:
self,
*,
type: Literal[InputMediaType.VIDEO],
media: Union[str, InputFile],
thumbnail: Optional[Union[InputFile, str]] = None,
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
duration: Optional[int] = None,
supports_streaming: Optional[bool] = None,
has_spoiler: Optional[bool] = None,
media: str | InputFile,
thumbnail: InputFile | str | None = None,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
width: int | None = None,
height: int | None = None,
duration: int | None = None,
supports_streaming: bool | None = None,
has_spoiler: bool | None = None,
**kwargs: Any,
) -> None:
pass
@ -112,12 +109,12 @@ class MediaGroupBuilder:
self,
*,
type: Literal[InputMediaType.DOCUMENT],
media: Union[str, InputFile],
thumbnail: Optional[Union[InputFile, str]] = None,
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
disable_content_type_detection: Optional[bool] = None,
media: str | InputFile,
thumbnail: InputFile | str | None = None,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
disable_content_type_detection: bool | None = None,
**kwargs: Any,
) -> None:
pass
@ -140,18 +137,19 @@ class MediaGroupBuilder:
elif type_ == InputMediaType.DOCUMENT:
self.add_document(**kwargs)
else:
raise ValueError(f"Unknown media type: {type_!r}")
msg = f"Unknown media type: {type_!r}"
raise ValueError(msg)
def add_audio(
self,
media: Union[str, InputFile],
thumbnail: Optional[InputFile] = None,
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
duration: Optional[int] = None,
performer: Optional[str] = None,
title: Optional[str] = None,
media: str | InputFile,
thumbnail: InputFile | None = None,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
duration: int | None = None,
performer: str | None = None,
title: str | None = None,
**kwargs: Any,
) -> None:
"""
@ -189,16 +187,16 @@ class MediaGroupBuilder:
performer=performer,
title=title,
**kwargs,
)
),
)
def add_photo(
self,
media: Union[str, InputFile],
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
has_spoiler: Optional[bool] = None,
media: str | InputFile,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
has_spoiler: bool | None = None,
**kwargs: Any,
) -> None:
"""
@ -228,21 +226,21 @@ class MediaGroupBuilder:
caption_entities=caption_entities,
has_spoiler=has_spoiler,
**kwargs,
)
),
)
def add_video(
self,
media: Union[str, InputFile],
thumbnail: Optional[InputFile] = None,
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
duration: Optional[int] = None,
supports_streaming: Optional[bool] = None,
has_spoiler: Optional[bool] = None,
media: str | InputFile,
thumbnail: InputFile | None = None,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
width: int | None = None,
height: int | None = None,
duration: int | None = None,
supports_streaming: bool | None = None,
has_spoiler: bool | None = None,
**kwargs: Any,
) -> None:
"""
@ -290,17 +288,17 @@ class MediaGroupBuilder:
supports_streaming=supports_streaming,
has_spoiler=has_spoiler,
**kwargs,
)
),
)
def add_document(
self,
media: Union[str, InputFile],
thumbnail: Optional[InputFile] = None,
caption: Optional[str] = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None,
disable_content_type_detection: Optional[bool] = None,
media: str | InputFile,
thumbnail: InputFile | None = None,
caption: str | None = None,
parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: list[MessageEntity] | None = None,
disable_content_type_detection: bool | None = None,
**kwargs: Any,
) -> None:
"""
@ -342,10 +340,10 @@ class MediaGroupBuilder:
caption_entities=caption_entities,
disable_content_type_detection=disable_content_type_detection,
**kwargs,
)
),
)
def build(self) -> List[MediaType]:
def build(self) -> list[MediaType]:
"""
Builds a list of media objects for a media group.
@ -353,7 +351,7 @@ class MediaGroupBuilder:
:return: List of media objects.
"""
update_first_media: Dict[str, Any] = {"caption": self.caption}
update_first_media: dict[str, Any] = {"caption": self.caption}
if self.caption_entities is not None:
update_first_media["caption_entities"] = self.caption_entities
update_first_media["parse_mode"] = None

View file

@ -1,18 +1,18 @@
from __future__ import annotations
import contextvars
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
if TYPE_CHECKING:
from typing_extensions import Literal
from typing import Literal
__all__ = ("ContextInstanceMixin", "DataMixin")
class DataMixin:
@property
def data(self) -> Dict[str, Any]:
data: Optional[Dict[str, Any]] = getattr(self, "_data", None)
def data(self) -> dict[str, Any]:
data: dict[str, Any] | None = getattr(self, "_data", None)
if data is None:
data = {}
setattr(self, "_data", data)
@ -30,7 +30,7 @@ class DataMixin:
def __contains__(self, key: str) -> bool:
return key in self.data
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
def get(self, key: str, default: Any | None = None) -> Any | None:
return self.data.get(key, default)
@ -44,36 +44,40 @@ class ContextInstanceMixin(Generic[ContextInstance]):
super().__init_subclass__()
cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}")
@overload # noqa: F811
@overload
@classmethod
def get_current(cls) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
def get_current(cls) -> ContextInstance | None: # pragma: no cover
...
@overload # noqa: F811
@overload
@classmethod
def get_current( # noqa: F811
cls, no_error: Literal[True]
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
def get_current(
cls,
no_error: Literal[True],
) -> ContextInstance | None: # pragma: no cover
...
@overload # noqa: F811
@overload
@classmethod
def get_current( # noqa: F811
cls, no_error: Literal[False]
) -> ContextInstance: # pragma: no cover # noqa: F811
def get_current(
cls,
no_error: Literal[False],
) -> ContextInstance: # pragma: no cover
...
@classmethod # noqa: F811
def get_current( # noqa: F811
cls, no_error: bool = True
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
@classmethod
def get_current(
cls,
no_error: bool = True,
) -> ContextInstance | None: # pragma: no cover
# on mypy 0.770 I catch that contextvars.ContextVar always contextvars.ContextVar[Any]
cls.__context_instance = cast(
contextvars.ContextVar[ContextInstance], cls.__context_instance
contextvars.ContextVar[ContextInstance],
cls.__context_instance,
)
try:
current: Optional[ContextInstance] = cls.__context_instance.get()
current: ContextInstance | None = cls.__context_instance.get()
except LookupError:
if no_error:
current = None
@ -85,9 +89,8 @@ class ContextInstanceMixin(Generic[ContextInstance]):
@classmethod
def set_current(cls, value: ContextInstance) -> contextvars.Token[ContextInstance]:
if not isinstance(value, cls):
raise TypeError(
f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
)
msg = f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
raise TypeError(msg)
return cls.__context_instance.set(value)
@classmethod

View file

@ -1,5 +1,10 @@
from __future__ import annotations
import functools
from typing import Callable, TypeVar
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from collections.abc import Callable
T = TypeVar("T")

View file

@ -61,13 +61,18 @@ Encoding and decoding with your own methods:
"""
from __future__ import annotations
from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Callable, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
def encode_payload(
payload: str,
encoder: Optional[Callable[[bytes], bytes]] = None,
encoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""Encode payload with encoder.
@ -85,7 +90,7 @@ def encode_payload(
def decode_payload(
payload: str,
decoder: Optional[Callable[[bytes], bytes]] = None,
decoder: Callable[[bytes], bytes] | None = None,
) -> str:
"""Decode URL-safe base64url payload with decoder."""
original_payload = _decode_b64(payload)

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel
@ -9,7 +9,7 @@ from aiogram.methods import TelegramMethod
from aiogram.types import InputFile
def _get_fake_bot(default: Optional[DefaultBotProperties] = None) -> Bot:
def _get_fake_bot(default: DefaultBotProperties | None = None) -> Bot:
if default is None:
default = DefaultBotProperties()
return Bot(token="42:Fake", default=default)
@ -28,12 +28,12 @@ class DeserializedTelegramObject:
"""
data: Any
files: Dict[str, InputFile]
files: dict[str, InputFile]
def deserialize_telegram_object(
obj: Any,
default: Optional[DefaultBotProperties] = None,
default: DefaultBotProperties | None = None,
include_api_method_name: bool = True,
) -> DeserializedTelegramObject:
"""
@ -55,7 +55,7 @@ def deserialize_telegram_object(
# Fake bot is needed to exclude global defaults from the object.
fake_bot = _get_fake_bot(default=default)
files: Dict[str, InputFile] = {}
files: dict[str, InputFile] = {}
prepared = fake_bot.session.prepare_value(
obj,
bot=fake_bot,
@ -70,7 +70,7 @@ def deserialize_telegram_object(
def deserialize_telegram_object_to_python(
obj: Any,
default: Optional[DefaultBotProperties] = None,
default: DefaultBotProperties | None = None,
include_api_method_name: bool = True,
) -> Any:
"""

View file

@ -3,20 +3,23 @@ from __future__ import annotations
import html
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generator, List, Optional, Pattern, cast
from typing import TYPE_CHECKING, cast
from aiogram.enums import MessageEntityType
if TYPE_CHECKING:
from collections.abc import Generator
from re import Pattern
from aiogram.types import MessageEntity
__all__ = (
"HtmlDecoration",
"MarkdownDecoration",
"TextDecoration",
"add_surrogates",
"html_decoration",
"markdown_decoration",
"add_surrogates",
"remove_surrogates",
)
@ -80,7 +83,7 @@ class TextDecoration(ABC):
# API it will be here too
return self.quote(text)
def unparse(self, text: str, entities: Optional[List[MessageEntity]] = None) -> str:
def unparse(self, text: str, entities: list[MessageEntity] | None = None) -> str:
"""
Unparse message entities
@ -92,15 +95,15 @@ class TextDecoration(ABC):
self._unparse_entities(
add_surrogates(text),
sorted(entities, key=lambda item: item.offset) if entities else [],
)
),
)
def _unparse_entities(
self,
text: bytes,
entities: List[MessageEntity],
offset: Optional[int] = None,
length: Optional[int] = None,
entities: list[MessageEntity],
offset: int | None = None,
length: int | None = None,
) -> Generator[str, None, None]:
if offset is None:
offset = 0
@ -115,7 +118,7 @@ class TextDecoration(ABC):
offset = entity.offset * 2 + entity.length * 2
sub_entities = list(
filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :])
filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :]),
)
yield self.apply_entity(
entity,

View file

@ -5,7 +5,7 @@ class TokenValidationError(Exception):
pass
@lru_cache()
@lru_cache
def validate_token(token: str) -> bool:
"""
Validate Telegram token
@ -14,9 +14,8 @@ def validate_token(token: str) -> bool:
:return:
"""
if not isinstance(token, str):
raise TokenValidationError(
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
)
msg = f"Token is invalid! It must be 'str' type instead of {type(token)} type."
raise TokenValidationError(msg)
if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces."
@ -24,12 +23,13 @@ def validate_token(token: str) -> bool:
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right):
raise TokenValidationError("Token is invalid!")
msg = "Token is invalid!"
raise TokenValidationError(msg)
return True
@lru_cache()
@lru_cache
def extract_bot_id(token: str) -> int:
"""
Extract bot ID from Telegram token

View file

@ -1,9 +1,10 @@
import hashlib
import hmac
import json
from collections.abc import Callable
from datetime import datetime
from operator import itemgetter
from typing import Any, Callable, Optional
from typing import Any
from urllib.parse import parse_qsl
from aiogram.types import TelegramObject
@ -25,9 +26,9 @@ class WebAppChat(TelegramObject):
"""Type of chat, can be either “group”, “supergroup” or “channel”"""
title: str
"""Title of the chat"""
username: Optional[str] = None
username: str | None = None
"""Username of the chat"""
photo_url: Optional[str] = None
photo_url: str | None = None
"""URL of the chats photo. The photo can be in .jpeg or .svg formats.
Only returned for Web Apps launched from the attachment menu."""
@ -44,23 +45,23 @@ class WebAppUser(TelegramObject):
and some programming languages may have difficulty/silent defects in interpreting it.
It has at most 52 significant bits, so a 64-bit integer or a double-precision float type
is safe for storing this identifier."""
is_bot: Optional[bool] = None
is_bot: bool | None = None
"""True, if this user is a bot. Returns in the receiver field only."""
first_name: str
"""First name of the user or bot."""
last_name: Optional[str] = None
last_name: str | None = None
"""Last name of the user or bot."""
username: Optional[str] = None
username: str | None = None
"""Username of the user or bot."""
language_code: Optional[str] = None
language_code: str | None = None
"""IETF language tag of the user's language. Returns in user field only."""
is_premium: Optional[bool] = None
is_premium: bool | None = None
"""True, if this user is a Telegram Premium user."""
added_to_attachment_menu: Optional[bool] = None
added_to_attachment_menu: bool | None = None
"""True, if this user added the bot to the attachment menu."""
allows_write_to_pm: Optional[bool] = None
allows_write_to_pm: bool | None = None
"""True, if this user allowed the bot to message them."""
photo_url: Optional[str] = None
photo_url: str | None = None
"""URL of the users profile photo. The photo can be in .jpeg or .svg formats.
Only returned for Web Apps launched from the attachment menu."""
@ -73,33 +74,33 @@ class WebAppInitData(TelegramObject):
Source: https://core.telegram.org/bots/webapps#webappinitdata
"""
query_id: Optional[str] = None
query_id: str | None = None
"""A unique identifier for the Web App session, required for sending messages
via the answerWebAppQuery method."""
user: Optional[WebAppUser] = None
user: WebAppUser | None = None
"""An object containing data about the current user."""
receiver: Optional[WebAppUser] = None
receiver: WebAppUser | None = None
"""An object containing data about the chat partner of the current user in the chat where
the bot was launched via the attachment menu.
Returned only for Web Apps launched via the attachment menu."""
chat: Optional[WebAppChat] = None
chat: WebAppChat | None = None
"""An object containing data about the chat where the bot was launched via the attachment menu.
Returned for supergroups, channels, and group chats only for Web Apps launched via the
attachment menu."""
chat_type: Optional[str] = None
chat_type: str | None = None
"""Type of the chat from which the Web App was opened.
Can be either sender for a private chat with the user opening the link,
private, group, supergroup, or channel.
Returned only for Web Apps launched from direct links."""
chat_instance: Optional[str] = None
chat_instance: str | None = None
"""Global identifier, uniquely corresponding to the chat from which the Web App was opened.
Returned only for Web Apps launched from a direct link."""
start_param: Optional[str] = None
start_param: str | None = None
"""The value of the startattach parameter, passed via link.
Only returned for Web Apps when launched from the attachment menu via link.
The value of the start_param parameter will also be passed in the GET-parameter
tgWebAppStartParam, so the Web App can load the correct interface right away."""
can_send_after: Optional[int] = None
can_send_after: int | None = None
"""Time in seconds, after which a message can be sent via the answerWebAppQuery method."""
auth_date: datetime
"""Unix time when the form was opened."""
@ -132,7 +133,9 @@ def check_webapp_signature(token: str, init_data: str) -> bool:
)
secret_key = hmac.new(key=b"WebAppData", msg=token.encode(), digestmod=hashlib.sha256)
calculated_hash = hmac.new(
key=secret_key.digest(), msg=data_check_string.encode(), digestmod=hashlib.sha256
key=secret_key.digest(),
msg=data_check_string.encode(),
digestmod=hashlib.sha256,
).hexdigest()
return hmac.compare_digest(calculated_hash, hash_)
@ -180,4 +183,5 @@ def safe_parse_webapp_init_data(
"""
if check_webapp_signature(token, init_data):
return parse_webapp_init_data(init_data, loads=loads)
raise ValueError("Invalid init data signature")
msg = "Invalid init data signature"
raise ValueError(msg)

View file

@ -8,13 +8,15 @@ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from .web_app import WebAppInitData, parse_webapp_init_data
PRODUCTION_PUBLIC_KEY = bytes.fromhex(
"e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d"
"e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d",
)
TEST_PUBLIC_KEY = bytes.fromhex("40055058a4ee38156a06562e52eece92a771bcd8346a8c4615cb7376eddf72ec")
def check_webapp_signature(
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY
bot_id: int,
init_data: str,
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
) -> bool:
"""
Check incoming WebApp init data signature without bot token using only bot id.
@ -49,13 +51,16 @@ def check_webapp_signature(
try:
public_key.verify(signature, message)
return True
except InvalidSignature:
return False
else:
return True
def safe_check_webapp_init_data_from_signature(
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY
bot_id: int,
init_data: str,
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
) -> WebAppInitData:
"""
Validate raw WebApp init data using only bot id and return it as WebAppInitData object
@ -67,4 +72,5 @@ def safe_check_webapp_init_data_from_signature(
"""
if check_webapp_signature(bot_id, init_data, public_key_bytes):
return parse_webapp_init_data(init_data)
raise ValueError("Invalid init data signature")
msg = "Invalid init data signature"
raise ValueError(msg)

View file

@ -2,7 +2,8 @@ import asyncio
import secrets
from abc import ABC, abstractmethod
from asyncio import Transport
from typing import Any, Awaitable, Callable, Dict, Optional, Set, Tuple, cast
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiohttp import JsonPayload, MultipartWriter, Payload, web
from aiohttp.typedefs import Handler
@ -12,9 +13,11 @@ from aiohttp.web_middlewares import middleware
from aiogram import Bot, Dispatcher, loggers
from aiogram.methods import TelegramMethod
from aiogram.methods.base import TelegramType
from aiogram.types import InputFile
from aiogram.webhook.security import IPFilter
if TYPE_CHECKING:
from aiogram.types import InputFile
def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any) -> None:
"""
@ -42,7 +45,7 @@ def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any
app.on_shutdown.append(on_shutdown)
def check_ip(ip_filter: IPFilter, request: web.Request) -> Tuple[str, bool]:
def check_ip(ip_filter: IPFilter, request: web.Request) -> tuple[str, bool]:
# Try to resolve client IP over reverse proxy
if forwarded_for := request.headers.get("X-Forwarded-For", ""):
# Get the left-most ip when there is multiple ips
@ -98,7 +101,7 @@ class BaseRequestHandler(ABC):
self.dispatcher = dispatcher
self.handle_in_background = handle_in_background
self.data = data
self._background_feed_update_tasks: Set[asyncio.Task[Any]] = set()
self._background_feed_update_tasks: set[asyncio.Task[Any]] = set()
def register(self, app: Application, /, path: str, **kwargs: Any) -> None:
"""
@ -128,13 +131,12 @@ class BaseRequestHandler(ABC):
:param request:
:return: Bot instance
"""
pass
@abstractmethod
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
pass
async def _background_feed_update(self, bot: Bot, update: Dict[str, Any]) -> None:
async def _background_feed_update(self, bot: Bot, update: dict[str, Any]) -> None:
result = await self.dispatcher.feed_raw_update(bot=bot, update=update, **self.data)
if isinstance(result, TelegramMethod):
await self.dispatcher.silent_call_request(bot=bot, result=result)
@ -142,15 +144,18 @@ class BaseRequestHandler(ABC):
async def _handle_request_background(self, bot: Bot, request: web.Request) -> web.Response:
feed_update_task = asyncio.create_task(
self._background_feed_update(
bot=bot, update=await request.json(loads=bot.session.json_loads)
)
bot=bot,
update=await request.json(loads=bot.session.json_loads),
),
)
self._background_feed_update_tasks.add(feed_update_task)
feed_update_task.add_done_callback(self._background_feed_update_tasks.discard)
return web.json_response({}, dumps=bot.session.json_dumps)
def _build_response_writer(
self, bot: Bot, result: Optional[TelegramMethod[TelegramType]]
self,
bot: Bot,
result: TelegramMethod[TelegramType] | None,
) -> Payload:
if not result:
# we need to return something "empty"
@ -166,7 +171,7 @@ class BaseRequestHandler(ABC):
payload = writer.append(result.__api_method__)
payload.set_content_disposition("form-data", name="method")
files: Dict[str, InputFile] = {}
files: dict[str, InputFile] = {}
for key, value in result.model_dump(warnings=False).items():
value = bot.session.prepare_value(value, bot=bot, files=files)
if not value:
@ -185,7 +190,7 @@ class BaseRequestHandler(ABC):
return writer
async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response:
result: Optional[TelegramMethod[Any]] = await self.dispatcher.feed_webhook_update(
result: TelegramMethod[Any] | None = await self.dispatcher.feed_webhook_update(
bot,
await request.json(loads=bot.session.json_loads),
**self.data,
@ -209,7 +214,7 @@ class SimpleRequestHandler(BaseRequestHandler):
dispatcher: Dispatcher,
bot: Bot,
handle_in_background: bool = True,
secret_token: Optional[str] = None,
secret_token: str | None = None,
**data: Any,
) -> None:
"""
@ -244,7 +249,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
self,
dispatcher: Dispatcher,
handle_in_background: bool = True,
bot_settings: Optional[Dict[str, Any]] = None,
bot_settings: dict[str, Any] | None = None,
**data: Any,
) -> None:
"""
@ -265,7 +270,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
if bot_settings is None:
bot_settings = {}
self.bot_settings = bot_settings
self.bots: Dict[str, Bot] = {}
self.bots: dict[str, Bot] = {}
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
return True
@ -283,7 +288,8 @@ class TokenBasedRequestHandler(BaseRequestHandler):
:param kwargs:
"""
if "{bot_token}" not in path:
raise ValueError("Path should contains '{bot_token}' substring")
msg = "Path should contains '{bot_token}' substring"
raise ValueError(msg)
super().register(app, path=path, **kwargs)
async def resolve_bot(self, request: web.Request) -> Bot:

View file

@ -1,5 +1,5 @@
from collections.abc import Sequence
from ipaddress import IPv4Address, IPv4Network
from typing import Optional, Sequence, Set, Union
DEFAULT_TELEGRAM_NETWORKS = [
IPv4Network("149.154.160.0/20"),
@ -8,17 +8,17 @@ DEFAULT_TELEGRAM_NETWORKS = [
class IPFilter:
def __init__(self, ips: Optional[Sequence[Union[str, IPv4Network, IPv4Address]]] = None):
self._allowed_ips: Set[IPv4Address] = set()
def __init__(self, ips: Sequence[str | IPv4Network | IPv4Address] | None = None):
self._allowed_ips: set[IPv4Address] = set()
if ips:
self.allow(*ips)
def allow(self, *ips: Union[str, IPv4Network, IPv4Address]) -> None:
def allow(self, *ips: str | IPv4Network | IPv4Address) -> None:
for ip in ips:
self.allow_ip(ip)
def allow_ip(self, ip: Union[str, IPv4Network, IPv4Address]) -> None:
def allow_ip(self, ip: str | IPv4Network | IPv4Address) -> None:
if isinstance(ip, str):
ip = IPv4Network(ip) if "/" in ip else IPv4Address(ip)
if isinstance(ip, IPv4Address):
@ -26,16 +26,17 @@ class IPFilter:
elif isinstance(ip, IPv4Network):
self._allowed_ips.update(ip.hosts())
else:
raise ValueError(f"Invalid type of ipaddress: {type(ip)} ('{ip}')")
msg = f"Invalid type of ipaddress: {type(ip)} ('{ip}')"
raise ValueError(msg)
@classmethod
def default(cls) -> "IPFilter":
return cls(DEFAULT_TELEGRAM_NETWORKS)
def check(self, ip: Union[str, IPv4Address]) -> bool:
def check(self, ip: str | IPv4Address) -> bool:
if not isinstance(ip, IPv4Address):
ip = IPv4Address(ip)
return ip in self._allowed_ips
def __contains__(self, item: Union[str, IPv4Address]) -> bool:
def __contains__(self, item: str | IPv4Address) -> bool:
return self.check(item)

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
from aiogram import Router
from aiogram.filters import Filter
@ -8,7 +8,7 @@ router = Router(name=__name__)
class HelloFilter(Filter):
def __init__(self, name: Optional[str] = None) -> None:
def __init__(self, name: str | None = None) -> None:
self.name = name
async def __call__(
@ -16,7 +16,7 @@ class HelloFilter(Filter):
message: Message,
event_from_user: User,
# Filters also can accept keyword parameters like in handlers
) -> Union[bool, Dict[str, Any]]:
) -> bool | dict[str, Any]:
if message.text.casefold() == "hello":
# Returning a dictionary that will update the context data
return {"name": event_from_user.mention_html(name=self.name)}
@ -25,6 +25,7 @@ class HelloFilter(Filter):
@router.message(HelloFilter())
async def my_handler(
message: Message, name: str # Now we can accept "name" as named parameter
message: Message,
name: str, # Now we can accept "name" as named parameter
) -> Any:
return message.answer("Hello, {name}!".format(name=name))
return message.answer(f"Hello, {name}!")

View file

@ -75,11 +75,13 @@ async def handle_set_age(message: types.Message, command: CommandObject) -> None
# To get the command arguments you can use `command.args` property.
age = command.args
if not age:
raise InvalidAge("No age provided. Please provide your age as a command argument.")
msg = "No age provided. Please provide your age as a command argument."
raise InvalidAge(msg)
# If the age is invalid, raise an exception.
if not age.isdigit():
raise InvalidAge("Age should be a number")
msg = "Age should be a number"
raise InvalidAge(msg)
# If the age is valid, send a message to the user.
age = int(age)
@ -95,7 +97,8 @@ async def handle_set_name(message: types.Message, command: CommandObject) -> Non
# To get the command arguments you can use `command.args` property.
name = command.args
if not name:
raise InvalidName("Invalid name. Please provide your name as a command argument.")
msg = "Invalid name. Please provide your name as a command argument."
raise InvalidName(msg)
# If the name is valid, send a message to the user.
await message.reply(text=f"Your name is {name}")

View file

@ -2,7 +2,7 @@ import asyncio
import logging
import sys
from os import getenv
from typing import Any, Dict
from typing import Any
from aiogram import Bot, Dispatcher, F, Router, html
from aiogram.client.default import DefaultBotProperties
@ -66,7 +66,7 @@ async def process_name(message: Message, state: FSMContext) -> None:
[
KeyboardButton(text="Yes"),
KeyboardButton(text="No"),
]
],
],
resize_keyboard=True,
),
@ -106,13 +106,13 @@ async def process_language(message: Message, state: FSMContext) -> None:
if message.text.casefold() == "python":
await message.reply(
"Python, you say? That's the language that makes my circuits light up! 😉"
"Python, you say? That's the language that makes my circuits light up! 😉",
)
await show_summary(message=message, data=data)
async def show_summary(message: Message, data: Dict[str, Any], positive: bool = True) -> None:
async def show_summary(message: Message, data: dict[str, Any], positive: bool = True) -> None:
name = data["name"]
language = data.get("language", "<something unexpected>")
text = f"I'll keep in mind that, {html.quote(name)}, "
@ -124,7 +124,7 @@ async def show_summary(message: Message, data: Dict[str, Any], positive: bool =
await message.answer(text=text, reply_markup=ReplyKeyboardRemove())
async def main():
async def main() -> None:
# Initialize Bot instance with default bot properties which will be passed to all API calls
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))

View file

@ -1,7 +1,7 @@
import logging
import sys
from os import getenv
from typing import Any, Dict, Union
from typing import Any
from aiohttp import web
from finite_state_machine import form_router
@ -34,7 +34,7 @@ REDIS_DSN = "redis://127.0.0.1:6479"
OTHER_BOTS_URL = f"{BASE_URL}{OTHER_BOTS_PATH}"
def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]:
def is_bot_token(value: str) -> bool | dict[str, Any]:
try:
validate_token(value)
except TokenValidationError:
@ -54,11 +54,11 @@ async def command_add_bot(message: Message, command: CommandObject, bot: Bot) ->
return await message.answer(f"Bot @{bot_user.username} successful added")
async def on_startup(dispatcher: Dispatcher, bot: Bot):
async def on_startup(dispatcher: Dispatcher, bot: Bot) -> None:
await bot.set_webhook(f"{BASE_URL}{MAIN_BOT_PATH}")
def main():
def main() -> None:
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
session = AiohttpSession()
bot_settings = {"session": session, "parse_mode": ParseMode.HTML}

View file

@ -14,4 +14,4 @@ class MyFilter(Filter):
@router.message(MyFilter("hello"))
async def my_handler(message: Message): ...
async def my_handler(message: Message) -> None: ...

View file

@ -263,7 +263,7 @@ quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
@quiz_router.message(Command("start"))
async def command_start(message: Message, scenes: ScenesManager):
async def command_start(message: Message, scenes: ScenesManager) -> None:
await scenes.close()
await message.answer(
"Hi! This is a quiz bot. To start the quiz, use the /quiz command.",
@ -271,7 +271,7 @@ async def command_start(message: Message, scenes: ScenesManager):
)
def create_dispatcher():
def create_dispatcher() -> Dispatcher:
# Event isolation is needed to correctly handle fast user responses
dispatcher = Dispatcher(
events_isolation=SimpleEventIsolation(),
@ -288,7 +288,7 @@ def create_dispatcher():
return dispatcher
async def main():
async def main() -> None:
dp = create_dispatcher()
bot = Bot(token=TOKEN)
await dp.start_polling(bot)

View file

@ -34,11 +34,11 @@ class CancellableScene(Scene):
"""
@on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit())
async def handle_cancel(self, message: Message):
async def handle_cancel(self, message: Message) -> None:
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
async def handle_back(self, message: Message):
async def handle_back(self, message: Message) -> None:
await message.answer("Back.")
@ -48,7 +48,7 @@ class LanguageScene(CancellableScene, state="language"):
"""
@on.message.enter()
async def on_enter(self, message: Message):
async def on_enter(self, message: Message) -> None:
await message.answer(
"What language do you prefer?",
reply_markup=ReplyKeyboardMarkup(
@ -58,14 +58,14 @@ class LanguageScene(CancellableScene, state="language"):
)
@on.message(F.text.casefold() == "python", after=After.exit())
async def process_python(self, message: Message):
async def process_python(self, message: Message) -> None:
await message.answer(
"Python, you say? That's the language that makes my circuits light up! 😉"
"Python, you say? That's the language that makes my circuits light up! 😉",
)
await self.input_language(message)
@on.message(after=After.exit())
async def input_language(self, message: Message):
async def input_language(self, message: Message) -> None:
data: FSMData = await self.wizard.get_data()
await self.show_results(message, language=message.text, **data)
@ -83,7 +83,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
"""
@on.message.enter()
async def on_enter(self, message: Message):
async def on_enter(self, message: Message) -> None:
await message.answer(
"Did you like to write bots?",
reply_markup=ReplyKeyboardMarkup(
@ -96,18 +96,18 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
)
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
async def process_like_write_bots(self, message: Message):
async def process_like_write_bots(self, message: Message) -> None:
await message.reply("Cool! I'm too!")
@on.message(F.text.casefold() == "no", after=After.exit())
async def process_dont_like_write_bots(self, message: Message):
async def process_dont_like_write_bots(self, message: Message) -> None:
await message.answer(
"Not bad not terrible.\nSee you soon.",
reply_markup=ReplyKeyboardRemove(),
)
@on.message()
async def input_like_bots(self, message: Message):
async def input_like_bots(self, message: Message) -> None:
await message.answer("I don't understand you :(")
@ -117,25 +117,25 @@ class NameScene(CancellableScene, state="name"):
"""
@on.message.enter() # Marker for handler that should be called when a user enters the scene.
async def on_enter(self, message: Message):
async def on_enter(self, message: Message) -> None:
await message.answer(
"Hi there! What's your name?",
reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True),
)
@on.callback_query.enter() # different types of updates that start the scene also supported.
async def on_enter_callback(self, callback_query: CallbackQuery):
async def on_enter_callback(self, callback_query: CallbackQuery) -> None:
await callback_query.answer()
await self.on_enter(callback_query.message)
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
async def on_leave(self, message: Message):
async def on_leave(self, message: Message) -> None:
data: FSMData = await self.wizard.get_data()
name = data.get("name", "Anonymous")
await message.answer(f"Nice to meet you, {html.quote(name)}!")
@on.message(after=After.goto(LikeBotsScene))
async def input_name(self, message: Message):
async def input_name(self, message: Message) -> None:
await self.wizard.update_data(name=message.text)
@ -154,22 +154,22 @@ class DefaultScene(
start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene))
@on.message(Command("demo"))
async def demo(self, message: Message):
async def demo(self, message: Message) -> None:
await message.answer(
"Demo started",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]]
inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]],
),
)
@on.callback_query(F.data == "start", after=After.goto(NameScene))
async def demo_callback(self, callback_query: CallbackQuery):
async def demo_callback(self, callback_query: CallbackQuery) -> None:
await callback_query.answer(cache_time=0)
await callback_query.message.delete_reply_markup()
@on.message.enter() # Mark that this handler should be called when a user enters the scene.
@on.message()
async def default_handler(self, message: Message):
async def default_handler(self, message: Message) -> None:
await message.answer(
"Start demo?\nYou can also start demo via command /demo",
reply_markup=ReplyKeyboardMarkup(

View file

@ -33,7 +33,7 @@ async def command_start_handler(message: Message) -> None:
await message.answer(
f"Hello, {hbold(message.from_user.full_name)}!",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]]
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]],
),
)
@ -43,7 +43,7 @@ async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message(
chat_member.chat.id,
f"Member {hcode(chat_member.from_user.id)} was changed "
+ f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}",
f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}",
)

View file

@ -12,7 +12,7 @@ my_router = Router()
@my_router.message(CommandStart())
async def command_start(message: Message, bot: Bot, base_url: str):
async def command_start(message: Message, bot: Bot, base_url: str) -> None:
await bot.set_chat_menu_button(
chat_id=message.chat.id,
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
@ -21,28 +21,29 @@ async def command_start(message: Message, bot: Bot, base_url: str):
@my_router.message(Command("webview"))
async def command_webview(message: Message, base_url: str):
async def command_webview(message: Message, base_url: str) -> None:
await message.answer(
"Good. Now you can try to send it via Webview",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[
[
InlineKeyboardButton(
text="Open Webview", web_app=WebAppInfo(url=f"{base_url}/demo")
)
]
]
text="Open Webview",
web_app=WebAppInfo(url=f"{base_url}/demo"),
),
],
],
),
)
@my_router.message(~F.message.via_bot) # Echo to all messages except messages via bot
async def echo_all(message: Message, base_url: str):
async def echo_all(message: Message, base_url: str) -> None:
await message.answer(
"Test webview",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[
[InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))]
]
[InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))],
],
),
)

View file

@ -18,14 +18,14 @@ TOKEN = getenv("BOT_TOKEN")
APP_BASE_URL = getenv("APP_BASE_URL")
async def on_startup(bot: Bot, base_url: str):
async def on_startup(bot: Bot, base_url: str) -> None:
await bot.set_webhook(f"{base_url}/webhook")
await bot.set_chat_menu_button(
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo"))
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
)
def main():
def main() -> None:
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
dispatcher = Dispatcher()
dispatcher["base_url"] = APP_BASE_URL

View file

@ -1,10 +1,11 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from aiohttp.web_fileresponse import FileResponse
from aiohttp.web_request import Request
from aiohttp.web_response import json_response
from aiogram import Bot
from aiogram.types import (
InlineKeyboardButton,
InlineKeyboardMarkup,
@ -14,12 +15,18 @@ from aiogram.types import (
)
from aiogram.utils.web_app import check_webapp_signature, safe_parse_webapp_init_data
if TYPE_CHECKING:
from aiohttp.web_request import Request
from aiohttp.web_response import Response
async def demo_handler(request: Request):
from aiogram import Bot
async def demo_handler(request: Request) -> FileResponse:
return FileResponse(Path(__file__).parent.resolve() / "demo.html")
async def check_data_handler(request: Request):
async def check_data_handler(request: Request) -> Response:
bot: Bot = request.app["bot"]
data = await request.post()
@ -28,7 +35,7 @@ async def check_data_handler(request: Request):
return json_response({"ok": False, "err": "Unauthorized"}, status=401)
async def send_message_handler(request: Request):
async def send_message_handler(request: Request) -> Response:
bot: Bot = request.app["bot"]
data = await request.post()
try:
@ -44,11 +51,11 @@ async def send_message_handler(request: Request):
InlineKeyboardButton(
text="Open",
web_app=WebAppInfo(
url=str(request.url.with_scheme("https").with_path("demo"))
url=str(request.url.with_scheme("https").with_path("demo")),
),
)
]
]
),
],
],
)
await bot.answer_web_app_query(
web_app_query_id=web_app_init_data.query_id,

View file

@ -15,7 +15,7 @@ def create_parser() -> ArgumentParser:
return parser
async def main():
async def main() -> None:
parser = create_parser()
ns = parser.parse_args()

View file

@ -6,7 +6,7 @@ build-backend = "hatchling.build"
name = "aiogram"
description = 'Modern and fully asynchronous framework for Telegram Bot API'
readme = "README.rst"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "MIT"
authors = [
{ name = "Alex Root Junior", email = "jroot.junior@gmail.com" },
@ -30,7 +30,6 @@ classifiers = [
"Typing :: Typed",
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@ -60,7 +59,7 @@ fast = [
"aiodns>=3.0.0",
]
redis = [
"redis[hiredis]>=5.0.1,<5.3.0",
"redis[hiredis]>=6.2.0,<7",
]
mongo = [
"motor>=3.3.2,<3.7.0",
@ -76,7 +75,7 @@ cli = [
"aiogram-cli>=1.1.0,<2.0.0",
]
signature = [
"cryptography>=43.0.0",
"cryptography>=46.0.0",
]
test = [
"pytest~=7.4.2",
@ -88,8 +87,8 @@ test = [
"pytest-cov~=4.1.0",
"pytest-aiohttp~=1.0.5",
"aresponses~=2.1.6",
"pytz~=2023.3",
"pycryptodomex~=3.19.0",
"pytz~=2025.2",
"pycryptodomex~=3.23.0",
]
docs = [
"Sphinx~=8.0.2",
@ -105,12 +104,12 @@ docs = [
"sphinxcontrib-towncrier~=0.4.0a0",
]
dev = [
"black~=24.4.2",
"isort~=5.13.2",
"ruff~=0.5.1",
"mypy~=1.10.0",
"black~=25.9.0",
"isort~=6.1.0",
"ruff~=0.13.3",
"mypy~=1.10.1",
"toml~=0.10.2",
"pre-commit~=3.5",
"pre-commit~=4.3.0",
"packaging~=24.1",
"motor-types~=1.0.0b4",
]
@ -200,9 +199,8 @@ cov-mongo = [
]
view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html"
[[tool.hatch.envs.test.matrix]]
python = ["39", "310", "311", "312", "313"]
python = ["310", "311", "312", "313"]
[tool.ruff]
line-length = 99
@ -219,7 +217,6 @@ exclude = [
"scripts",
"*.egg-info",
]
target-version = "py39"
[tool.ruff.lint]
select = [
@ -227,13 +224,13 @@ select = [
"C4",
"E",
"F",
"T10",
"T20",
"Q",
"RET",
"T10",
"T20",
]
ignore = [
"F401"
"F401",
]
[tool.ruff.lint.isort]
@ -280,7 +277,7 @@ exclude_lines = [
[tool.mypy]
plugins = "pydantic.mypy"
python_version = "3.9"
python_version = "3.10"
show_error_codes = true
show_error_context = true
pretty = true
@ -315,7 +312,7 @@ disallow_untyped_defs = true
[tool.black]
line-length = 99
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
target-version = ['py310', 'py311', 'py312', 'py313']
exclude = '''
(
\.eggs

View file

@ -1,13 +1,11 @@
version: "3.9"
services:
redis:
image: redis:6-alpine
image: redis:8-alpine
ports:
- "${REDIS_PORT-6379}:6379"
mongo:
image: mongo:7.0.6
image: mongo:8.0.14
environment:
MONGO_INITDB_ROOT_USERNAME: mongo
MONGO_INITDB_ROOT_PASSWORD: mongo

View file

@ -24,8 +24,6 @@ class TestDataclassKwargs:
@pytest.mark.parametrize(
"py_version,expected",
[
((3, 9, 0), ALL_VERSIONS),
((3, 9, 2), ALL_VERSIONS),
((3, 10, 2), PY_310),
((3, 11, 0), PY_311),
((4, 13, 0), LATEST_PY),