EOL of Py3.9 (#1726)
Some checks failed
Tests / tests (macos-latest, 3.10) (push) Has been cancelled
Tests / tests (macos-latest, 3.11) (push) Has been cancelled
Tests / tests (macos-latest, 3.12) (push) Has been cancelled
Tests / tests (macos-latest, 3.13) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / tests (windows-latest, 3.10) (push) Has been cancelled
Tests / tests (windows-latest, 3.11) (push) Has been cancelled
Tests / tests (windows-latest, 3.12) (push) Has been cancelled
Tests / tests (windows-latest, 3.13) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.11) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.11) (push) Has been cancelled

* Drop py3.9 and pypy3.9

Add pypy3.11 (testing) into `tests.yml`

Remove py3.9 from matrix in `tests.yml`

Refactor not auto-gen code to be compatible with py3.10+, droping ugly 3.9 annotation.

Replace some `from typing` imports to `from collections.abc`, due to deprecation

Add `from __future__ import annotations` and `if TYPE_CHECKING:` where possible

Add some `noqa` to calm down Ruff in some places, if Ruff will be used as default linting+formatting tool in future

Replace some relative imports to absolute

Sort `__all__` tuples in `__init__.py` and some other `.py` files

Sort `__slots__` tuples in classes

Split raises into `msg` and `raise` (`EM101`, `EM102`) to not duplicate error message in the traceback

Add `Self` from `typing_extenstion` where possible

Resolve typing problem in `aiogram/filters/command.py:18`

Concatenate nested `if` statements

Convert `HandlerContainer` into a dataclass in `aiogram/fsm/scene.py`

Bump tests docker-compose.yml `redis:6-alpine` -> `redis:8-alpine`

Bump tests docker-compose.yml `mongo:7.0.6` -> `mongo:8.0.14`

Bump pre-commit-config `black==24.4.2` -> `black==25.9.0`

Bump pre-commit-config `ruff==0.5.1` -> `ruff==0.13.3`

Update Makefile lint for ruff to show fixes

Add `make outdated` into Makefile

Use `pathlib` instead of `os.path`

Bump `redis[hiredis]>=5.0.1,<5.3.0` -> `redis[hiredis]>=6.2.0,<7`

Bump `cryptography>=43.0.0` -> `cryptography>=46.0.0` due to security reasons

Bump `pytz~=2023.3` -> `pytz~=2025.2`

Bump `pycryptodomex~=3.19.0` -> `pycryptodomex~=3.23.0` due to security reasons

Bump linting and formatting tools

* Add `1726.removal.rst`

* Update aiogram/utils/dataclass.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update aiogram/filters/callback_data.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update 1726.removal.rst

* Remove `outdated` from Makefile

* Add `__slots__` to `HandlerContainer`

* Remove unused imports

* Add `@dataclass` with `slots=True` to `HandlerContainer`

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Andrew 2025-10-06 19:19:23 +03:00 committed by GitHub
parent ab32296d07
commit df7b16d5b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 1383 additions and 1215 deletions

View file

@ -19,7 +19,7 @@ jobs:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: '0' fetch-depth: '0'
- name: Set up Python 3.10 - name: Set up Python 3.12
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.12" python-version: "3.12"

View file

@ -29,7 +29,6 @@ jobs:
- macos-latest - macos-latest
- windows-latest - windows-latest
python-version: python-version:
- "3.9"
- "3.10" - "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"
@ -111,8 +110,8 @@ jobs:
- macos-latest - macos-latest
# - windows-latest # - windows-latest
python-version: python-version:
- "pypy3.9"
- "pypy3.10" - "pypy3.10"
- "pypy3.11"
defaults: defaults:
# Windows sucks. Force use bash instead of PowerShell # Windows sucks. Force use bash instead of PowerShell

View file

@ -14,12 +14,12 @@ repos:
- id: "check-json" - id: "check-json"
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 24.4.2 rev: 25.9.0
hooks: hooks:
- id: black - id: black
files: &files '^(aiogram|tests|examples)' files: &files '^(aiogram|tests|examples)'
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.5.1' rev: 'v0.13.3'
hooks: hooks:
- id: ruff - id: ruff

7
CHANGES/1726.removal.rst Normal file
View file

@ -0,0 +1,7 @@
This PR updates the codebase following the end of life for Python 3.9.
Reference: https://devguide.python.org/versions/
- Updated type annotations to Python 3.10+ style, replacing deprecated ``List``, ``Set``, etc., with built-in ``list``, ``set``, and related types.
- Refactored code by simplifying nested ``if`` expressions.
- Updated several dependencies, including security-related upgrades.

View file

@ -39,7 +39,7 @@ install: clean
lint: lint:
isort --check-only $(code_dir) isort --check-only $(code_dir)
black --check --diff $(code_dir) black --check --diff $(code_dir)
ruff check $(package_dir) $(examples_dir) ruff check --show-fixes --preview $(package_dir) $(examples_dir)
mypy $(package_dir) mypy $(package_dir)
.PHONY: reformat .PHONY: reformat

View file

@ -35,7 +35,7 @@ aiogram
:alt: Codecov :alt: Codecov
**aiogram** is a modern and fully asynchronous framework for **aiogram** is a modern and fully asynchronous framework for
`Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.8+ using `Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.10+ using
`asyncio <https://docs.python.org/3/library/asyncio.html>`_ and `asyncio <https://docs.python.org/3/library/asyncio.html>`_ and
`aiohttp <https://github.com/aio-libs/aiohttp>`_. `aiohttp <https://github.com/aio-libs/aiohttp>`_.

View file

@ -24,18 +24,18 @@ F = MagicFilter()
flags = FlagGenerator() flags = FlagGenerator()
__all__ = ( __all__ = (
"BaseMiddleware",
"Bot",
"Dispatcher",
"F",
"Router",
"__api_version__", "__api_version__",
"__version__", "__version__",
"types",
"methods",
"enums", "enums",
"Bot", "flags",
"session",
"Dispatcher",
"Router",
"BaseMiddleware",
"F",
"html", "html",
"md", "md",
"flags", "methods",
"session",
"types",
) )

View file

@ -10,7 +10,7 @@ if TYPE_CHECKING:
class BotContextController(BaseModel): class BotContextController(BaseModel):
_bot: Optional["Bot"] = PrivateAttr() _bot: Optional["Bot"] = PrivateAttr()
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None: # noqa: PYI063
self._bot = __context.get("bot") if __context else None self._bot = __context.get("bot") if __context else None
def as_(self, bot: Optional["Bot"]) -> Self: def as_(self, bot: Optional["Bot"]) -> Self:

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from aiogram.utils.dataclass import dataclass_kwargs from aiogram.utils.dataclass import dataclass_kwargs
@ -35,25 +35,25 @@ class DefaultBotProperties:
Default bot properties. Default bot properties.
""" """
parse_mode: Optional[str] = None parse_mode: str | None = None
"""Default parse mode for messages.""" """Default parse mode for messages."""
disable_notification: Optional[bool] = None disable_notification: bool | None = None
"""Sends the message silently. Users will receive a notification with no sound.""" """Sends the message silently. Users will receive a notification with no sound."""
protect_content: Optional[bool] = None protect_content: bool | None = None
"""Protects content from copying.""" """Protects content from copying."""
allow_sending_without_reply: Optional[bool] = None allow_sending_without_reply: bool | None = None
"""Allows to send messages without reply.""" """Allows to send messages without reply."""
link_preview: Optional[LinkPreviewOptions] = None link_preview: LinkPreviewOptions | None = None
"""Link preview settings.""" """Link preview settings."""
link_preview_is_disabled: Optional[bool] = None link_preview_is_disabled: bool | None = None
"""Disables link preview.""" """Disables link preview."""
link_preview_prefer_small_media: Optional[bool] = None link_preview_prefer_small_media: bool | None = None
"""Prefer small media in link preview.""" """Prefer small media in link preview."""
link_preview_prefer_large_media: Optional[bool] = None link_preview_prefer_large_media: bool | None = None
"""Prefer large media in link preview.""" """Prefer large media in link preview."""
link_preview_show_above_text: Optional[bool] = None link_preview_show_above_text: bool | None = None
"""Show link preview above text.""" """Show link preview above text."""
show_caption_above_media: Optional[bool] = None show_caption_above_media: bool | None = None
"""Show caption above media.""" """Show caption above media."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -63,11 +63,11 @@ class DefaultBotProperties:
self.link_preview_prefer_small_media, self.link_preview_prefer_small_media,
self.link_preview_prefer_large_media, self.link_preview_prefer_large_media,
self.link_preview_show_above_text, self.link_preview_show_above_text,
) ),
) )
if has_any_link_preview_option and self.link_preview is None: if has_any_link_preview_option and self.link_preview is None:
from ..types import LinkPreviewOptions from aiogram.types import LinkPreviewOptions
self.link_preview = LinkPreviewOptions( self.link_preview = LinkPreviewOptions(
is_disabled=self.link_preview_is_disabled, is_disabled=self.link_preview_is_disabled,

View file

@ -2,45 +2,35 @@ from __future__ import annotations
import asyncio import asyncio
import ssl import ssl
from typing import ( from collections.abc import AsyncGenerator, Iterable
TYPE_CHECKING, from typing import TYPE_CHECKING, Any, cast
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import certifi import certifi
from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
from aiohttp.hdrs import USER_AGENT from aiohttp.hdrs import USER_AGENT
from aiohttp.http import SERVER_SOFTWARE from aiohttp.http import SERVER_SOFTWARE
from typing_extensions import Self
from aiogram.__meta__ import __version__ from aiogram.__meta__ import __version__
from aiogram.methods import TelegramMethod from aiogram.exceptions import TelegramNetworkError
from aiogram.methods.base import TelegramType
from ...exceptions import TelegramNetworkError
from ...methods.base import TelegramType
from ...types import InputFile
from .base import BaseSession from .base import BaseSession
if TYPE_CHECKING: if TYPE_CHECKING:
from ..bot import Bot from aiogram.client.bot import Bot
from aiogram.methods import TelegramMethod
from aiogram.types import InputFile
_ProxyBasic = Union[str, Tuple[str, BasicAuth]] _ProxyBasic = str | tuple[str, BasicAuth]
_ProxyChain = Iterable[_ProxyBasic] _ProxyChain = Iterable[_ProxyBasic]
_ProxyType = Union[_ProxyChain, _ProxyBasic] _ProxyType = _ProxyChain | _ProxyBasic
def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]: def _retrieve_basic(basic: _ProxyBasic) -> dict[str, Any]:
from aiohttp_socks.utils import parse_proxy_url from aiohttp_socks.utils import parse_proxy_url
proxy_auth: Optional[BasicAuth] = None proxy_auth: BasicAuth | None = None
if isinstance(basic, str): if isinstance(basic, str):
proxy_url = basic proxy_url = basic
@ -62,7 +52,7 @@ def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
} }
def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"], Dict[str, Any]]: def _prepare_connector(chain_or_plain: _ProxyType) -> tuple[type[TCPConnector], dict[str, Any]]:
from aiohttp_socks import ChainProxyConnector, ProxyConnector, ProxyInfo from aiohttp_socks import ChainProxyConnector, ProxyConnector, ProxyInfo
# since tuple is Iterable(compatible with _ProxyChain) object, we assume that # since tuple is Iterable(compatible with _ProxyChain) object, we assume that
@ -74,17 +64,13 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
return ProxyConnector, _retrieve_basic(chain_or_plain) return ProxyConnector, _retrieve_basic(chain_or_plain)
chain_or_plain = cast(_ProxyChain, chain_or_plain) chain_or_plain = cast(_ProxyChain, chain_or_plain)
infos: List[ProxyInfo] = [] infos: list[ProxyInfo] = [ProxyInfo(**_retrieve_basic(basic)) for basic in chain_or_plain]
for basic in chain_or_plain:
infos.append(ProxyInfo(**_retrieve_basic(basic)))
return ChainProxyConnector, {"proxy_infos": infos} return ChainProxyConnector, {"proxy_infos": infos}
class AiohttpSession(BaseSession): class AiohttpSession(BaseSession):
def __init__( def __init__(self, proxy: _ProxyType | None = None, limit: int = 100, **kwargs: Any) -> None:
self, proxy: Optional[_ProxyType] = None, limit: int = 100, **kwargs: Any
) -> None:
""" """
Client session based on aiohttp. Client session based on aiohttp.
@ -94,31 +80,32 @@ class AiohttpSession(BaseSession):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self._session: Optional[ClientSession] = None self._session: ClientSession | None = None
self._connector_type: Type[TCPConnector] = TCPConnector self._connector_type: type[TCPConnector] = TCPConnector
self._connector_init: Dict[str, Any] = { self._connector_init: dict[str, Any] = {
"ssl": ssl.create_default_context(cafile=certifi.where()), "ssl": ssl.create_default_context(cafile=certifi.where()),
"limit": limit, "limit": limit,
"ttl_dns_cache": 3600, # Workaround for https://github.com/aiogram/aiogram/issues/1500 "ttl_dns_cache": 3600, # Workaround for https://github.com/aiogram/aiogram/issues/1500
} }
self._should_reset_connector = True # flag determines connector state self._should_reset_connector = True # flag determines connector state
self._proxy: Optional[_ProxyType] = None self._proxy: _ProxyType | None = None
if proxy is not None: if proxy is not None:
try: try:
self._setup_proxy_connector(proxy) self._setup_proxy_connector(proxy)
except ImportError as exc: # pragma: no cover except ImportError as exc: # pragma: no cover
raise RuntimeError( msg = (
"In order to use aiohttp client for proxy requests, install " "In order to use aiohttp client for proxy requests, install "
"https://pypi.org/project/aiohttp-socks/" "https://pypi.org/project/aiohttp-socks/"
) from exc )
raise RuntimeError(msg) from exc
def _setup_proxy_connector(self, proxy: _ProxyType) -> None: def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
self._connector_type, self._connector_init = _prepare_connector(proxy) self._connector_type, self._connector_init = _prepare_connector(proxy)
self._proxy = proxy self._proxy = proxy
@property @property
def proxy(self) -> Optional[_ProxyType]: def proxy(self) -> _ProxyType | None:
return self._proxy return self._proxy
@proxy.setter @proxy.setter
@ -151,7 +138,7 @@ class AiohttpSession(BaseSession):
def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData: def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData:
form = FormData(quote_fields=False) form = FormData(quote_fields=False)
files: Dict[str, InputFile] = {} files: dict[str, InputFile] = {}
for key, value in method.model_dump(warnings=False).items(): for key, value in method.model_dump(warnings=False).items():
value = self.prepare_value(value, bot=bot, files=files) value = self.prepare_value(value, bot=bot, files=files)
if not value: if not value:
@ -166,7 +153,10 @@ class AiohttpSession(BaseSession):
return form return form
async def make_request( async def make_request(
self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = None self,
bot: Bot,
method: TelegramMethod[TelegramType],
timeout: int | None = None,
) -> TelegramType: ) -> TelegramType:
session = await self.create_session() session = await self.create_session()
@ -175,7 +165,9 @@ class AiohttpSession(BaseSession):
try: try:
async with session.post( async with session.post(
url, data=form, timeout=self.timeout if timeout is None else timeout url,
data=form,
timeout=self.timeout if timeout is None else timeout,
) as resp: ) as resp:
raw_result = await resp.text() raw_result = await resp.text()
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -183,14 +175,17 @@ class AiohttpSession(BaseSession):
except ClientError as e: except ClientError as e:
raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}") raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}")
response = self.check_response( response = self.check_response(
bot=bot, method=method, status_code=resp.status, content=raw_result bot=bot,
method=method,
status_code=resp.status,
content=raw_result,
) )
return cast(TelegramType, response.result) return cast(TelegramType, response.result)
async def stream_content( async def stream_content(
self, self,
url: str, url: str,
headers: Optional[Dict[str, Any]] = None, headers: dict[str, Any] | None = None,
timeout: int = 30, timeout: int = 30,
chunk_size: int = 65536, chunk_size: int = 65536,
raise_for_status: bool = True, raise_for_status: bool = True,
@ -201,11 +196,14 @@ class AiohttpSession(BaseSession):
session = await self.create_session() session = await self.create_session()
async with session.get( async with session.get(
url, timeout=timeout, headers=headers, raise_for_status=raise_for_status url,
timeout=timeout,
headers=headers,
raise_for_status=raise_for_status,
) as resp: ) as resp:
async for chunk in resp.content.iter_chunked(chunk_size): async for chunk in resp.content.iter_chunked(chunk_size):
yield chunk yield chunk
async def __aenter__(self) -> AiohttpSession: async def __aenter__(self) -> Self:
await self.create_session() await self.create_session()
return self return self

View file

@ -4,23 +4,16 @@ import abc
import datetime import datetime
import json import json
import secrets import secrets
from collections.abc import AsyncGenerator, Callable
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from types import TracebackType from typing import TYPE_CHECKING, Any, Final, cast
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Final,
Optional,
Type,
cast,
)
from pydantic import ValidationError from pydantic import ValidationError
from typing_extensions import Self
from aiogram.client.default import Default
from aiogram.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.exceptions import ( from aiogram.exceptions import (
ClientDecodeError, ClientDecodeError,
RestartingTelegram, RestartingTelegram,
@ -35,16 +28,16 @@ from aiogram.exceptions import (
TelegramServerError, TelegramServerError,
TelegramUnauthorizedError, TelegramUnauthorizedError,
) )
from aiogram.methods import Response, TelegramMethod
from aiogram.methods.base import TelegramType
from aiogram.types import InputFile, TelegramObject
from ...methods import Response, TelegramMethod
from ...methods.base import TelegramType
from ...types import InputFile, TelegramObject
from ..default import Default
from ..telegram import PRODUCTION, TelegramAPIServer
from .middlewares.manager import RequestMiddlewareManager from .middlewares.manager import RequestMiddlewareManager
if TYPE_CHECKING: if TYPE_CHECKING:
from ..bot import Bot from types import TracebackType
from aiogram.client.bot import Bot
_JsonLoads = Callable[..., Any] _JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str] _JsonDumps = Callable[..., str]
@ -81,24 +74,30 @@ class BaseSession(abc.ABC):
self.middleware = RequestMiddlewareManager() self.middleware = RequestMiddlewareManager()
def check_response( def check_response(
self, bot: Bot, method: TelegramMethod[TelegramType], status_code: int, content: str self,
bot: Bot,
method: TelegramMethod[TelegramType],
status_code: int,
content: str,
) -> Response[TelegramType]: ) -> Response[TelegramType]:
""" """
Check response status Check response status
""" """
try: try:
json_data = self.json_loads(content) json_data = self.json_loads(content)
except Exception as e: except Exception as e: # noqa: BLE001
# Handled error type can't be classified as specific error # Handled error type can't be classified as specific error
# in due to decoder can be customized and raise any exception # in due to decoder can be customized and raise any exception
raise ClientDecodeError("Failed to decode object", e, content) msg = "Failed to decode object"
raise ClientDecodeError(msg, e, content)
try: try:
response_type = Response[method.__returning__] # type: ignore response_type = Response[method.__returning__] # type: ignore
response = response_type.model_validate(json_data, context={"bot": bot}) response = response_type.model_validate(json_data, context={"bot": bot})
except ValidationError as e: except ValidationError as e:
raise ClientDecodeError("Failed to deserialize object", e, json_data) msg = "Failed to deserialize object"
raise ClientDecodeError(msg, e, json_data)
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED and response.ok: if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED and response.ok:
return response return response
@ -108,7 +107,9 @@ class BaseSession(abc.ABC):
if parameters := response.parameters: if parameters := response.parameters:
if parameters.retry_after: if parameters.retry_after:
raise TelegramRetryAfter( raise TelegramRetryAfter(
method=method, message=description, retry_after=parameters.retry_after method=method,
message=description,
retry_after=parameters.retry_after,
) )
if parameters.migrate_to_chat_id: if parameters.migrate_to_chat_id:
raise TelegramMigrateToChat( raise TelegramMigrateToChat(
@ -143,14 +144,13 @@ class BaseSession(abc.ABC):
""" """
Close client session Close client session
""" """
pass
@abc.abstractmethod @abc.abstractmethod
async def make_request( async def make_request(
self, self,
bot: Bot, bot: Bot,
method: TelegramMethod[TelegramType], method: TelegramMethod[TelegramType],
timeout: Optional[int] = None, timeout: int | None = None,
) -> TelegramType: # pragma: no cover ) -> TelegramType: # pragma: no cover
""" """
Make request to Telegram Bot API Make request to Telegram Bot API
@ -161,13 +161,12 @@ class BaseSession(abc.ABC):
:return: :return:
:raise TelegramApiError: :raise TelegramApiError:
""" """
pass
@abc.abstractmethod @abc.abstractmethod
async def stream_content( async def stream_content(
self, self,
url: str, url: str,
headers: Optional[Dict[str, Any]] = None, headers: dict[str, Any] | None = None,
timeout: int = 30, timeout: int = 30,
chunk_size: int = 65536, chunk_size: int = 65536,
raise_for_status: bool = True, raise_for_status: bool = True,
@ -181,7 +180,7 @@ class BaseSession(abc.ABC):
self, self,
value: Any, value: Any,
bot: Bot, bot: Bot,
files: Dict[str, Any], files: dict[str, Any],
_dumps_json: bool = True, _dumps_json: bool = True,
) -> Any: ) -> Any:
""" """
@ -204,7 +203,10 @@ class BaseSession(abc.ABC):
for key, item in value.items() for key, item in value.items()
if ( if (
prepared_item := self.prepare_value( prepared_item := self.prepare_value(
item, bot=bot, files=files, _dumps_json=False item,
bot=bot,
files=files,
_dumps_json=False,
) )
) )
is not None is not None
@ -218,7 +220,10 @@ class BaseSession(abc.ABC):
for item in value for item in value
if ( if (
prepared_item := self.prepare_value( prepared_item := self.prepare_value(
item, bot=bot, files=files, _dumps_json=False item,
bot=bot,
files=files,
_dumps_json=False,
) )
) )
is not None is not None
@ -227,7 +232,7 @@ class BaseSession(abc.ABC):
return self.json_dumps(value) return self.json_dumps(value)
return value return value
if isinstance(value, datetime.timedelta): if isinstance(value, datetime.timedelta):
now = datetime.datetime.now() now = datetime.datetime.now() # noqa: DTZ005
return str(round((now + value).timestamp())) return str(round((now + value).timestamp()))
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return str(round(value.timestamp())) return str(round(value.timestamp()))
@ -248,18 +253,18 @@ class BaseSession(abc.ABC):
self, self,
bot: Bot, bot: Bot,
method: TelegramMethod[TelegramType], method: TelegramMethod[TelegramType],
timeout: Optional[int] = None, timeout: int | None = None,
) -> TelegramType: ) -> TelegramType:
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout) middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
return cast(TelegramType, await middleware(bot, method)) return cast(TelegramType, await middleware(bot, method))
async def __aenter__(self) -> BaseSession: async def __aenter__(self) -> Self:
return self return self
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_value: Optional[BaseException], exc_value: BaseException | None,
traceback: Optional[TracebackType], traceback: TracebackType | None,
) -> None: ) -> None:
await self.close() await self.close()

View file

@ -3,17 +3,17 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol from typing import TYPE_CHECKING, Protocol
from aiogram.methods import Response, TelegramMethod
from aiogram.methods.base import TelegramType from aiogram.methods.base import TelegramType
if TYPE_CHECKING: if TYPE_CHECKING:
from ...bot import Bot from aiogram.client.bot import Bot
from aiogram.methods import Response, TelegramMethod
class NextRequestMiddlewareType(Protocol[TelegramType]): # pragma: no cover class NextRequestMiddlewareType(Protocol[TelegramType]): # pragma: no cover
async def __call__( async def __call__(
self, self,
bot: "Bot", bot: Bot,
method: TelegramMethod[TelegramType], method: TelegramMethod[TelegramType],
) -> Response[TelegramType]: ) -> Response[TelegramType]:
pass pass
@ -23,7 +23,7 @@ class RequestMiddlewareType(Protocol): # pragma: no cover
async def __call__( async def __call__(
self, self,
make_request: NextRequestMiddlewareType[TelegramType], make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot", bot: Bot,
method: TelegramMethod[TelegramType], method: TelegramMethod[TelegramType],
) -> Response[TelegramType]: ) -> Response[TelegramType]:
pass pass
@ -38,7 +38,7 @@ class BaseRequestMiddleware(ABC):
async def __call__( async def __call__(
self, self,
make_request: NextRequestMiddlewareType[TelegramType], make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot", bot: Bot,
method: TelegramMethod[TelegramType], method: TelegramMethod[TelegramType],
) -> Response[TelegramType]: ) -> Response[TelegramType]:
""" """
@ -50,4 +50,3 @@ class BaseRequestMiddleware(ABC):
:return: :class:`aiogram.methods.Response` :return: :class:`aiogram.methods.Response`
""" """
pass

View file

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload from typing import Any, cast, overload
from aiogram.client.session.middlewares.base import ( from aiogram.client.session.middlewares.base import (
NextRequestMiddlewareType, NextRequestMiddlewareType,
@ -12,7 +13,7 @@ from aiogram.methods.base import TelegramType
class RequestMiddlewareManager(Sequence[RequestMiddlewareType]): class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
def __init__(self) -> None: def __init__(self) -> None:
self._middlewares: List[RequestMiddlewareType] = [] self._middlewares: list[RequestMiddlewareType] = []
def register( def register(
self, self,
@ -26,11 +27,8 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
def __call__( def __call__(
self, self,
middleware: Optional[RequestMiddlewareType] = None, middleware: RequestMiddlewareType | None = None,
) -> Union[ ) -> Callable[[RequestMiddlewareType], RequestMiddlewareType] | RequestMiddlewareType:
Callable[[RequestMiddlewareType], RequestMiddlewareType],
RequestMiddlewareType,
]:
if middleware is None: if middleware is None:
return self.register return self.register
return self.register(middleware) return self.register(middleware)
@ -44,8 +42,9 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
pass pass
def __getitem__( def __getitem__(
self, item: Union[int, slice] self,
) -> Union[RequestMiddlewareType, Sequence[RequestMiddlewareType]]: item: int | slice,
) -> RequestMiddlewareType | Sequence[RequestMiddlewareType]:
return self._middlewares[item] return self._middlewares[item]
def __len__(self) -> int: def __len__(self) -> int:

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import TYPE_CHECKING, Any, List, Optional, Type from typing import TYPE_CHECKING, Any
from aiogram import loggers from aiogram import loggers
from aiogram.methods import TelegramMethod from aiogram.methods import TelegramMethod
@ -8,19 +8,19 @@ from aiogram.methods.base import Response, TelegramType
from .base import BaseRequestMiddleware, NextRequestMiddlewareType from .base import BaseRequestMiddleware, NextRequestMiddlewareType
if TYPE_CHECKING: if TYPE_CHECKING:
from ...bot import Bot from aiogram.client.bot import Bot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RequestLogging(BaseRequestMiddleware): class RequestLogging(BaseRequestMiddleware):
def __init__(self, ignore_methods: Optional[List[Type[TelegramMethod[Any]]]] = None): def __init__(self, ignore_methods: list[type[TelegramMethod[Any]]] | None = None):
""" """
Middleware for logging outgoing requests Middleware for logging outgoing requests
:param ignore_methods: methods to ignore in logging middleware :param ignore_methods: methods to ignore in logging middleware
""" """
self.ignore_methods = ignore_methods if ignore_methods else [] self.ignore_methods = ignore_methods or []
async def __call__( async def __call__(
self, self,

View file

@ -1,24 +1,24 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any
class FilesPathWrapper(ABC): class FilesPathWrapper(ABC):
@abstractmethod @abstractmethod
def to_local(self, path: Union[Path, str]) -> Union[Path, str]: def to_local(self, path: Path | str) -> Path | str:
pass pass
@abstractmethod @abstractmethod
def to_server(self, path: Union[Path, str]) -> Union[Path, str]: def to_server(self, path: Path | str) -> Path | str:
pass pass
class BareFilesPathWrapper(FilesPathWrapper): class BareFilesPathWrapper(FilesPathWrapper):
def to_local(self, path: Union[Path, str]) -> Union[Path, str]: def to_local(self, path: Path | str) -> Path | str:
return path return path
def to_server(self, path: Union[Path, str]) -> Union[Path, str]: def to_server(self, path: Path | str) -> Path | str:
return path return path
@ -29,15 +29,18 @@ class SimpleFilesPathWrapper(FilesPathWrapper):
@classmethod @classmethod
def _resolve( def _resolve(
cls, base1: Union[Path, str], base2: Union[Path, str], value: Union[Path, str] cls,
base1: Path | str,
base2: Path | str,
value: Path | str,
) -> Path: ) -> Path:
relative = Path(value).relative_to(base1) relative = Path(value).relative_to(base1)
return base2 / relative return base2 / relative
def to_local(self, path: Union[Path, str]) -> Union[Path, str]: def to_local(self, path: Path | str) -> Path | str:
return self._resolve(base1=self.server_path, base2=self.local_path, value=path) return self._resolve(base1=self.server_path, base2=self.local_path, value=path)
def to_server(self, path: Union[Path, str]) -> Union[Path, str]: def to_server(self, path: Path | str) -> Path | str:
return self._resolve(base1=self.local_path, base2=self.server_path, value=path) return self._resolve(base1=self.local_path, base2=self.server_path, value=path)
@ -54,7 +57,7 @@ class TelegramAPIServer:
is_local: bool = False is_local: bool = False
"""Mark this server is """Mark this server is
in `local mode <https://core.telegram.org/bots/api#using-a-local-bot-api-server>`_.""" in `local mode <https://core.telegram.org/bots/api#using-a-local-bot-api-server>`_."""
wrap_local_file: FilesPathWrapper = BareFilesPathWrapper() wrap_local_file: FilesPathWrapper = field(default=BareFilesPathWrapper())
"""Callback to wrap files path in local mode""" """Callback to wrap files path in local mode"""
def api_url(self, token: str, method: str) -> str: def api_url(self, token: str, method: str) -> str:
@ -67,7 +70,7 @@ class TelegramAPIServer:
""" """
return self.base.format(token=token, method=method) return self.base.format(token=token, method=method)
def file_url(self, token: str, path: Union[str, Path]) -> str: def file_url(self, token: str, path: str | Path) -> str:
""" """
Generate URL for downloading files Generate URL for downloading files

View file

@ -5,28 +5,32 @@ import contextvars
import signal import signal
import warnings import warnings
from asyncio import CancelledError, Event, Future, Lock from asyncio import CancelledError, Event, Future, Lock
from collections.abc import AsyncGenerator, Awaitable
from contextlib import suppress from contextlib import suppress
from typing import Any, AsyncGenerator, Awaitable, Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Any
from aiogram import loggers
from aiogram.exceptions import TelegramAPIError
from aiogram.fsm.middleware import FSMContextMiddleware
from aiogram.fsm.storage.base import BaseEventIsolation, BaseStorage
from aiogram.fsm.storage.memory import DisabledEventIsolation, MemoryStorage
from aiogram.fsm.strategy import FSMStrategy
from aiogram.methods import GetUpdates, TelegramMethod
from aiogram.types import Update, User
from aiogram.types.base import UNSET, UNSET_TYPE
from aiogram.types.update import UpdateTypeLookupError
from aiogram.utils.backoff import Backoff, BackoffConfig
from .. import loggers
from ..client.bot import Bot
from ..exceptions import TelegramAPIError
from ..fsm.middleware import FSMContextMiddleware
from ..fsm.storage.base import BaseEventIsolation, BaseStorage
from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage
from ..fsm.strategy import FSMStrategy
from ..methods import GetUpdates, TelegramMethod
from ..methods.base import TelegramType
from ..types import Update, User
from ..types.base import UNSET, UNSET_TYPE
from ..types.update import UpdateTypeLookupError
from ..utils.backoff import Backoff, BackoffConfig
from .event.bases import UNHANDLED, SkipHandler from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver from .event.telegram import TelegramEventObserver
from .middlewares.error import ErrorsMiddleware from .middlewares.error import ErrorsMiddleware
from .middlewares.user_context import UserContextMiddleware from .middlewares.user_context import UserContextMiddleware
from .router import Router from .router import Router
if TYPE_CHECKING:
from aiogram.client.bot import Bot
from aiogram.methods.base import TelegramType
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1) DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
@ -38,11 +42,11 @@ class Dispatcher(Router):
def __init__( def __init__(
self, self,
*, # * - Preventing to pass instance of Bot to the FSM storage *, # * - Preventing to pass instance of Bot to the FSM storage
storage: Optional[BaseStorage] = None, storage: BaseStorage | None = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
events_isolation: Optional[BaseEventIsolation] = None, events_isolation: BaseEventIsolation | None = None,
disable_fsm: bool = False, disable_fsm: bool = False,
name: Optional[str] = None, name: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -55,18 +59,18 @@ class Dispatcher(Router):
then you should not use storage and events isolation then you should not use storage and events isolation
:param kwargs: Other arguments, will be passed as keyword arguments to handlers :param kwargs: Other arguments, will be passed as keyword arguments to handlers
""" """
super(Dispatcher, self).__init__(name=name) super().__init__(name=name)
if storage and not isinstance(storage, BaseStorage): if storage and not isinstance(storage, BaseStorage):
raise TypeError( msg = f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}"
f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}" raise TypeError(msg)
)
# Telegram API provides originally only one event type - Update # Telegram API provides originally only one event type - Update
# For making easily interactions with events here is registered handler which helps # For making easily interactions with events here is registered handler which helps
# to separate Update to different event types like Message, CallbackQuery etc. # to separate Update to different event types like Message, CallbackQuery etc.
self.update = self.observers["update"] = TelegramEventObserver( self.update = self.observers["update"] = TelegramEventObserver(
router=self, event_name="update" router=self,
event_name="update",
) )
self.update.register(self._listen_update) self.update.register(self._listen_update)
@ -91,11 +95,11 @@ class Dispatcher(Router):
self.update.outer_middleware(self.fsm) self.update.outer_middleware(self.fsm)
self.shutdown.register(self.fsm.close) self.shutdown.register(self.fsm.close)
self.workflow_data: Dict[str, Any] = kwargs self.workflow_data: dict[str, Any] = kwargs
self._running_lock = Lock() self._running_lock = Lock()
self._stop_signal: Optional[Event] = None self._stop_signal: Event | None = None
self._stopped_signal: Optional[Event] = None self._stopped_signal: Event | None = None
self._handle_update_tasks: Set[asyncio.Task[Any]] = set() self._handle_update_tasks: set[asyncio.Task[Any]] = set()
def __getitem__(self, item: str) -> Any: def __getitem__(self, item: str) -> Any:
return self.workflow_data[item] return self.workflow_data[item]
@ -106,7 +110,7 @@ class Dispatcher(Router):
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
del self.workflow_data[key] del self.workflow_data[key]
def get(self, key: str, /, default: Optional[Any] = None) -> Optional[Any]: def get(self, key: str, /, default: Any | None = None) -> Any | None:
return self.workflow_data.get(key, default) return self.workflow_data.get(key, default)
@property @property
@ -114,13 +118,13 @@ class Dispatcher(Router):
return self.fsm.storage return self.fsm.storage
@property @property
def parent_router(self) -> Optional[Router]: def parent_router(self) -> Router | None:
""" """
Dispatcher has no parent router and can't be included to any other routers or dispatchers Dispatcher has no parent router and can't be included to any other routers or dispatchers
:return: :return:
""" """
return None # noqa: RET501 return None
@parent_router.setter @parent_router.setter
def parent_router(self, value: Router) -> None: def parent_router(self, value: Router) -> None:
@ -130,7 +134,8 @@ class Dispatcher(Router):
:param value: :param value:
:return: :return:
""" """
raise RuntimeError("Dispatcher can not be attached to another Router.") msg = "Dispatcher can not be attached to another Router."
raise RuntimeError(msg)
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
""" """
@ -177,7 +182,7 @@ class Dispatcher(Router):
bot.id, bot.id,
) )
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any: async def feed_raw_update(self, bot: Bot, update: dict[str, Any], **kwargs: Any) -> Any:
""" """
Main entry point for incoming updates with automatic Dict->Update serializer Main entry point for incoming updates with automatic Dict->Update serializer
@ -194,7 +199,7 @@ class Dispatcher(Router):
bot: Bot, bot: Bot,
polling_timeout: int = 30, polling_timeout: int = 30,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None, allowed_updates: list[str] | None = None,
) -> AsyncGenerator[Update, None]: ) -> AsyncGenerator[Update, None]:
""" """
Endless updates reader with correctly handling any server-side or connection errors. Endless updates reader with correctly handling any server-side or connection errors.
@ -212,7 +217,7 @@ class Dispatcher(Router):
while True: while True:
try: try:
updates = await bot(get_updates, **kwargs) updates = await bot(get_updates, **kwargs)
except Exception as e: except Exception as e: # noqa: BLE001
failed = True failed = True
# In cases when Telegram Bot API was inaccessible don't need to stop polling # In cases when Telegram Bot API was inaccessible don't need to stop polling
# process because some developers can't make auto-restarting of the script # process because some developers can't make auto-restarting of the script
@ -268,6 +273,7 @@ class Dispatcher(Router):
"installed not latest version of aiogram framework" "installed not latest version of aiogram framework"
f"\nUpdate: {update.model_dump_json(exclude_unset=True)}", f"\nUpdate: {update.model_dump_json(exclude_unset=True)}",
RuntimeWarning, RuntimeWarning,
stacklevel=2,
) )
raise SkipHandler() from e raise SkipHandler() from e
@ -294,7 +300,11 @@ class Dispatcher(Router):
loggers.event.error("Failed to make answer: %s: %s", e.__class__.__name__, e) loggers.event.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
async def _process_update( async def _process_update(
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any self,
bot: Bot,
update: Update,
call_answer: bool = True,
**kwargs: Any,
) -> bool: ) -> bool:
""" """
Propagate update to event listeners Propagate update to event listeners
@ -309,9 +319,8 @@ class Dispatcher(Router):
response = await self.feed_update(bot, update, **kwargs) response = await self.feed_update(bot, update, **kwargs)
if call_answer and isinstance(response, TelegramMethod): if call_answer and isinstance(response, TelegramMethod):
await self.silent_call_request(bot=bot, result=response) await self.silent_call_request(bot=bot, result=response)
return response is not UNHANDLED
except Exception as e: except Exception as e: # noqa: BLE001
loggers.event.exception( loggers.event.exception(
"Cause exception while process update id=%d by bot id=%d\n%s: %s", "Cause exception while process update id=%d by bot id=%d\n%s: %s",
update.update_id, update.update_id,
@ -321,8 +330,13 @@ class Dispatcher(Router):
) )
return True # because update was processed but unsuccessful return True # because update was processed but unsuccessful
else:
return response is not UNHANDLED
async def _process_with_semaphore( async def _process_with_semaphore(
self, handle_update: Awaitable[bool], semaphore: asyncio.Semaphore self,
handle_update: Awaitable[bool],
semaphore: asyncio.Semaphore,
) -> bool: ) -> bool:
""" """
Process update with semaphore to limit concurrent tasks Process update with semaphore to limit concurrent tasks
@ -342,8 +356,8 @@ class Dispatcher(Router):
polling_timeout: int = 30, polling_timeout: int = 30,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None, allowed_updates: list[str] | None = None,
tasks_concurrency_limit: Optional[int] = None, tasks_concurrency_limit: int | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -361,7 +375,10 @@ class Dispatcher(Router):
""" """
user: User = await bot.me() user: User = await bot.me()
loggers.dispatcher.info( loggers.dispatcher.info(
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name "Run polling for bot @%s id=%d - %r",
user.username,
bot.id,
user.full_name,
) )
# Create semaphore if tasks_concurrency_limit is specified # Create semaphore if tasks_concurrency_limit is specified
@ -382,7 +399,7 @@ class Dispatcher(Router):
# Use semaphore to limit concurrent tasks # Use semaphore to limit concurrent tasks
await semaphore.acquire() await semaphore.acquire()
handle_update_task = asyncio.create_task( handle_update_task = asyncio.create_task(
self._process_with_semaphore(handle_update, semaphore) self._process_with_semaphore(handle_update, semaphore),
) )
else: else:
handle_update_task = asyncio.create_task(handle_update) handle_update_task = asyncio.create_task(handle_update)
@ -393,7 +410,10 @@ class Dispatcher(Router):
await handle_update await handle_update
finally: finally:
loggers.dispatcher.info( loggers.dispatcher.info(
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name "Polling stopped for bot @%s id=%d - %r",
user.username,
bot.id,
user.full_name,
) )
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
@ -413,8 +433,12 @@ class Dispatcher(Router):
raise raise
async def feed_webhook_update( async def feed_webhook_update(
self, bot: Bot, update: Union[Update, Dict[str, Any]], _timeout: float = 55, **kwargs: Any self,
) -> Optional[TelegramMethod[TelegramType]]: bot: Bot,
update: Update | dict[str, Any],
_timeout: float = 55,
**kwargs: Any,
) -> TelegramMethod[TelegramType] | None:
if not isinstance(update, Update): # Allow to use raw updates if not isinstance(update, Update): # Allow to use raw updates
update = Update.model_validate(update, context={"bot": bot}) update = Update.model_validate(update, context={"bot": bot})
@ -429,7 +453,7 @@ class Dispatcher(Router):
timeout_handle = loop.call_later(_timeout, release_waiter) timeout_handle = loop.call_later(_timeout, release_waiter)
process_updates: Future[Any] = asyncio.ensure_future( process_updates: Future[Any] = asyncio.ensure_future(
self._feed_webhook_update(bot=bot, update=update, **kwargs) self._feed_webhook_update(bot=bot, update=update, **kwargs),
) )
process_updates.add_done_callback(release_waiter, context=ctx) process_updates.add_done_callback(release_waiter, context=ctx)
@ -440,11 +464,9 @@ class Dispatcher(Router):
"For preventing this situation response into webhook returned immediately " "For preventing this situation response into webhook returned immediately "
"and handler is moved to background and still processing update.", "and handler is moved to background and still processing update.",
RuntimeWarning, RuntimeWarning,
stacklevel=2,
) )
try: result = task.result()
result = task.result()
except Exception as e:
raise e
if isinstance(result, TelegramMethod): if isinstance(result, TelegramMethod):
asyncio.ensure_future(self.silent_call_request(bot=bot, result=result)) asyncio.ensure_future(self.silent_call_request(bot=bot, result=result))
@ -478,7 +500,8 @@ class Dispatcher(Router):
:return: :return:
""" """
if not self._running_lock.locked(): if not self._running_lock.locked():
raise RuntimeError("Polling is not started") msg = "Polling is not started"
raise RuntimeError(msg)
if not self._stop_signal or not self._stopped_signal: if not self._stop_signal or not self._stopped_signal:
return return
self._stop_signal.set() self._stop_signal.set()
@ -499,10 +522,10 @@ class Dispatcher(Router):
polling_timeout: int = 10, polling_timeout: int = 10,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET, allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
handle_signals: bool = True, handle_signals: bool = True,
close_bot_session: bool = True, close_bot_session: bool = True,
tasks_concurrency_limit: Optional[int] = None, tasks_concurrency_limit: int | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -522,12 +545,14 @@ class Dispatcher(Router):
:return: :return:
""" """
if not bots: if not bots:
raise ValueError("At least one bot instance is required to start polling") msg = "At least one bot instance is required to start polling"
raise ValueError(msg)
if "bot" in kwargs: if "bot" in kwargs:
raise ValueError( msg = (
"Keyword argument 'bot' is not acceptable, " "Keyword argument 'bot' is not acceptable, "
"the bot instance should be passed as positional argument" "the bot instance should be passed as positional argument"
) )
raise ValueError(msg)
async with self._running_lock: # Prevent to run this method twice at a once async with self._running_lock: # Prevent to run this method twice at a once
if self._stop_signal is None: if self._stop_signal is None:
@ -547,10 +572,14 @@ class Dispatcher(Router):
# Signals handling is not supported on Windows # Signals handling is not supported on Windows
# It also can't be covered on Windows # It also can't be covered on Windows
loop.add_signal_handler( loop.add_signal_handler(
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM signal.SIGTERM,
self._signal_stop_polling,
signal.SIGTERM,
) )
loop.add_signal_handler( loop.add_signal_handler(
signal.SIGINT, self._signal_stop_polling, signal.SIGINT signal.SIGINT,
self._signal_stop_polling,
signal.SIGINT,
) )
workflow_data = { workflow_data = {
@ -565,7 +594,7 @@ class Dispatcher(Router):
await self.emit_startup(bot=bots[-1], **workflow_data) await self.emit_startup(bot=bots[-1], **workflow_data)
loggers.dispatcher.info("Start polling") loggers.dispatcher.info("Start polling")
try: try:
tasks: List[asyncio.Task[Any]] = [ tasks: list[asyncio.Task[Any]] = [
asyncio.create_task( asyncio.create_task(
self._polling( self._polling(
bot=bot, bot=bot,
@ -575,7 +604,7 @@ class Dispatcher(Router):
allowed_updates=allowed_updates, allowed_updates=allowed_updates,
tasks_concurrency_limit=tasks_concurrency_limit, tasks_concurrency_limit=tasks_concurrency_limit,
**workflow_data, **workflow_data,
) ),
) )
for bot in bots for bot in bots
] ]
@ -605,10 +634,10 @@ class Dispatcher(Router):
polling_timeout: int = 10, polling_timeout: int = 10,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET, allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
handle_signals: bool = True, handle_signals: bool = True,
close_bot_session: bool = True, close_bot_session: bool = True,
tasks_concurrency_limit: Optional[int] = None, tasks_concurrency_limit: int | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -638,5 +667,5 @@ class Dispatcher(Router):
handle_signals=handle_signals, handle_signals=handle_signals,
close_bot_session=close_bot_session, close_bot_session=close_bot_session,
tasks_concurrency_limit=tasks_concurrency_limit, tasks_concurrency_limit=tasks_concurrency_limit,
) ),
) )

View file

@ -1,20 +1,22 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, TypeVar, Union from collections.abc import Awaitable, Callable
from typing import Any, NoReturn, TypeVar
from unittest.mock import sentinel from unittest.mock import sentinel
from ...types import TelegramObject from aiogram.dispatcher.middlewares.base import BaseMiddleware
from ..middlewares.base import BaseMiddleware from aiogram.types import TelegramObject
MiddlewareEventType = TypeVar("MiddlewareEventType", bound=TelegramObject) MiddlewareEventType = TypeVar("MiddlewareEventType", bound=TelegramObject)
NextMiddlewareType = Callable[[MiddlewareEventType, Dict[str, Any]], Awaitable[Any]] NextMiddlewareType = Callable[[MiddlewareEventType, dict[str, Any]], Awaitable[Any]]
MiddlewareType = Union[ MiddlewareType = (
BaseMiddleware, BaseMiddleware
Callable[ | Callable[
[NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, Dict[str, Any]], [NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, dict[str, Any]],
Awaitable[Any], Awaitable[Any],
], ]
] )
UNHANDLED = sentinel.UNHANDLED UNHANDLED = sentinel.UNHANDLED
REJECTED = sentinel.REJECTED REJECTED = sentinel.REJECTED
@ -28,7 +30,7 @@ class CancelHandler(Exception):
pass pass
def skip(message: Optional[str] = None) -> NoReturn: def skip(message: str | None = None) -> NoReturn:
""" """
Raise an SkipHandler Raise an SkipHandler
""" """

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, List from collections.abc import Callable
from typing import Any
from .handler import CallbackType, HandlerObject from .handler import CallbackType, HandlerObject
@ -25,7 +26,7 @@ class EventObserver:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.handlers: List[HandlerObject] = [] self.handlers: list[HandlerObject] = []
def register(self, callback: CallbackType) -> None: def register(self, callback: CallbackType) -> None:
""" """

View file

@ -1,10 +1,10 @@
import asyncio import asyncio
import contextvars
import inspect import inspect
import warnings import warnings
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple from typing import Any
from magic_filter.magic import MagicFilter as OriginalMagicFilter from magic_filter.magic import MagicFilter as OriginalMagicFilter
@ -21,7 +21,7 @@ CallbackType = Callable[..., Any]
class CallableObject: class CallableObject:
callback: CallbackType callback: CallbackType
awaitable: bool = field(init=False) awaitable: bool = field(init=False)
params: Set[str] = field(init=False) params: set[str] = field(init=False)
varkw: bool = field(init=False) varkw: bool = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -31,7 +31,7 @@ class CallableObject:
self.params = {*spec.args, *spec.kwonlyargs} self.params = {*spec.args, *spec.kwonlyargs}
self.varkw = spec.varkw is not None self.varkw = spec.varkw is not None
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.varkw: if self.varkw:
return kwargs return kwargs
@ -46,7 +46,7 @@ class CallableObject:
@dataclass @dataclass
class FilterObject(CallableObject): class FilterObject(CallableObject):
magic: Optional[MagicFilter] = None magic: MagicFilter | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
if isinstance(self.callback, OriginalMagicFilter): if isinstance(self.callback, OriginalMagicFilter):
@ -65,7 +65,7 @@ class FilterObject(CallableObject):
stacklevel=6, stacklevel=6,
) )
super(FilterObject, self).__post_init__() super().__post_init__()
if isinstance(self.callback, Filter): if isinstance(self.callback, Filter):
self.awaitable = True self.awaitable = True
@ -73,17 +73,17 @@ class FilterObject(CallableObject):
@dataclass @dataclass
class HandlerObject(CallableObject): class HandlerObject(CallableObject):
filters: Optional[List[FilterObject]] = None filters: list[FilterObject] | None = None
flags: Dict[str, Any] = field(default_factory=dict) flags: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
super(HandlerObject, self).__post_init__() super().__post_init__()
callback = inspect.unwrap(self.callback) callback = inspect.unwrap(self.callback)
if inspect.isclass(callback) and issubclass(callback, BaseHandler): if inspect.isclass(callback) and issubclass(callback, BaseHandler):
self.awaitable = True self.awaitable = True
self.flags.update(extract_flags_from_object(callback)) self.flags.update(extract_flags_from_object(callback))
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: async def check(self, *args: Any, **kwargs: Any) -> tuple[bool, dict[str, Any]]:
if not self.filters: if not self.filters:
return True, kwargs return True, kwargs
for event_filter in self.filters: for event_filter in self.filters:

View file

@ -1,17 +1,18 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from aiogram.dispatcher.middlewares.manager import MiddlewareManager from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.exceptions import UnsupportedKeywordArgument
from aiogram.filters.base import Filter
from ...exceptions import UnsupportedKeywordArgument
from ...filters.base import Filter
from ...types import TelegramObject
from .bases import UNHANDLED, MiddlewareType, SkipHandler from .bases import UNHANDLED, MiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, HandlerObject from .handler import CallbackType, FilterObject, HandlerObject
if TYPE_CHECKING: if TYPE_CHECKING:
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.types import TelegramObject
class TelegramEventObserver: class TelegramEventObserver:
@ -26,7 +27,7 @@ class TelegramEventObserver:
self.router: Router = router self.router: Router = router
self.event_name: str = event_name self.event_name: str = event_name
self.handlers: List[HandlerObject] = [] self.handlers: list[HandlerObject] = []
self.middleware = MiddlewareManager() self.middleware = MiddlewareManager()
self.outer_middleware = MiddlewareManager() self.outer_middleware = MiddlewareManager()
@ -45,8 +46,8 @@ class TelegramEventObserver:
self._handler.filters = [] self._handler.filters = []
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters]) self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]: def _resolve_middlewares(self) -> list[MiddlewareType[TelegramObject]]:
middlewares: List[MiddlewareType[TelegramObject]] = [] middlewares: list[MiddlewareType[TelegramObject]] = []
for router in reversed(tuple(self.router.chain_head)): for router in reversed(tuple(self.router.chain_head)):
observer = router.observers.get(self.event_name) observer = router.observers.get(self.event_name)
if observer: if observer:
@ -58,14 +59,14 @@ class TelegramEventObserver:
self, self,
callback: CallbackType, callback: CallbackType,
*filters: CallbackType, *filters: CallbackType,
flags: Optional[Dict[str, Any]] = None, flags: dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> CallbackType: ) -> CallbackType:
""" """
Register event handler Register event handler
""" """
if kwargs: if kwargs:
raise UnsupportedKeywordArgument( msg = (
"Passing any additional keyword arguments to the registrar method " "Passing any additional keyword arguments to the registrar method "
"is not supported.\n" "is not supported.\n"
"This error may be caused when you are trying to register filters like in 2.x " "This error may be caused when you are trying to register filters like in 2.x "
@ -73,6 +74,7 @@ class TelegramEventObserver:
"documentation pages.\n" "documentation pages.\n"
f"Please remove the {set(kwargs.keys())} arguments from this call.\n" f"Please remove the {set(kwargs.keys())} arguments from this call.\n"
) )
raise UnsupportedKeywordArgument(msg)
if flags is None: if flags is None:
flags = {} flags = {}
@ -86,13 +88,16 @@ class TelegramEventObserver:
callback=callback, callback=callback,
filters=[FilterObject(filter_) for filter_ in filters], filters=[FilterObject(filter_) for filter_ in filters],
flags=flags, flags=flags,
) ),
) )
return callback return callback
def wrap_outer_middleware( def wrap_outer_middleware(
self, callback: Any, event: TelegramObject, data: Dict[str, Any] self,
callback: Any,
event: TelegramObject,
data: dict[str, Any],
) -> Any: ) -> Any:
wrapped_outer = self.middleware.wrap_middlewares( wrapped_outer = self.middleware.wrap_middlewares(
self.outer_middleware, self.outer_middleware,
@ -127,7 +132,7 @@ class TelegramEventObserver:
def __call__( def __call__(
self, self,
*filters: CallbackType, *filters: CallbackType,
flags: Optional[Dict[str, Any]] = None, flags: dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Callable[[CallbackType], CallbackType]: ) -> Callable[[CallbackType], CallbackType]:
""" """

View file

@ -1,5 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast, overload from typing import TYPE_CHECKING, Any, Union, cast, overload
from magic_filter import AttrDict, MagicFilter from magic_filter import AttrDict, MagicFilter
@ -39,11 +40,12 @@ class FlagDecorator:
def __call__( def __call__(
self, self,
value: Optional[Any] = None, value: Any | None = None,
**kwargs: Any, **kwargs: Any,
) -> Union[Callable[..., Any], "FlagDecorator"]: ) -> Union[Callable[..., Any], "FlagDecorator"]:
if value and kwargs: if value and kwargs:
raise ValueError("The arguments `value` and **kwargs can not be used together") msg = "The arguments `value` and **kwargs can not be used together"
raise ValueError(msg)
if value is not None and callable(value): if value is not None and callable(value):
value.aiogram_flag = { value.aiogram_flag = {
@ -70,20 +72,21 @@ if TYPE_CHECKING:
class FlagGenerator: class FlagGenerator:
def __getattr__(self, name: str) -> FlagDecorator: def __getattr__(self, name: str) -> FlagDecorator:
if name[0] == "_": if name[0] == "_":
raise AttributeError("Flag name must NOT start with underscore") msg = "Flag name must NOT start with underscore"
raise AttributeError(msg)
return FlagDecorator(Flag(name, True)) return FlagDecorator(Flag(name, True))
if TYPE_CHECKING: if TYPE_CHECKING:
chat_action: _ChatActionFlagProtocol chat_action: _ChatActionFlagProtocol
def extract_flags_from_object(obj: Any) -> Dict[str, Any]: def extract_flags_from_object(obj: Any) -> dict[str, Any]:
if not hasattr(obj, "aiogram_flag"): if not hasattr(obj, "aiogram_flag"):
return {} return {}
return cast(Dict[str, Any], obj.aiogram_flag) return cast(dict[str, Any], obj.aiogram_flag)
def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, Any]: def extract_flags(handler: Union["HandlerObject", dict[str, Any]]) -> dict[str, Any]:
""" """
Extract flags from handler or middleware context data Extract flags from handler or middleware context data
@ -98,10 +101,10 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
def get_flag( def get_flag(
handler: Union["HandlerObject", Dict[str, Any]], handler: Union["HandlerObject", dict[str, Any]],
name: str, name: str,
*, *,
default: Optional[Any] = None, default: Any | None = None,
) -> Any: ) -> Any:
""" """
Get flag by name Get flag by name
@ -115,7 +118,7 @@ def get_flag(
return flags.get(name, default) return flags.get(name, default)
def check_flags(handler: Union["HandlerObject", Dict[str, Any]], magic: MagicFilter) -> Any: def check_flags(handler: Union["HandlerObject", dict[str, Any]], magic: MagicFilter) -> Any:
""" """
Check flags via magic filter Check flags via magic filter

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, TypeVar from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
@ -14,9 +15,9 @@ class BaseMiddleware(ABC):
@abstractmethod @abstractmethod
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: # pragma: no cover ) -> Any: # pragma: no cover
""" """
Execute middleware Execute middleware
@ -26,4 +27,3 @@ class BaseMiddleware(ABC):
:param data: Contextual data. Will be mapped to handler arguments :param data: Contextual data. Will be mapped to handler arguments
:return: :class:`Any` :return: :class:`Any`
""" """
pass

View file

@ -1,14 +1,16 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiogram.dispatcher.event.bases import UNHANDLED, CancelHandler, SkipHandler
from aiogram.types import TelegramObject, Update
from aiogram.types.error_event import ErrorEvent
from ...types import TelegramObject, Update
from ...types.error_event import ErrorEvent
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
from .base import BaseMiddleware from .base import BaseMiddleware
if TYPE_CHECKING: if TYPE_CHECKING:
from ..router import Router from aiogram.dispatcher.router import Router
class ErrorsMiddleware(BaseMiddleware): class ErrorsMiddleware(BaseMiddleware):
@ -17,9 +19,9 @@ class ErrorsMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
try: try:
return await handler(event, data) return await handler(event, data)

View file

@ -1,5 +1,6 @@
import functools import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload from collections.abc import Callable, Sequence
from typing import Any, overload
from aiogram.dispatcher.event.bases import ( from aiogram.dispatcher.event.bases import (
MiddlewareEventType, MiddlewareEventType,
@ -12,7 +13,7 @@ from aiogram.types import TelegramObject
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]): class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __init__(self) -> None: def __init__(self) -> None:
self._middlewares: List[MiddlewareType[TelegramObject]] = [] self._middlewares: list[MiddlewareType[TelegramObject]] = []
def register( def register(
self, self,
@ -26,11 +27,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __call__( def __call__(
self, self,
middleware: Optional[MiddlewareType[TelegramObject]] = None, middleware: MiddlewareType[TelegramObject] | None = None,
) -> Union[ ) -> (
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]], Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]]
MiddlewareType[TelegramObject], | MiddlewareType[TelegramObject]
]: ):
if middleware is None: if middleware is None:
return self.register return self.register
return self.register(middleware) return self.register(middleware)
@ -44,8 +45,9 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
pass pass
def __getitem__( def __getitem__(
self, item: Union[int, slice] self,
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]: item: int | slice,
) -> MiddlewareType[TelegramObject] | Sequence[MiddlewareType[TelegramObject]]:
return self._middlewares[item] return self._middlewares[item]
def __len__(self) -> int: def __len__(self) -> int:
@ -53,10 +55,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
@staticmethod @staticmethod
def wrap_middlewares( def wrap_middlewares(
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType middlewares: Sequence[MiddlewareType[MiddlewareEventType]],
handler: CallbackType,
) -> NextMiddlewareType[MiddlewareEventType]: ) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler) @functools.wraps(handler)
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any: def handler_wrapper(event: TelegramObject, kwargs: dict[str, Any]) -> Any:
return handler(event, **kwargs) return handler(event, **kwargs)
middleware = handler_wrapper middleware = handler_wrapper

View file

@ -1,5 +1,6 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional from typing import Any
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import ( from aiogram.types import (
@ -20,29 +21,30 @@ EVENT_THREAD_ID_KEY = "event_thread_id"
@dataclass(frozen=True) @dataclass(frozen=True)
class EventContext: class EventContext:
chat: Optional[Chat] = None chat: Chat | None = None
user: Optional[User] = None user: User | None = None
thread_id: Optional[int] = None thread_id: int | None = None
business_connection_id: Optional[str] = None business_connection_id: str | None = None
@property @property
def user_id(self) -> Optional[int]: def user_id(self) -> int | None:
return self.user.id if self.user else None return self.user.id if self.user else None
@property @property
def chat_id(self) -> Optional[int]: def chat_id(self) -> int | None:
return self.chat.id if self.chat else None return self.chat.id if self.chat else None
class UserContextMiddleware(BaseMiddleware): class UserContextMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
if not isinstance(event, Update): if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!") msg = "UserContextMiddleware got an unexpected event type!"
raise RuntimeError(msg)
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event) event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
# Backward compatibility # Backward compatibility
@ -116,13 +118,15 @@ class UserContextMiddleware(BaseMiddleware):
) )
if event.my_chat_member: if event.my_chat_member:
return EventContext( return EventContext(
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user chat=event.my_chat_member.chat,
user=event.my_chat_member.from_user,
) )
if event.chat_member: if event.chat_member:
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user) return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
if event.chat_join_request: if event.chat_join_request:
return EventContext( return EventContext(
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user chat=event.chat_join_request.chat,
user=event.chat_join_request.from_user,
) )
if event.message_reaction: if event.message_reaction:
return EventContext( return EventContext(

View file

@ -1,12 +1,15 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Final, Generator, List, Optional, Set from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Final
from ..types import TelegramObject
from .event.bases import REJECTED, UNHANDLED from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver from .event.event import EventObserver
from .event.telegram import TelegramEventObserver from .event.telegram import TelegramEventObserver
if TYPE_CHECKING:
from aiogram.types import TelegramObject
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"}) INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
@ -21,31 +24,34 @@ class Router:
- By decorator - :obj:`@router.<event_type>(<filters, ...>)` - By decorator - :obj:`@router.<event_type>(<filters, ...>)`
""" """
def __init__(self, *, name: Optional[str] = None) -> None: def __init__(self, *, name: str | None = None) -> None:
""" """
:param name: Optional router name, can be useful for debugging :param name: Optional router name, can be useful for debugging
""" """
self.name = name or hex(id(self)) self.name = name or hex(id(self))
self._parent_router: Optional[Router] = None self._parent_router: Router | None = None
self.sub_routers: List[Router] = [] self.sub_routers: list[Router] = []
# Observers # Observers
self.message = TelegramEventObserver(router=self, event_name="message") self.message = TelegramEventObserver(router=self, event_name="message")
self.edited_message = TelegramEventObserver(router=self, event_name="edited_message") self.edited_message = TelegramEventObserver(router=self, event_name="edited_message")
self.channel_post = TelegramEventObserver(router=self, event_name="channel_post") self.channel_post = TelegramEventObserver(router=self, event_name="channel_post")
self.edited_channel_post = TelegramEventObserver( self.edited_channel_post = TelegramEventObserver(
router=self, event_name="edited_channel_post" router=self,
event_name="edited_channel_post",
) )
self.inline_query = TelegramEventObserver(router=self, event_name="inline_query") self.inline_query = TelegramEventObserver(router=self, event_name="inline_query")
self.chosen_inline_result = TelegramEventObserver( self.chosen_inline_result = TelegramEventObserver(
router=self, event_name="chosen_inline_result" router=self,
event_name="chosen_inline_result",
) )
self.callback_query = TelegramEventObserver(router=self, event_name="callback_query") self.callback_query = TelegramEventObserver(router=self, event_name="callback_query")
self.shipping_query = TelegramEventObserver(router=self, event_name="shipping_query") self.shipping_query = TelegramEventObserver(router=self, event_name="shipping_query")
self.pre_checkout_query = TelegramEventObserver( self.pre_checkout_query = TelegramEventObserver(
router=self, event_name="pre_checkout_query" router=self,
event_name="pre_checkout_query",
) )
self.poll = TelegramEventObserver(router=self, event_name="poll") self.poll = TelegramEventObserver(router=self, event_name="poll")
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer") self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
@ -54,24 +60,30 @@ class Router:
self.chat_join_request = TelegramEventObserver(router=self, event_name="chat_join_request") self.chat_join_request = TelegramEventObserver(router=self, event_name="chat_join_request")
self.message_reaction = TelegramEventObserver(router=self, event_name="message_reaction") self.message_reaction = TelegramEventObserver(router=self, event_name="message_reaction")
self.message_reaction_count = TelegramEventObserver( self.message_reaction_count = TelegramEventObserver(
router=self, event_name="message_reaction_count" router=self,
event_name="message_reaction_count",
) )
self.chat_boost = TelegramEventObserver(router=self, event_name="chat_boost") self.chat_boost = TelegramEventObserver(router=self, event_name="chat_boost")
self.removed_chat_boost = TelegramEventObserver( self.removed_chat_boost = TelegramEventObserver(
router=self, event_name="removed_chat_boost" router=self,
event_name="removed_chat_boost",
) )
self.deleted_business_messages = TelegramEventObserver( self.deleted_business_messages = TelegramEventObserver(
router=self, event_name="deleted_business_messages" router=self,
event_name="deleted_business_messages",
) )
self.business_connection = TelegramEventObserver( self.business_connection = TelegramEventObserver(
router=self, event_name="business_connection" router=self,
event_name="business_connection",
) )
self.edited_business_message = TelegramEventObserver( self.edited_business_message = TelegramEventObserver(
router=self, event_name="edited_business_message" router=self,
event_name="edited_business_message",
) )
self.business_message = TelegramEventObserver(router=self, event_name="business_message") self.business_message = TelegramEventObserver(router=self, event_name="business_message")
self.purchased_paid_media = TelegramEventObserver( self.purchased_paid_media = TelegramEventObserver(
router=self, event_name="purchased_paid_media" router=self,
event_name="purchased_paid_media",
) )
self.errors = self.error = TelegramEventObserver(router=self, event_name="error") self.errors = self.error = TelegramEventObserver(router=self, event_name="error")
@ -79,7 +91,7 @@ class Router:
self.startup = EventObserver() self.startup = EventObserver()
self.shutdown = EventObserver() self.shutdown = EventObserver()
self.observers: Dict[str, TelegramEventObserver] = { self.observers: dict[str, TelegramEventObserver] = {
"message": self.message, "message": self.message,
"edited_message": self.edited_message, "edited_message": self.edited_message,
"channel_post": self.channel_post, "channel_post": self.channel_post,
@ -112,7 +124,7 @@ class Router:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<{self}>" return f"<{self}>"
def resolve_used_update_types(self, skip_events: Optional[Set[str]] = None) -> List[str]: def resolve_used_update_types(self, skip_events: set[str] | None = None) -> list[str]:
""" """
Resolve registered event names Resolve registered event names
@ -121,7 +133,7 @@ class Router:
:param skip_events: skip specified event names :param skip_events: skip specified event names
:return: set of registered names :return: set of registered names
""" """
handlers_in_use: Set[str] = set() handlers_in_use: set[str] = set()
if skip_events is None: if skip_events is None:
skip_events = set() skip_events = set()
skip_events = {*skip_events, *INTERNAL_UPDATE_TYPES} skip_events = {*skip_events, *INTERNAL_UPDATE_TYPES}
@ -139,7 +151,10 @@ class Router:
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any: async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
return await self._propagate_event( return await self._propagate_event(
observer=observer, update_type=update_type, event=telegram_event, **data observer=observer,
update_type=update_type,
event=telegram_event,
**data,
) )
if observer: if observer:
@ -148,7 +163,7 @@ class Router:
async def _propagate_event( async def _propagate_event(
self, self,
observer: Optional[TelegramEventObserver], observer: TelegramEventObserver | None,
update_type: str, update_type: str,
event: TelegramObject, event: TelegramObject,
**kwargs: Any, **kwargs: Any,
@ -179,7 +194,7 @@ class Router:
@property @property
def chain_head(self) -> Generator[Router, None, None]: def chain_head(self) -> Generator[Router, None, None]:
router: Optional[Router] = self router: Router | None = self
while router: while router:
yield router yield router
router = router.parent_router router = router.parent_router
@ -191,7 +206,7 @@ class Router:
yield from router.chain_tail yield from router.chain_tail
@property @property
def parent_router(self) -> Optional[Router]: def parent_router(self) -> Router | None:
return self._parent_router return self._parent_router
@parent_router.setter @parent_router.setter
@ -206,16 +221,20 @@ class Router:
:param router: :param router:
""" """
if not isinstance(router, Router): if not isinstance(router, Router):
raise ValueError(f"router should be instance of Router not {type(router).__name__!r}") msg = f"router should be instance of Router not {type(router).__name__!r}"
raise ValueError(msg)
if self._parent_router: if self._parent_router:
raise RuntimeError(f"Router is already attached to {self._parent_router!r}") msg = f"Router is already attached to {self._parent_router!r}"
raise RuntimeError(msg)
if self == router: if self == router:
raise RuntimeError("Self-referencing routers is not allowed") msg = "Self-referencing routers is not allowed"
raise RuntimeError(msg)
parent: Optional[Router] = router parent: Router | None = router
while parent is not None: while parent is not None:
if parent == self: if parent == self:
raise RuntimeError("Circular referencing of Router is not allowed") msg = "Circular referencing of Router is not allowed"
raise RuntimeError(msg)
parent = parent.parent_router parent = parent.parent_router
@ -230,7 +249,8 @@ class Router:
:return: :return:
""" """
if not routers: if not routers:
raise ValueError("At least one router must be provided") msg = "At least one router must be provided"
raise ValueError(msg)
for router in routers: for router in routers:
self.include_router(router) self.include_router(router)
@ -242,9 +262,8 @@ class Router:
:return: :return:
""" """
if not isinstance(router, Router): if not isinstance(router, Router):
raise ValueError( msg = f"router should be instance of Router not {type(router).__class__.__name__}"
f"router should be instance of Router not {type(router).__class__.__name__}" raise ValueError(msg)
)
router.parent_router = self router.parent_router = self
return router return router

View file

@ -16,7 +16,7 @@ class DetailedAiogramError(AiogramError):
Base exception for all aiogram errors with detailed message. Base exception for all aiogram errors with detailed message.
""" """
url: Optional[str] = None url: str | None = None
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
self.message = message self.message = message

View file

@ -23,29 +23,29 @@ from .state import StateFilter
BaseFilter = Filter BaseFilter = Filter
__all__ = ( __all__ = (
"Filter", "ADMINISTRATOR",
"CREATOR",
"IS_ADMIN",
"IS_MEMBER",
"IS_NOT_MEMBER",
"JOIN_TRANSITION",
"KICKED",
"LEAVE_TRANSITION",
"LEFT",
"MEMBER",
"PROMOTED_TRANSITION",
"RESTRICTED",
"BaseFilter", "BaseFilter",
"ChatMemberUpdatedFilter",
"Command", "Command",
"CommandObject", "CommandObject",
"CommandStart", "CommandStart",
"ExceptionMessageFilter", "ExceptionMessageFilter",
"ExceptionTypeFilter", "ExceptionTypeFilter",
"StateFilter", "Filter",
"MagicData", "MagicData",
"ChatMemberUpdatedFilter", "StateFilter",
"CREATOR",
"ADMINISTRATOR",
"MEMBER",
"RESTRICTED",
"LEFT",
"KICKED",
"IS_MEMBER",
"IS_ADMIN",
"PROMOTED_TRANSITION",
"IS_NOT_MEMBER",
"JOIN_TRANSITION",
"LEAVE_TRANSITION",
"and_f", "and_f",
"or_f",
"invert_f", "invert_f",
"or_f",
) )

View file

@ -1,11 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from aiogram.filters.logic import _InvertFilter from aiogram.filters.logic import _InvertFilter
class Filter(ABC): class Filter(ABC): # noqa: B024
""" """
If you want to register own filters like builtin filters you will need to write subclass If you want to register own filters like builtin filters you will need to write subclass
of this class with overriding the :code:`__call__` of this class with overriding the :code:`__call__`
@ -16,11 +17,11 @@ class Filter(ABC):
# This checking type-hint is needed because mypy checks validity of overrides and raises: # This checking type-hint is needed because mypy checks validity of overrides and raises:
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override] # error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override # https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
__call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]] __call__: Callable[..., Awaitable[bool | dict[str, Any]]]
else: # pragma: no cover else: # pragma: no cover
@abstractmethod @abstractmethod
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
""" """
This method should be overridden. This method should be overridden.
@ -28,21 +29,19 @@ class Filter(ABC):
:return: :class:`bool` or :class:`Dict[str, Any]` :return: :class:`bool` or :class:`Dict[str, Any]`
""" """
pass
def __invert__(self) -> "_InvertFilter": def __invert__(self) -> "_InvertFilter":
from aiogram.filters.logic import invert_f from aiogram.filters.logic import invert_f
return invert_f(self) return invert_f(self)
def update_handler_flags(self, flags: Dict[str, Any]) -> None: def update_handler_flags(self, flags: dict[str, Any]) -> None: # noqa: B027
""" """
Also if you want to extend handler flags with using this filter Also if you want to extend handler flags with using this filter
you should implement this method you should implement this method
:param flags: existing flags, can be updated directly :param flags: existing flags, can be updated directly
""" """
pass
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str: def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
items = [repr(arg) for arg in args] items = [repr(arg) for arg in args]

View file

@ -1,40 +1,30 @@
from __future__ import annotations from __future__ import annotations
import sys
import types import types
import typing import typing
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import ( from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from uuid import UUID from uuid import UUID
from magic_filter import MagicFilter
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from typing_extensions import Self
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import CallbackQuery from aiogram.types import CallbackQuery
if TYPE_CHECKING:
from magic_filter import MagicFilter
from pydantic.fields import FieldInfo
T = TypeVar("T", bound="CallbackData") T = TypeVar("T", bound="CallbackData")
MAX_CALLBACK_LENGTH: int = 64 MAX_CALLBACK_LENGTH: int = 64
_UNION_TYPES = {typing.Union} _UNION_TYPES = {typing.Union, types.UnionType}
if sys.version_info >= (3, 10): # pragma: no cover
_UNION_TYPES.add(types.UnionType)
class CallbackDataException(Exception): class CallbackDataException(Exception):
@ -59,17 +49,19 @@ class CallbackData(BaseModel):
def __init_subclass__(cls, **kwargs: Any) -> None: def __init_subclass__(cls, **kwargs: Any) -> None:
if "prefix" not in kwargs: if "prefix" not in kwargs:
raise ValueError( msg = (
f"prefix required, usage example: " f"prefix required, usage example: "
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`" f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
) )
raise ValueError(msg)
cls.__separator__ = kwargs.pop("sep", ":") cls.__separator__ = kwargs.pop("sep", ":")
cls.__prefix__ = kwargs.pop("prefix") cls.__prefix__ = kwargs.pop("prefix")
if cls.__separator__ in cls.__prefix__: if cls.__separator__ in cls.__prefix__:
raise ValueError( msg = (
f"Separator symbol {cls.__separator__!r} can not be used " f"Separator symbol {cls.__separator__!r} can not be used "
f"inside prefix {cls.__prefix__!r}" f"inside prefix {cls.__prefix__!r}"
) )
raise ValueError(msg)
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
def _encode_value(self, key: str, value: Any) -> str: def _encode_value(self, key: str, value: Any) -> str:
@ -83,10 +75,11 @@ class CallbackData(BaseModel):
return str(int(value)) return str(int(value))
if isinstance(value, (int, str, float, Decimal, Fraction)): if isinstance(value, (int, str, float, Decimal, Fraction)):
return str(value) return str(value)
raise ValueError( msg = (
f"Attribute {key}={value!r} of type {type(value).__name__!r}" f"Attribute {key}={value!r} of type {type(value).__name__!r}"
f" can not be packed to callback data" f" can not be packed to callback data"
) )
raise ValueError(msg)
def pack(self) -> str: def pack(self) -> str:
""" """
@ -98,21 +91,23 @@ class CallbackData(BaseModel):
for key, value in self.model_dump(mode="python").items(): for key, value in self.model_dump(mode="python").items():
encoded = self._encode_value(key, value) encoded = self._encode_value(key, value)
if self.__separator__ in encoded: if self.__separator__ in encoded:
raise ValueError( msg = (
f"Separator symbol {self.__separator__!r} can not be used " f"Separator symbol {self.__separator__!r} can not be used "
f"in value {key}={encoded!r}" f"in value {key}={encoded!r}"
) )
raise ValueError(msg)
result.append(encoded) result.append(encoded)
callback_data = self.__separator__.join(result) callback_data = self.__separator__.join(result)
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH: if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
raise ValueError( msg = (
f"Resulted callback data is too long! " f"Resulted callback data is too long! "
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}" f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
) )
raise ValueError(msg)
return callback_data return callback_data
@classmethod @classmethod
def unpack(cls: Type[T], value: str) -> T: def unpack(cls, value: str) -> Self:
""" """
Parse callback data string Parse callback data string
@ -122,22 +117,28 @@ class CallbackData(BaseModel):
prefix, *parts = value.split(cls.__separator__) prefix, *parts = value.split(cls.__separator__)
names = cls.model_fields.keys() names = cls.model_fields.keys()
if len(parts) != len(names): if len(parts) != len(names):
raise TypeError( msg = (
f"Callback data {cls.__name__!r} takes {len(names)} arguments " f"Callback data {cls.__name__!r} takes {len(names)} arguments "
f"but {len(parts)} were given" f"but {len(parts)} were given"
) )
raise TypeError(msg)
if prefix != cls.__prefix__: if prefix != cls.__prefix__:
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})") msg = f"Bad prefix ({prefix!r} != {cls.__prefix__!r})"
raise ValueError(msg)
payload = {} payload = {}
for k, v in zip(names, parts): # type: str, Optional[str] for k, v in zip(names, parts, strict=True): # type: str, str
if field := cls.model_fields.get(k): if (
if v == "" and _check_field_is_nullable(field) and field.default != "": (field := cls.model_fields.get(k))
v = field.default if field.default is not PydanticUndefined else None and v == ""
and _check_field_is_nullable(field)
and field.default != ""
):
v = field.default if field.default is not PydanticUndefined else None
payload[k] = v payload[k] = v
return cls(**payload) return cls(**payload)
@classmethod @classmethod
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter: def filter(cls, rule: MagicFilter | None = None) -> CallbackQueryFilter:
""" """
Generates a filter for callback query with rule Generates a filter for callback query with rule
@ -163,8 +164,8 @@ class CallbackQueryFilter(Filter):
def __init__( def __init__(
self, self,
*, *,
callback_data: Type[CallbackData], callback_data: type[CallbackData],
rule: Optional[MagicFilter] = None, rule: MagicFilter | None = None,
): ):
""" """
:param callback_data: Expected type of callback data :param callback_data: Expected type of callback data
@ -179,7 +180,7 @@ class CallbackQueryFilter(Filter):
rule=self.rule, rule=self.rule,
) )
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]: async def __call__(self, query: CallbackQuery) -> Literal[False] | dict[str, Any]:
if not isinstance(query, CallbackQuery) or not query.data: if not isinstance(query, CallbackQuery) or not query.data:
return False return False
try: try:
@ -204,5 +205,5 @@ def _check_field_is_nullable(field: FieldInfo) -> bool:
return True return True
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args( return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
field.annotation field.annotation,
) )

View file

@ -1,4 +1,6 @@
from typing import Any, Dict, Optional, TypeVar, Union from typing import Any, TypeVar, Union
from typing_extensions import Self
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import ChatMember, ChatMemberUpdated from aiogram.types import ChatMember, ChatMemberUpdated
@ -10,11 +12,11 @@ TransitionT = TypeVar("TransitionT", bound="_MemberStatusTransition")
class _MemberStatusMarker: class _MemberStatusMarker:
__slots__ = ( __slots__ = (
"name",
"is_member", "is_member",
"name",
) )
def __init__(self, name: str, *, is_member: Optional[bool] = None) -> None: def __init__(self, name: str, *, is_member: bool | None = None) -> None:
self.name = name self.name = name
self.is_member = is_member self.is_member = is_member
@ -22,53 +24,59 @@ class _MemberStatusMarker:
result = self.name.upper() result = self.name.upper()
if self.is_member is not None: if self.is_member is not None:
result = ("+" if self.is_member else "-") + result result = ("+" if self.is_member else "-") + result
return result # noqa: RET504 return result
def __pos__(self: MarkerT) -> MarkerT: def __pos__(self) -> Self:
return type(self)(name=self.name, is_member=True) return type(self)(name=self.name, is_member=True)
def __neg__(self: MarkerT) -> MarkerT: def __neg__(self) -> Self:
return type(self)(name=self.name, is_member=False) return type(self)(name=self.name, is_member=False)
def __or__( def __or__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusGroupMarker": ) -> "_MemberStatusGroupMarker":
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return _MemberStatusGroupMarker(self, other) return _MemberStatusGroupMarker(self, other)
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return other | self return other | self
raise TypeError( msg = (
f"unsupported operand type(s) for |: " f"unsupported operand type(s) for |: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
__ror__ = __or__ __ror__ = __or__
def __rshift__( def __rshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition": ) -> "_MemberStatusTransition":
old = _MemberStatusGroupMarker(self) old = _MemberStatusGroupMarker(self)
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=old, new=_MemberStatusGroupMarker(other)) return _MemberStatusTransition(old=old, new=_MemberStatusGroupMarker(other))
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=old, new=other) return _MemberStatusTransition(old=old, new=other)
raise TypeError( msg = (
f"unsupported operand type(s) for >>: " f"unsupported operand type(s) for >>: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
def __lshift__( def __lshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition": ) -> "_MemberStatusTransition":
new = _MemberStatusGroupMarker(self) new = _MemberStatusGroupMarker(self)
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=new) return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=new)
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=other, new=new) return _MemberStatusTransition(old=other, new=new)
raise TypeError( msg = (
f"unsupported operand type(s) for <<: " f"unsupported operand type(s) for <<: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.name, self.is_member)) return hash((self.name, self.is_member))
@ -87,44 +95,51 @@ class _MemberStatusGroupMarker:
def __init__(self, *statuses: _MemberStatusMarker) -> None: def __init__(self, *statuses: _MemberStatusMarker) -> None:
if not statuses: if not statuses:
raise ValueError("Member status group should have at least one status included") msg = "Member status group should have at least one status included"
raise ValueError(msg)
self.statuses = frozenset(statuses) self.statuses = frozenset(statuses)
def __or__( def __or__(
self: MarkerGroupT, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
) -> MarkerGroupT: other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> Self:
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return type(self)(*self.statuses, other) return type(self)(*self.statuses, other)
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return type(self)(*self.statuses, *other.statuses) return type(self)(*self.statuses, *other.statuses)
raise TypeError( msg = (
f"unsupported operand type(s) for |: " f"unsupported operand type(s) for |: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
def __rshift__( def __rshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition": ) -> "_MemberStatusTransition":
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=self, new=_MemberStatusGroupMarker(other)) return _MemberStatusTransition(old=self, new=_MemberStatusGroupMarker(other))
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=self, new=other) return _MemberStatusTransition(old=self, new=other)
raise TypeError( msg = (
f"unsupported operand type(s) for >>: " f"unsupported operand type(s) for >>: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
def __lshift__( def __lshift__(
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"] self,
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
) -> "_MemberStatusTransition": ) -> "_MemberStatusTransition":
if isinstance(other, _MemberStatusMarker): if isinstance(other, _MemberStatusMarker):
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=self) return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=self)
if isinstance(other, _MemberStatusGroupMarker): if isinstance(other, _MemberStatusGroupMarker):
return _MemberStatusTransition(old=other, new=self) return _MemberStatusTransition(old=other, new=self)
raise TypeError( msg = (
f"unsupported operand type(s) for <<: " f"unsupported operand type(s) for <<: "
f"{type(self).__name__!r} and {type(other).__name__!r}" f"{type(self).__name__!r} and {type(other).__name__!r}"
) )
raise TypeError(msg)
def __str__(self) -> str: def __str__(self) -> str:
result = " | ".join(map(str, sorted(self.statuses, key=str))) result = " | ".join(map(str, sorted(self.statuses, key=str)))
@ -138,8 +153,8 @@ class _MemberStatusGroupMarker:
class _MemberStatusTransition: class _MemberStatusTransition:
__slots__ = ( __slots__ = (
"old",
"new", "new",
"old",
) )
def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None: def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None:
@ -149,7 +164,7 @@ class _MemberStatusTransition:
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.old} >> {self.new}" return f"{self.old} >> {self.new}"
def __invert__(self: TransitionT) -> TransitionT: def __invert__(self) -> Self:
return type(self)(old=self.new, new=self.old) return type(self)(old=self.new, new=self.old)
def check(self, *, old: ChatMember, new: ChatMember) -> bool: def check(self, *, old: ChatMember, new: ChatMember) -> bool:
@ -177,11 +192,9 @@ class ChatMemberUpdatedFilter(Filter):
def __init__( def __init__(
self, self,
member_status_changed: Union[ member_status_changed: (
_MemberStatusMarker, _MemberStatusMarker | _MemberStatusGroupMarker | _MemberStatusTransition
_MemberStatusGroupMarker, ),
_MemberStatusTransition,
],
): ):
self.member_status_changed = member_status_changed self.member_status_changed = member_status_changed
@ -190,7 +203,7 @@ class ChatMemberUpdatedFilter(Filter):
member_status_changed=self.member_status_changed, member_status_changed=self.member_status_changed,
) )
async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]: async def __call__(self, member_updated: ChatMemberUpdated) -> bool | dict[str, Any]:
old = member_updated.old_chat_member old = member_updated.old_chat_member
new = member_updated.new_chat_member new = member_updated.new_chat_member
rule = self.member_status_changed rule = self.member_status_changed

View file

@ -1,31 +1,21 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import ( from re import Match, Pattern
TYPE_CHECKING, from typing import TYPE_CHECKING, Any, cast
Any,
Dict,
Iterable,
Match,
Optional,
Pattern,
Sequence,
Union,
cast,
)
from magic_filter import MagicFilter
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import BotCommand, Message from aiogram.types import BotCommand, Message
from aiogram.utils.deep_linking import decode_payload from aiogram.utils.deep_linking import decode_payload
if TYPE_CHECKING: if TYPE_CHECKING:
from magic_filter import MagicFilter
from aiogram import Bot from aiogram import Bot
# TODO: rm type ignore after py3.8 support expiration or mypy bug fix CommandPatternType = str | re.Pattern[str] | BotCommand
CommandPatternType = Union[str, re.Pattern, BotCommand] # type: ignore[type-arg]
class CommandException(Exception): class CommandException(Exception):
@ -41,20 +31,20 @@ class Command(Filter):
__slots__ = ( __slots__ = (
"commands", "commands",
"prefix",
"ignore_case", "ignore_case",
"ignore_mention", "ignore_mention",
"magic", "magic",
"prefix",
) )
def __init__( def __init__(
self, self,
*values: CommandPatternType, *values: CommandPatternType,
commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None, commands: Sequence[CommandPatternType] | CommandPatternType | None = None,
prefix: str = "/", prefix: str = "/",
ignore_case: bool = False, ignore_case: bool = False,
ignore_mention: bool = False, ignore_mention: bool = False,
magic: Optional[MagicFilter] = None, magic: MagicFilter | None = None,
): ):
""" """
List of commands (string or compiled regexp patterns) List of commands (string or compiled regexp patterns)
@ -74,26 +64,29 @@ class Command(Filter):
commands = [commands] commands = [commands]
if not isinstance(commands, Iterable): if not isinstance(commands, Iterable):
raise ValueError( msg = (
"Command filter only supports str, re.Pattern, BotCommand object" "Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable" " or their Iterable"
) )
raise ValueError(msg)
items = [] items = []
for command in (*values, *commands): for command in (*values, *commands):
if isinstance(command, BotCommand): if isinstance(command, BotCommand):
command = command.command command = command.command
if not isinstance(command, (str, re.Pattern)): if not isinstance(command, (str, re.Pattern)):
raise ValueError( msg = (
"Command filter only supports str, re.Pattern, BotCommand object" "Command filter only supports str, re.Pattern, BotCommand object"
" or their Iterable" " or their Iterable"
) )
raise ValueError(msg)
if ignore_case and isinstance(command, str): if ignore_case and isinstance(command, str):
command = command.casefold() command = command.casefold()
items.append(command) items.append(command)
if not items: if not items:
raise ValueError("At least one command should be specified") msg = "At least one command should be specified"
raise ValueError(msg)
self.commands = tuple(items) self.commands = tuple(items)
self.prefix = prefix self.prefix = prefix
@ -110,11 +103,11 @@ class Command(Filter):
magic=self.magic, magic=self.magic,
) )
def update_handler_flags(self, flags: Dict[str, Any]) -> None: def update_handler_flags(self, flags: dict[str, Any]) -> None:
commands = flags.setdefault("commands", []) commands = flags.setdefault("commands", [])
commands.append(self) commands.append(self)
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: async def __call__(self, message: Message, bot: Bot) -> bool | dict[str, Any]:
if not isinstance(message, Message): if not isinstance(message, Message):
return False return False
@ -137,7 +130,8 @@ class Command(Filter):
try: try:
full_command, *args = text.split(maxsplit=1) full_command, *args = text.split(maxsplit=1)
except ValueError: except ValueError:
raise CommandException("not enough values to unpack") msg = "not enough values to unpack"
raise CommandException(msg)
# Separate command into valuable parts # Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention") # "/command@mention" -> "/", ("command", "@", "mention")
@ -151,13 +145,15 @@ class Command(Filter):
def validate_prefix(self, command: CommandObject) -> None: def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.prefix: if command.prefix not in self.prefix:
raise CommandException("Invalid command prefix") msg = "Invalid command prefix"
raise CommandException(msg)
async def validate_mention(self, bot: Bot, command: CommandObject) -> None: async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.ignore_mention: if command.mention and not self.ignore_mention:
me = await bot.me() me = await bot.me()
if me.username and command.mention.lower() != me.username.lower(): if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match") msg = "Mention did not match"
raise CommandException(msg)
def validate_command(self, command: CommandObject) -> CommandObject: def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatternType], self.commands): for allowed_command in cast(Sequence[CommandPatternType], self.commands):
@ -174,7 +170,8 @@ class Command(Filter):
if command_name == allowed_command: # String if command_name == allowed_command: # String
return command return command
raise CommandException("Command did not match pattern") msg = "Command did not match pattern"
raise CommandException(msg)
async def parse_command(self, text: str, bot: Bot) -> CommandObject: async def parse_command(self, text: str, bot: Bot) -> CommandObject:
""" """
@ -196,7 +193,8 @@ class Command(Filter):
return command return command
result = self.magic.resolve(command) result = self.magic.resolve(command)
if not result: if not result:
raise CommandException("Rejected via magic filter") msg = "Rejected via magic filter"
raise CommandException(msg)
return replace(command, magic_result=result) return replace(command, magic_result=result)
@ -211,13 +209,13 @@ class CommandObject:
"""Command prefix""" """Command prefix"""
command: str = "" command: str = ""
"""Command without prefix and mention""" """Command without prefix and mention"""
mention: Optional[str] = None mention: str | None = None
"""Mention (if available)""" """Mention (if available)"""
args: Optional[str] = field(repr=False, default=None) args: str | None = field(repr=False, default=None)
"""Command argument""" """Command argument"""
regexp_match: Optional[Match[str]] = field(repr=False, default=None) regexp_match: Match[str] | None = field(repr=False, default=None)
"""Will be presented match result if the command is presented as regexp in filter""" """Will be presented match result if the command is presented as regexp in filter"""
magic_result: Optional[Any] = field(repr=False, default=None) magic_result: Any | None = field(repr=False, default=None)
@property @property
def mentioned(self) -> bool: def mentioned(self) -> bool:
@ -246,7 +244,7 @@ class CommandStart(Command):
deep_link_encoded: bool = False, deep_link_encoded: bool = False,
ignore_case: bool = False, ignore_case: bool = False,
ignore_mention: bool = False, ignore_mention: bool = False,
magic: Optional[MagicFilter] = None, magic: MagicFilter | None = None,
): ):
super().__init__( super().__init__(
"start", "start",
@ -287,12 +285,14 @@ class CommandStart(Command):
if not self.deep_link: if not self.deep_link:
return command return command
if not command.args: if not command.args:
raise CommandException("Deep-link was missing") msg = "Deep-link was missing"
raise CommandException(msg)
args = command.args args = command.args
if self.deep_link_encoded: if self.deep_link_encoded:
try: try:
args = decode_payload(args) args = decode_payload(args)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise CommandException(f"Failed to decode Base64: {e}") msg = f"Failed to decode Base64: {e}"
raise CommandException(msg)
return replace(command, args=args) return replace(command, args=args)
return command return command

View file

@ -1,5 +1,6 @@
import re import re
from typing import Any, Dict, Pattern, Type, Union, cast from re import Pattern
from typing import Any, cast
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
@ -13,15 +14,16 @@ class ExceptionTypeFilter(Filter):
__slots__ = ("exceptions",) __slots__ = ("exceptions",)
def __init__(self, *exceptions: Type[Exception]): def __init__(self, *exceptions: type[Exception]):
""" """
:param exceptions: Exception type(s) :param exceptions: Exception type(s)
""" """
if not exceptions: if not exceptions:
raise ValueError("At least one exception type is required") msg = "At least one exception type is required"
raise ValueError(msg)
self.exceptions = exceptions self.exceptions = exceptions
async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]: async def __call__(self, obj: TelegramObject) -> bool | dict[str, Any]:
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions) return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
@ -32,7 +34,7 @@ class ExceptionMessageFilter(Filter):
__slots__ = ("pattern",) __slots__ = ("pattern",)
def __init__(self, pattern: Union[str, Pattern[str]]): def __init__(self, pattern: str | Pattern[str]):
""" """
:param pattern: Regexp pattern :param pattern: Regexp pattern
""" """
@ -48,7 +50,7 @@ class ExceptionMessageFilter(Filter):
async def __call__( async def __call__(
self, self,
obj: TelegramObject, obj: TelegramObject,
) -> Union[bool, Dict[str, Any]]: ) -> bool | dict[str, Any]:
result = self.pattern.match(str(cast(ErrorEvent, obj).exception)) result = self.pattern.match(str(cast(ErrorEvent, obj).exception))
if not result: if not result:
return False return False

View file

@ -1,5 +1,5 @@
from abc import ABC from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Union from typing import TYPE_CHECKING, Any
from aiogram.filters import Filter from aiogram.filters import Filter
@ -17,7 +17,7 @@ class _InvertFilter(_LogicFilter):
def __init__(self, target: "FilterObject") -> None: def __init__(self, target: "FilterObject") -> None:
self.target = target self.target = target
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
return not bool(await self.target.call(*args, **kwargs)) return not bool(await self.target.call(*args, **kwargs))
@ -27,7 +27,7 @@ class _AndFilter(_LogicFilter):
def __init__(self, *targets: "FilterObject") -> None: def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
final_result = {} final_result = {}
for target in self.targets: for target in self.targets:
@ -48,7 +48,7 @@ class _OrFilter(_LogicFilter):
def __init__(self, *targets: "FilterObject") -> None: def __init__(self, *targets: "FilterObject") -> None:
self.targets = targets self.targets = targets
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]: async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
for target in self.targets: for target in self.targets:
result = await target.call(*args, **kwargs) result = await target.call(*args, **kwargs)
if not result: if not result:

View file

@ -18,7 +18,7 @@ class MagicData(Filter):
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any: async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
return self.magic_data.resolve( return self.magic_data.resolve(
AttrDict({"event": event, **dict(enumerate(args)), **kwargs}) AttrDict({"event": event, **dict(enumerate(args)), **kwargs}),
) )
def __str__(self) -> str: def __str__(self) -> str:

View file

@ -1,11 +1,12 @@
from collections.abc import Sequence
from inspect import isclass from inspect import isclass
from typing import Any, Dict, Optional, Sequence, Type, Union, cast from typing import Any, cast
from aiogram.filters.base import Filter from aiogram.filters.base import Filter
from aiogram.fsm.state import State, StatesGroup from aiogram.fsm.state import State, StatesGroup
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]] StateType = str | State | StatesGroup | type[StatesGroup] | None
class StateFilter(Filter): class StateFilter(Filter):
@ -17,7 +18,8 @@ class StateFilter(Filter):
def __init__(self, *states: StateType) -> None: def __init__(self, *states: StateType) -> None:
if not states: if not states:
raise ValueError("At least one state is required") msg = "At least one state is required"
raise ValueError(msg)
self.states = states self.states = states
@ -27,17 +29,22 @@ class StateFilter(Filter):
) )
async def __call__( async def __call__(
self, obj: TelegramObject, raw_state: Optional[str] = None self,
) -> Union[bool, Dict[str, Any]]: obj: TelegramObject,
raw_state: str | None = None,
) -> bool | dict[str, Any]:
allowed_states = cast(Sequence[StateType], self.states) allowed_states = cast(Sequence[StateType], self.states)
for allowed_state in allowed_states: for allowed_state in allowed_states:
if isinstance(allowed_state, str) or allowed_state is None: if isinstance(allowed_state, str) or allowed_state is None:
if allowed_state == "*" or raw_state == allowed_state: if allowed_state in {"*", raw_state}:
return True return True
elif isinstance(allowed_state, (State, StatesGroup)): elif isinstance(allowed_state, (State, StatesGroup)):
if allowed_state(event=obj, raw_state=raw_state): if allowed_state(event=obj, raw_state=raw_state):
return True return True
elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup): elif (
if allowed_state()(event=obj, raw_state=raw_state): isclass(allowed_state)
return True and issubclass(allowed_state, StatesGroup)
and allowed_state()(event=obj, raw_state=raw_state)
):
return True
return False return False

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, overload from collections.abc import Mapping
from typing import Any, overload
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
@ -11,27 +12,29 @@ class FSMContext:
async def set_state(self, state: StateType = None) -> None: async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(key=self.key, state=state) await self.storage.set_state(key=self.key, state=state)
async def get_state(self) -> Optional[str]: async def get_state(self) -> str | None:
return await self.storage.get_state(key=self.key) return await self.storage.get_state(key=self.key)
async def set_data(self, data: Mapping[str, Any]) -> None: async def set_data(self, data: Mapping[str, Any]) -> None:
await self.storage.set_data(key=self.key, data=data) await self.storage.set_data(key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]: async def get_data(self) -> dict[str, Any]:
return await self.storage.get_data(key=self.key) return await self.storage.get_data(key=self.key)
@overload @overload
async def get_value(self, key: str) -> Optional[Any]: ... async def get_value(self, key: str) -> Any | None: ...
@overload @overload
async def get_value(self, key: str, default: Any) -> Any: ... async def get_value(self, key: str, default: Any) -> Any: ...
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]: async def get_value(self, key: str, default: Any | None = None) -> Any | None:
return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default) return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default)
async def update_data( async def update_data(
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any self,
) -> Dict[str, Any]: data: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
if data: if data:
kwargs.update(data) kwargs.update(data)
return await self.storage.update_data(key=self.key, data=kwargs) return await self.storage.update_data(key=self.key, data=kwargs)

View file

@ -1,4 +1,5 @@
from typing import Any, Awaitable, Callable, Dict, Optional, cast from collections.abc import Awaitable, Callable
from typing import Any, cast
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.base import BaseMiddleware
@ -27,9 +28,9 @@ class FSMContextMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
bot: Bot = cast(Bot, data["bot"]) bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data) context = self.resolve_event_context(bot, data)
@ -45,9 +46,9 @@ class FSMContextMiddleware(BaseMiddleware):
def resolve_event_context( def resolve_event_context(
self, self,
bot: Bot, bot: Bot,
data: Dict[str, Any], data: dict[str, Any],
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> FSMContext | None:
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY)) event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
return self.resolve_context( return self.resolve_context(
bot=bot, bot=bot,
@ -61,12 +62,12 @@ class FSMContextMiddleware(BaseMiddleware):
def resolve_context( def resolve_context(
self, self,
bot: Bot, bot: Bot,
chat_id: Optional[int], chat_id: int | None,
user_id: Optional[int], user_id: int | None,
thread_id: Optional[int] = None, thread_id: int | None = None,
business_connection_id: Optional[str] = None, business_connection_id: str | None = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]: ) -> FSMContext | None:
if chat_id is None: if chat_id is None:
chat_id = user_id chat_id = user_id
@ -92,8 +93,8 @@ class FSMContextMiddleware(BaseMiddleware):
bot: Bot, bot: Bot,
chat_id: int, chat_id: int,
user_id: int, user_id: int,
thread_id: Optional[int] = None, thread_id: int | None = None,
business_connection_id: Optional[str] = None, business_connection_id: str | None = None,
destiny: str = DEFAULT_DESTINY, destiny: str = DEFAULT_DESTINY,
) -> FSMContext: ) -> FSMContext:
return FSMContext( return FSMContext(

View file

@ -2,26 +2,15 @@ from __future__ import annotations
import inspect import inspect
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from enum import Enum, auto from enum import Enum, auto
from typing import ( from typing import TYPE_CHECKING, Any, ClassVar, overload
Any,
ClassVar,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
overload,
)
from typing_extensions import Self from typing_extensions import Self
from aiogram import loggers from aiogram import loggers
from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.dispatcher.flags import extract_flags_from_object from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
@ -36,16 +25,20 @@ from aiogram.utils.class_attrs_resolver import (
get_sorted_mro_attrs_resolver, get_sorted_mro_attrs_resolver,
) )
if TYPE_CHECKING:
from aiogram.dispatcher.event.bases import NextMiddlewareType
class HistoryManager: class HistoryManager:
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10): def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
self._size = size self._size = size
self._state = state self._state = state
self._history_state = FSMContext( self._history_state = FSMContext(
storage=state.storage, key=replace(state.key, destiny=destiny) storage=state.storage,
key=replace(state.key, destiny=destiny),
) )
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None: async def push(self, state: str | None, data: dict[str, Any]) -> None:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
history.append({"state": state, "data": data}) history.append({"state": state, "data": data})
@ -55,7 +48,7 @@ class HistoryManager:
await self._history_state.update_data(history=history) await self._history_state.update_data(history=history)
async def pop(self) -> Optional[MemoryStorageRecord]: async def pop(self) -> MemoryStorageRecord | None:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
if not history: if not history:
@ -70,14 +63,14 @@ class HistoryManager:
loggers.scene.debug("Pop state=%s data=%s from history", state, data) loggers.scene.debug("Pop state=%s data=%s from history", state, data)
return MemoryStorageRecord(state=state, data=data) return MemoryStorageRecord(state=state, data=data)
async def get(self) -> Optional[MemoryStorageRecord]: async def get(self) -> MemoryStorageRecord | None:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
if not history: if not history:
return None return None
return MemoryStorageRecord(**history[-1]) return MemoryStorageRecord(**history[-1])
async def all(self) -> List[MemoryStorageRecord]: async def all(self) -> list[MemoryStorageRecord]:
history_data = await self._history_state.get_data() history_data = await self._history_state.get_data()
history = history_data.setdefault("history", []) history = history_data.setdefault("history", [])
return [MemoryStorageRecord(**item) for item in history] return [MemoryStorageRecord(**item) for item in history]
@ -91,11 +84,11 @@ class HistoryManager:
data = await self._state.get_data() data = await self._state.get_data()
await self.push(state, data) await self.push(state, data)
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None: async def _set_state(self, state: str | None, data: dict[str, Any]) -> None:
await self._state.set_state(state) await self._state.set_state(state)
await self._state.set_data(data) await self._state.set_data(data)
async def rollback(self) -> Optional[str]: async def rollback(self) -> str | None:
previous_state = await self.pop() previous_state = await self.pop()
if not previous_state: if not previous_state:
await self._set_state(None, {}) await self._set_state(None, {})
@ -116,14 +109,14 @@ class ObserverDecorator:
name: str, name: str,
filters: tuple[CallbackType, ...], filters: tuple[CallbackType, ...],
action: SceneAction | None = None, action: SceneAction | None = None,
after: Optional[After] = None, after: After | None = None,
) -> None: ) -> None:
self.name = name self.name = name
self.filters = filters self.filters = filters
self.action = action self.action = action
self.after = after self.after = after
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None: def _wrap_filter(self, target: type[Scene] | CallbackType) -> None:
handlers = getattr(target, "__aiogram_handler__", None) handlers = getattr(target, "__aiogram_handler__", None)
if not handlers: if not handlers:
handlers = [] handlers = []
@ -135,7 +128,7 @@ class ObserverDecorator:
handler=target, handler=target,
filters=self.filters, filters=self.filters,
after=self.after, after=self.after,
) ),
) )
def _wrap_action(self, target: CallbackType) -> None: def _wrap_action(self, target: CallbackType) -> None:
@ -154,13 +147,14 @@ class ObserverDecorator:
else: else:
self._wrap_action(target) self._wrap_action(target)
else: else:
raise TypeError("Only function or method is allowed") msg = "Only function or method is allowed"
raise TypeError(msg)
return target return target
def leave(self) -> ActionContainer: def leave(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.leave) return ActionContainer(self.name, self.filters, SceneAction.leave)
def enter(self, target: Type[Scene]) -> ActionContainer: def enter(self, target: type[Scene]) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.enter, target) return ActionContainer(self.name, self.filters, SceneAction.enter, target)
def exit(self) -> ActionContainer: def exit(self) -> ActionContainer:
@ -181,9 +175,9 @@ class ActionContainer:
def __init__( def __init__(
self, self,
name: str, name: str,
filters: Tuple[CallbackType, ...], filters: tuple[CallbackType, ...],
action: SceneAction, action: SceneAction,
target: Optional[Union[Type[Scene], State, str]] = None, target: type[Scene] | State | str | None = None,
) -> None: ) -> None:
self.name = name self.name = name
self.filters = filters self.filters = filters
@ -201,33 +195,27 @@ class ActionContainer:
await wizard.back() await wizard.back()
@dataclass(slots=True)
class HandlerContainer: class HandlerContainer:
def __init__( name: str
self, handler: CallbackType
name: str, filters: tuple[CallbackType, ...]
handler: CallbackType, after: After | None = None
filters: Tuple[CallbackType, ...],
after: Optional[After] = None,
) -> None:
self.name = name
self.handler = handler
self.filters = filters
self.after = after
@dataclass() @dataclass
class SceneConfig: class SceneConfig:
state: Optional[str] state: str | None
"""Scene state""" """Scene state"""
handlers: List[HandlerContainer] handlers: list[HandlerContainer]
"""Scene handlers""" """Scene handlers"""
actions: Dict[SceneAction, Dict[str, CallableObject]] actions: dict[SceneAction, dict[str, CallableObject]]
"""Scene actions""" """Scene actions"""
reset_data_on_enter: Optional[bool] = None reset_data_on_enter: bool | None = None
"""Reset scene data on enter""" """Reset scene data on enter"""
reset_history_on_enter: Optional[bool] = None reset_history_on_enter: bool | None = None
"""Reset scene history on enter""" """Reset scene history on enter"""
callback_query_without_state: Optional[bool] = None callback_query_without_state: bool | None = None
"""Allow callback query without state""" """Allow callback query without state"""
attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver
""" """
@ -247,9 +235,9 @@ async def _empty_handler(*args: Any, **kwargs: Any) -> None:
class SceneHandlerWrapper: class SceneHandlerWrapper:
def __init__( def __init__(
self, self,
scene: Type[Scene], scene: type[Scene],
handler: CallbackType, handler: CallbackType,
after: Optional[After] = None, after: After | None = None,
) -> None: ) -> None:
self.scene = scene self.scene = scene
self.handler = CallableObject(handler) self.handler = CallableObject(handler)
@ -271,7 +259,7 @@ class SceneHandlerWrapper:
update_type=event_update.event_type, update_type=event_update.event_type,
event=event, event=event,
data=kwargs, data=kwargs,
) ),
) )
result = await self.handler.call(scene, event, **kwargs) result = await self.handler.call(scene, event, **kwargs)
@ -331,7 +319,7 @@ class Scene:
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
handlers: list[HandlerContainer] = [] handlers: list[HandlerContainer] = []
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict) actions: defaultdict[SceneAction, dict[str, CallableObject]] = defaultdict(dict)
for base in cls.__bases__: for base in cls.__bases__:
if not issubclass(base, Scene): if not issubclass(base, Scene):
@ -353,7 +341,7 @@ class Scene:
if attrs_resolver is None: if attrs_resolver is None:
attrs_resolver = get_sorted_mro_attrs_resolver attrs_resolver = get_sorted_mro_attrs_resolver
for name, value in attrs_resolver(cls): for _name, value in attrs_resolver(cls):
if scene_handlers := getattr(value, "__aiogram_handler__", None): if scene_handlers := getattr(value, "__aiogram_handler__", None):
handlers.extend(scene_handlers) handlers.extend(scene_handlers)
if isinstance(value, ObserverDecorator): if isinstance(value, ObserverDecorator):
@ -363,7 +351,7 @@ class Scene:
_empty_handler, _empty_handler,
value.filters, value.filters,
after=value.after, after=value.after,
) ),
) )
if hasattr(value, "__aiogram_action__"): if hasattr(value, "__aiogram_action__"):
for action, action_handlers in value.__aiogram_action__.items(): for action, action_handlers in value.__aiogram_action__.items():
@ -408,7 +396,7 @@ class Scene:
router.observers[observer_name].filter(StateFilter(scene_config.state)) router.observers[observer_name].filter(StateFilter(scene_config.state))
@classmethod @classmethod
def as_router(cls, name: Optional[str] = None) -> Router: def as_router(cls, name: str | None = None) -> Router:
""" """
Returns the scene as a router. Returns the scene as a router.
@ -433,7 +421,9 @@ class Scene:
""" """
async def enter_to_scene_handler( async def enter_to_scene_handler(
event: TelegramObject, scenes: ScenesManager, **middleware_kwargs: Any event: TelegramObject,
scenes: ScenesManager,
**middleware_kwargs: Any,
) -> None: ) -> None:
await scenes.enter(cls, **{**handler_kwargs, **middleware_kwargs}) await scenes.enter(cls, **{**handler_kwargs, **middleware_kwargs})
@ -461,7 +451,7 @@ class SceneWizard:
state: FSMContext, state: FSMContext,
update_type: str, update_type: str,
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
): ):
""" """
A class that represents a wizard for managing scenes in a Telegram bot. A class that represents a wizard for managing scenes in a Telegram bot.
@ -480,7 +470,7 @@ class SceneWizard:
self.event = event self.event = event
self.data = data self.data = data
self.scene: Optional[Scene] = None self.scene: Scene | None = None
async def enter(self, **kwargs: Any) -> None: async def enter(self, **kwargs: Any) -> None:
""" """
@ -548,7 +538,7 @@ class SceneWizard:
assert self.scene_config.state is not None, "Scene state is not specified" assert self.scene_config.state is not None, "Scene state is not specified"
await self.goto(self.scene_config.state, **kwargs) await self.goto(self.scene_config.state, **kwargs)
async def goto(self, scene: Union[Type[Scene], State, str], **kwargs: Any) -> None: async def goto(self, scene: type[Scene] | State | str, **kwargs: Any) -> None:
""" """
The `goto` method transitions to a new scene. The `goto` method transitions to a new scene.
It first calls the `leave` method to perform any necessary cleanup It first calls the `leave` method to perform any necessary cleanup
@ -565,13 +555,16 @@ class SceneWizard:
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool: async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
if not self.scene: if not self.scene:
raise SceneException("Scene is not initialized") msg = "Scene is not initialized"
raise SceneException(msg)
loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state) loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state)
action_config = self.scene_config.actions.get(action, {}) action_config = self.scene_config.actions.get(action, {})
if not action_config: if not action_config:
loggers.scene.debug( loggers.scene.debug(
"Action %r not found in scene %r", action.name, self.scene_config.state "Action %r not found in scene %r",
action.name,
self.scene_config.state,
) )
return False return False
@ -597,7 +590,7 @@ class SceneWizard:
""" """
await self.state.set_data(data=data) await self.state.set_data(data=data)
async def get_data(self) -> Dict[str, Any]: async def get_data(self) -> dict[str, Any]:
""" """
This method returns the data stored in the current state. This method returns the data stored in the current state.
@ -606,7 +599,7 @@ class SceneWizard:
return await self.state.get_data() return await self.state.get_data()
@overload @overload
async def get_value(self, key: str) -> Optional[Any]: async def get_value(self, key: str) -> Any | None:
""" """
This method returns the value from key in the data of the current state. This method returns the value from key in the data of the current state.
@ -614,7 +607,6 @@ class SceneWizard:
:return: A dictionary containing the data stored in the scene state. :return: A dictionary containing the data stored in the scene state.
""" """
pass
@overload @overload
async def get_value(self, key: str, default: Any) -> Any: async def get_value(self, key: str, default: Any) -> Any:
@ -626,14 +618,15 @@ class SceneWizard:
:return: A dictionary containing the data stored in the scene state. :return: A dictionary containing the data stored in the scene state.
""" """
pass
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]: async def get_value(self, key: str, default: Any | None = None) -> Any | None:
return await self.state.get_value(key, default) return await self.state.get_value(key, default)
async def update_data( async def update_data(
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any self,
) -> Dict[str, Any]: data: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
""" """
This method updates the data stored in the current state This method updates the data stored in the current state
@ -666,7 +659,7 @@ class ScenesManager:
update_type: str, update_type: str,
event: TelegramObject, event: TelegramObject,
state: FSMContext, state: FSMContext,
data: Dict[str, Any], data: dict[str, Any],
) -> None: ) -> None:
self.registry = registry self.registry = registry
self.update_type = update_type self.update_type = update_type
@ -676,7 +669,7 @@ class ScenesManager:
self.history = HistoryManager(self.state) self.history = HistoryManager(self.state)
async def _get_scene(self, scene_type: Optional[Union[Type[Scene], State, str]]) -> Scene: async def _get_scene(self, scene_type: type[Scene] | State | str | None) -> Scene:
scene_type = self.registry.get(scene_type) scene_type = self.registry.get(scene_type)
return scene_type( return scene_type(
wizard=SceneWizard( wizard=SceneWizard(
@ -689,7 +682,7 @@ class ScenesManager:
), ),
) )
async def _get_active_scene(self) -> Optional[Scene]: async def _get_active_scene(self) -> Scene | None:
state = await self.state.get_state() state = await self.state.get_state()
try: try:
return await self._get_scene(state) return await self._get_scene(state)
@ -698,7 +691,7 @@ class ScenesManager:
async def enter( async def enter(
self, self,
scene_type: Optional[Union[Type[Scene], State, str]], scene_type: type[Scene] | State | str | None,
_check_active: bool = True, _check_active: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -753,7 +746,7 @@ class SceneRegistry:
self.router = router self.router = router
self.register_on_add = register_on_add self.register_on_add = register_on_add
self._scenes: Dict[Optional[str], Type[Scene]] = {} self._scenes: dict[str | None, type[Scene]] = {}
self._setup_middleware(router) self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None: def _setup_middleware(self, router: Router) -> None:
@ -772,7 +765,7 @@ class SceneRegistry:
self, self,
handler: NextMiddlewareType[TelegramObject], handler: NextMiddlewareType[TelegramObject],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
assert isinstance(event, Update), "Event must be an Update instance" assert isinstance(event, Update), "Event must be an Update instance"
@ -789,7 +782,7 @@ class SceneRegistry:
self, self,
handler: NextMiddlewareType[TelegramObject], handler: NextMiddlewareType[TelegramObject],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
update: Update = data["event_update"] update: Update = data["event_update"]
data["scenes"] = ScenesManager( data["scenes"] = ScenesManager(
@ -801,7 +794,7 @@ class SceneRegistry:
) )
return await handler(event, data) return await handler(event, data)
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None: def add(self, *scenes: type[Scene], router: Router | None = None) -> None:
""" """
This method adds the specified scenes to the registry This method adds the specified scenes to the registry
and optionally registers it to the router. and optionally registers it to the router.
@ -820,13 +813,13 @@ class SceneRegistry:
:return: None :return: None
""" """
if not scenes: if not scenes:
raise ValueError("At least one scene must be specified") msg = "At least one scene must be specified"
raise ValueError(msg)
for scene in scenes: for scene in scenes:
if scene.__scene_config__.state in self._scenes: if scene.__scene_config__.state in self._scenes:
raise SceneException( msg = f"Scene with state {scene.__scene_config__.state!r} already exists"
f"Scene with state {scene.__scene_config__.state!r} already exists" raise SceneException(msg)
)
self._scenes[scene.__scene_config__.state] = scene self._scenes[scene.__scene_config__.state] = scene
@ -835,7 +828,7 @@ class SceneRegistry:
elif self.register_on_add: elif self.register_on_add:
self.router.include_router(scene.as_router()) self.router.include_router(scene.as_router())
def register(self, *scenes: Type[Scene]) -> None: def register(self, *scenes: type[Scene]) -> None:
""" """
Registers one or more scenes to the SceneRegistry. Registers one or more scenes to the SceneRegistry.
@ -844,7 +837,7 @@ class SceneRegistry:
""" """
self.add(*scenes, router=self.router) self.add(*scenes, router=self.router)
def get(self, scene: Optional[Union[Type[Scene], State, str]]) -> Type[Scene]: def get(self, scene: type[Scene] | State | str | None) -> type[Scene]:
""" """
This method returns the registered Scene object for the specified scene. This method returns the registered Scene object for the specified scene.
The scene parameter can be either a Scene object, State object or a string representing The scene parameter can be either a Scene object, State object or a string representing
@ -865,18 +858,20 @@ class SceneRegistry:
if isinstance(scene, State): if isinstance(scene, State):
scene = scene.state scene = scene.state
if scene is not None and not isinstance(scene, str): if scene is not None and not isinstance(scene, str):
raise SceneException("Scene must be a subclass of Scene, State or a string") msg = "Scene must be a subclass of Scene, State or a string"
raise SceneException(msg)
try: try:
return self._scenes[scene] return self._scenes[scene]
except KeyError: except KeyError:
raise SceneException(f"Scene {scene!r} is not registered") msg = f"Scene {scene!r} is not registered"
raise SceneException(msg)
@dataclass @dataclass
class After: class After:
action: SceneAction action: SceneAction
scene: Optional[Union[Type[Scene], State, str]] = None scene: type[Scene] | State | str | None = None
@classmethod @classmethod
def exit(cls) -> After: def exit(cls) -> After:
@ -887,7 +882,7 @@ class After:
return cls(action=SceneAction.back) return cls(action=SceneAction.back)
@classmethod @classmethod
def goto(cls, scene: Optional[Union[Type[Scene], State, str]]) -> After: def goto(cls, scene: type[Scene] | State | str | None) -> After:
return cls(action=SceneAction.enter, scene=scene) return cls(action=SceneAction.enter, scene=scene)
@ -898,7 +893,7 @@ class ObserverMarker:
def __call__( def __call__(
self, self,
*filters: CallbackType, *filters: CallbackType,
after: Optional[After] = None, after: After | None = None,
) -> ObserverDecorator: ) -> ObserverDecorator:
return ObserverDecorator( return ObserverDecorator(
self.name, self.name,

View file

@ -1,5 +1,6 @@
import inspect import inspect
from typing import Any, Iterator, Optional, Tuple, Type, no_type_check from collections.abc import Iterator
from typing import Any, no_type_check
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
@ -9,19 +10,20 @@ class State:
State object State object
""" """
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None) -> None: def __init__(self, state: str | None = None, group_name: str | None = None) -> None:
self._state = state self._state = state
self._group_name = group_name self._group_name = group_name
self._group: Optional[Type[StatesGroup]] = None self._group: type[StatesGroup] | None = None
@property @property
def group(self) -> "Type[StatesGroup]": def group(self) -> "type[StatesGroup]":
if not self._group: if not self._group:
raise RuntimeError("This state is not in any group.") msg = "This state is not in any group."
raise RuntimeError(msg)
return self._group return self._group
@property @property
def state(self) -> Optional[str]: def state(self) -> str | None:
if self._state is None or self._state == "*": if self._state is None or self._state == "*":
return self._state return self._state
@ -34,12 +36,13 @@ class State:
return f"{group}:{self._state}" return f"{group}:{self._state}"
def set_parent(self, group: "Type[StatesGroup]") -> None: def set_parent(self, group: "type[StatesGroup]") -> None:
if not issubclass(group, StatesGroup): if not issubclass(group, StatesGroup):
raise ValueError("Group must be subclass of StatesGroup") msg = "Group must be subclass of StatesGroup"
raise ValueError(msg)
self._group = group self._group = group
def __set_name__(self, owner: "Type[StatesGroup]", name: str) -> None: def __set_name__(self, owner: "type[StatesGroup]", name: str) -> None:
if self._state is None: if self._state is None:
self._state = name self._state = name
self.set_parent(owner) self.set_parent(owner)
@ -49,12 +52,12 @@ class State:
__repr__ = __str__ __repr__ = __str__
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool: def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
if self.state == "*": if self.state == "*":
return True return True
return raw_state == self.state return raw_state == self.state
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return self.state == other.state return self.state == other.state
if isinstance(other, str): if isinstance(other, str):
@ -66,13 +69,13 @@ class State:
class StatesGroupMeta(type): class StatesGroupMeta(type):
__parent__: "Optional[Type[StatesGroup]]" __parent__: type["StatesGroup"] | None
__childs__: "Tuple[Type[StatesGroup], ...]" __childs__: tuple[type["StatesGroup"], ...]
__states__: Tuple[State, ...] __states__: tuple[State, ...]
__state_names__: Tuple[str, ...] __state_names__: tuple[str, ...]
__all_childs__: Tuple[Type["StatesGroup"], ...] __all_childs__: tuple[type["StatesGroup"], ...]
__all_states__: Tuple[State, ...] __all_states__: tuple[State, ...]
__all_states_names__: Tuple[str, ...] __all_states_names__: tuple[str, ...]
@no_type_check @no_type_check
def __new__(mcs, name, bases, namespace, **kwargs): def __new__(mcs, name, bases, namespace, **kwargs):
@ -81,7 +84,7 @@ class StatesGroupMeta(type):
states = [] states = []
childs = [] childs = []
for name, arg in namespace.items(): for arg in namespace.values():
if isinstance(arg, State): if isinstance(arg, State):
states.append(arg) states.append(arg)
elif inspect.isclass(arg) and issubclass(arg, StatesGroup): elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
@ -106,10 +109,10 @@ class StatesGroupMeta(type):
@property @property
def __full_group_name__(cls) -> str: def __full_group_name__(cls) -> str:
if cls.__parent__: if cls.__parent__:
return ".".join((cls.__parent__.__full_group_name__, cls.__name__)) return f"{cls.__parent__.__full_group_name__}.{cls.__name__}"
return cls.__name__ return cls.__name__
def _prepare_child(cls, child: Type["StatesGroup"]) -> Type["StatesGroup"]: def _prepare_child(cls, child: type["StatesGroup"]) -> type["StatesGroup"]:
"""Prepare child. """Prepare child.
While adding `cls` for its children, we also need to recalculate While adding `cls` for its children, we also need to recalculate
@ -123,19 +126,19 @@ class StatesGroupMeta(type):
child.__all_states_names__ = child._get_all_states_names() child.__all_states_names__ = child._get_all_states_names()
return child return child
def _get_all_childs(cls) -> Tuple[Type["StatesGroup"], ...]: def _get_all_childs(cls) -> tuple[type["StatesGroup"], ...]:
result = cls.__childs__ result = cls.__childs__
for child in cls.__childs__: for child in cls.__childs__:
result += child.__childs__ result += child.__childs__
return result return result
def _get_all_states(cls) -> Tuple[State, ...]: def _get_all_states(cls) -> tuple[State, ...]:
result = cls.__states__ result = cls.__states__
for group in cls.__childs__: for group in cls.__childs__:
result += group.__all_states__ result += group.__all_states__
return result return result
def _get_all_states_names(cls) -> Tuple[str, ...]: def _get_all_states_names(cls) -> tuple[str, ...]:
return tuple(state.state for state in cls.__all_states__ if state.state) return tuple(state.state for state in cls.__all_states__ if state.state)
def __contains__(cls, item: Any) -> bool: def __contains__(cls, item: Any) -> bool:
@ -156,12 +159,12 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta): class StatesGroup(metaclass=StatesGroupMeta):
@classmethod @classmethod
def get_root(cls) -> Type["StatesGroup"]: def get_root(cls) -> type["StatesGroup"]:
if cls.__parent__ is None: if cls.__parent__ is None:
return cls return cls
return cls.__parent__.get_root() return cls.__parent__.get_root()
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool: def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
return raw_state in type(self).__all_states_names__ return raw_state in type(self).__all_states_names__
def __str__(self) -> str: def __str__(self) -> str:

View file

@ -1,20 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Literal, overload
Any,
AsyncGenerator,
Dict,
Literal,
Mapping,
Optional,
Union,
overload,
)
from aiogram.fsm.state import State from aiogram.fsm.state import State
StateType = Optional[Union[str, State]] StateType = str | State | None
DEFAULT_DESTINY = "default" DEFAULT_DESTINY = "default"
@ -24,8 +16,8 @@ class StorageKey:
bot_id: int bot_id: int
chat_id: int chat_id: int
user_id: int user_id: int
thread_id: Optional[int] = None thread_id: int | None = None
business_connection_id: Optional[str] = None business_connection_id: str | None = None
destiny: str = DEFAULT_DESTINY destiny: str = DEFAULT_DESTINY
@ -36,7 +28,7 @@ class KeyBuilder(ABC):
def build( def build(
self, self,
key: StorageKey, key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None, part: Literal["data", "state", "lock"] | None = None,
) -> str: ) -> str:
""" """
Build key to be used in storage's db queries Build key to be used in storage's db queries
@ -45,7 +37,6 @@ class KeyBuilder(ABC):
:param part: part of the record :param part: part of the record
:return: key to be used in storage's db queries :return: key to be used in storage's db queries
""" """
pass
class DefaultKeyBuilder(KeyBuilder): class DefaultKeyBuilder(KeyBuilder):
@ -84,7 +75,7 @@ class DefaultKeyBuilder(KeyBuilder):
def build( def build(
self, self,
key: StorageKey, key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None, part: Literal["data", "state", "lock"] | None = None,
) -> str: ) -> str:
parts = [self.prefix] parts = [self.prefix]
if self.with_bot_id: if self.with_bot_id:
@ -121,17 +112,15 @@ class BaseStorage(ABC):
:param key: storage key :param key: storage key
:param state: new state :param state: new state
""" """
pass
@abstractmethod @abstractmethod
async def get_state(self, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> str | None:
""" """
Get key state Get key state
:param key: storage key :param key: storage key
:return: current state :return: current state
""" """
pass
@abstractmethod @abstractmethod
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
@ -141,20 +130,18 @@ class BaseStorage(ABC):
:param key: storage key :param key: storage key
:param data: new data :param data: new data
""" """
pass
@abstractmethod @abstractmethod
async def get_data(self, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> dict[str, Any]:
""" """
Get current data for key Get current data for key
:param key: storage key :param key: storage key
:return: current data :return: current data
""" """
pass
@overload @overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]: async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None:
""" """
Get single value from data by key Get single value from data by key
@ -162,7 +149,6 @@ class BaseStorage(ABC):
:param dict_key: value key :param dict_key: value key
:return: value stored in key of dict or ``None`` :return: value stored in key of dict or ``None``
""" """
pass
@overload @overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any:
@ -174,15 +160,17 @@ class BaseStorage(ABC):
:param default: default value to return :param default: default value to return
:return: value stored in key of dict or default :return: value stored in key of dict or default
""" """
pass
async def get_value( async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None self,
) -> Optional[Any]: storage_key: StorageKey,
dict_key: str,
default: Any | None = None,
) -> Any | None:
data = await self.get_data(storage_key) data = await self.get_data(storage_key)
return data.get(dict_key, default) return data.get(dict_key, default)
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]: async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
""" """
Update date in the storage for key (like dict.update) Update date in the storage for key (like dict.update)
@ -200,7 +188,6 @@ class BaseStorage(ABC):
""" """
Close storage (database connection, file or etc.) Close storage (database connection, file or etc.)
""" """
pass
class BaseEventIsolation(ABC): class BaseEventIsolation(ABC):

View file

@ -1,18 +1,10 @@
from asyncio import Lock from asyncio import Lock
from collections import defaultdict from collections import defaultdict
from collections.abc import AsyncGenerator, Hashable, Mapping
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from copy import copy from copy import copy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import Any, overload
Any,
AsyncGenerator,
DefaultDict,
Dict,
Hashable,
Mapping,
Optional,
overload,
)
from aiogram.exceptions import DataNotDictLikeError from aiogram.exceptions import DataNotDictLikeError
from aiogram.fsm.state import State from aiogram.fsm.state import State
@ -26,8 +18,8 @@ from aiogram.fsm.storage.base import (
@dataclass @dataclass
class MemoryStorageRecord: class MemoryStorageRecord:
data: Dict[str, Any] = field(default_factory=dict) data: dict[str, Any] = field(default_factory=dict)
state: Optional[str] = None state: str | None = None
class MemoryStorage(BaseStorage): class MemoryStorage(BaseStorage):
@ -41,8 +33,8 @@ class MemoryStorage(BaseStorage):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict( self.storage: defaultdict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord MemoryStorageRecord,
) )
async def close(self) -> None: async def close(self) -> None:
@ -51,28 +43,30 @@ class MemoryStorage(BaseStorage):
async def set_state(self, key: StorageKey, state: StateType = None) -> None: async def set_state(self, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> str | None:
return self.storage[key].state return self.storage[key].state
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict): if not isinstance(data, dict):
raise DataNotDictLikeError( msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
f"Data must be a dict or dict-like object, got {type(data).__name__}" raise DataNotDictLikeError(msg)
)
self.storage[key].data = data.copy() self.storage[key].data = data.copy()
async def get_data(self, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> dict[str, Any]:
return self.storage[key].data.copy() return self.storage[key].data.copy()
@overload @overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]: ... async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None: ...
@overload @overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ... async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
async def get_value( async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None self,
) -> Optional[Any]: storage_key: StorageKey,
dict_key: str,
default: Any | None = None,
) -> Any | None:
data = self.storage[storage_key].data data = self.storage[storage_key].data
return copy(data.get(dict_key, default)) return copy(data.get(dict_key, default))
@ -89,7 +83,7 @@ class DisabledEventIsolation(BaseEventIsolation):
class SimpleEventIsolation(BaseEventIsolation): class SimpleEventIsolation(BaseEventIsolation):
def __init__(self) -> None: def __init__(self) -> None:
# TODO: Unused locks cleaner is needed # TODO: Unused locks cleaner is needed
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock) self._locks: defaultdict[Hashable, Lock] = defaultdict(Lock)
@asynccontextmanager @asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]: async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, cast from collections.abc import Mapping
from typing import Any, cast
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
@ -27,7 +28,7 @@ class MongoStorage(BaseStorage):
def __init__( def __init__(
self, self,
client: AsyncIOMotorClient, client: AsyncIOMotorClient,
key_builder: Optional[KeyBuilder] = None, key_builder: KeyBuilder | None = None,
db_name: str = "aiogram_fsm", db_name: str = "aiogram_fsm",
collection_name: str = "states_and_data", collection_name: str = "states_and_data",
) -> None: ) -> None:
@ -46,7 +47,10 @@ class MongoStorage(BaseStorage):
@classmethod @classmethod
def from_url( def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "MongoStorage": ) -> "MongoStorage":
""" """
Create an instance of :class:`MongoStorage` with specifying the connection string Create an instance of :class:`MongoStorage` with specifying the connection string
@ -65,7 +69,7 @@ class MongoStorage(BaseStorage):
"""Cleanup client resources and disconnect from MongoDB.""" """Cleanup client resources and disconnect from MongoDB."""
self._client.close() self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]: def resolve_state(self, value: StateType) -> str | None:
if value is None: if value is None:
return None return None
if isinstance(value, State): if isinstance(value, State):
@ -90,7 +94,7 @@ class MongoStorage(BaseStorage):
upsert=True, upsert=True,
) )
async def get_state(self, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> str | None:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id}) document = await self._collection.find_one({"_id": document_id})
if document is None: if document is None:
@ -99,9 +103,8 @@ class MongoStorage(BaseStorage):
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict): if not isinstance(data, dict):
raise DataNotDictLikeError( msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
f"Data must be a dict or dict-like object, got {type(data).__name__}" raise DataNotDictLikeError(msg)
)
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
if not data: if not data:
@ -120,14 +123,14 @@ class MongoStorage(BaseStorage):
upsert=True, upsert=True,
) )
async def get_data(self, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> dict[str, Any]:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id}) document = await self._collection.find_one({"_id": document_id})
if document is None or not document.get("data"): if document is None or not document.get("data"):
return {} return {}
return cast(Dict[str, Any], document["data"]) return cast(dict[str, Any], document["data"])
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]: async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
update_with = {f"data.{key}": value for key, value in data.items()} update_with = {f"data.{key}": value for key, value in data.items()}
update_result = await self._collection.find_one_and_update( update_result = await self._collection.find_one_and_update(

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Mapping, Optional, cast from collections.abc import Mapping
from typing import Any, cast
from pymongo import AsyncMongoClient from pymongo import AsyncMongoClient
@ -21,7 +22,7 @@ class PyMongoStorage(BaseStorage):
def __init__( def __init__(
self, self,
client: AsyncMongoClient[Any], client: AsyncMongoClient[Any],
key_builder: Optional[KeyBuilder] = None, key_builder: KeyBuilder | None = None,
db_name: str = "aiogram_fsm", db_name: str = "aiogram_fsm",
collection_name: str = "states_and_data", collection_name: str = "states_and_data",
) -> None: ) -> None:
@ -40,7 +41,10 @@ class PyMongoStorage(BaseStorage):
@classmethod @classmethod
def from_url( def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "PyMongoStorage": ) -> "PyMongoStorage":
""" """
Create an instance of :class:`PyMongoStorage` with specifying the connection string Create an instance of :class:`PyMongoStorage` with specifying the connection string
@ -59,7 +63,7 @@ class PyMongoStorage(BaseStorage):
"""Cleanup client resources and disconnect from MongoDB.""" """Cleanup client resources and disconnect from MongoDB."""
return await self._client.close() return await self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]: def resolve_state(self, value: StateType) -> str | None:
if value is None: if value is None:
return None return None
if isinstance(value, State): if isinstance(value, State):
@ -84,18 +88,17 @@ class PyMongoStorage(BaseStorage):
upsert=True, upsert=True,
) )
async def get_state(self, key: StorageKey) -> Optional[str]: async def get_state(self, key: StorageKey) -> str | None:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id}) document = await self._collection.find_one({"_id": document_id})
if document is None: if document is None:
return None return None
return cast(Optional[str], document.get("state")) return cast(str | None, document.get("state"))
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None: async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
if not isinstance(data, dict): if not isinstance(data, dict):
raise DataNotDictLikeError( msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
f"Data must be a dict or dict-like object, got {type(data).__name__}" raise DataNotDictLikeError(msg)
)
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
if not data: if not data:
@ -114,14 +117,14 @@ class PyMongoStorage(BaseStorage):
upsert=True, upsert=True,
) )
async def get_data(self, key: StorageKey) -> Dict[str, Any]: async def get_data(self, key: StorageKey) -> dict[str, Any]:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id}) document = await self._collection.find_one({"_id": document_id})
if document is None or not document.get("data"): if document is None or not document.get("data"):
return {} return {}
return cast(Dict[str, Any], document["data"]) return cast(dict[str, Any], document["data"])
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]: async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
document_id = self._key_builder.build(key) document_id = self._key_builder.build(key)
update_with = {f"data.{key}": value for key, value in data.items()} update_with = {f"data.{key}": value for key, value in data.items()}
update_result = await self._collection.find_one_and_update( update_result = await self._collection.find_one_and_update(
@ -133,4 +136,4 @@ class PyMongoStorage(BaseStorage):
) )
if not update_result: if not update_result:
await self._collection.delete_one({"_id": document_id}) await self._collection.delete_one({"_id": document_id})
return cast(Dict[str, Any], update_result.get("data", {})) return cast(dict[str, Any], update_result.get("data", {}))

View file

@ -1,6 +1,7 @@
import json import json
from collections.abc import AsyncGenerator, Callable, Mapping
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Optional, cast from typing import Any, cast
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool from redis.asyncio.connection import ConnectionPool
@ -31,9 +32,9 @@ class RedisStorage(BaseStorage):
def __init__( def __init__(
self, self,
redis: Redis, redis: Redis,
key_builder: Optional[KeyBuilder] = None, key_builder: KeyBuilder | None = None,
state_ttl: Optional[ExpiryT] = None, state_ttl: ExpiryT | None = None,
data_ttl: Optional[ExpiryT] = None, data_ttl: ExpiryT | None = None,
json_loads: _JsonLoads = json.loads, json_loads: _JsonLoads = json.loads,
json_dumps: _JsonDumps = json.dumps, json_dumps: _JsonDumps = json.dumps,
) -> None: ) -> None:
@ -54,7 +55,10 @@ class RedisStorage(BaseStorage):
@classmethod @classmethod
def from_url( def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any cls,
url: str,
connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> "RedisStorage": ) -> "RedisStorage":
""" """
Create an instance of :class:`RedisStorage` with specifying the connection string Create an instance of :class:`RedisStorage` with specifying the connection string
@ -94,12 +98,12 @@ class RedisStorage(BaseStorage):
async def get_state( async def get_state(
self, self,
key: StorageKey, key: StorageKey,
) -> Optional[str]: ) -> str | None:
redis_key = self.key_builder.build(key, "state") redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key) value = await self.redis.get(redis_key)
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode("utf-8") return value.decode("utf-8")
return cast(Optional[str], value) return cast(str | None, value)
async def set_data( async def set_data(
self, self,
@ -107,9 +111,8 @@ class RedisStorage(BaseStorage):
data: Mapping[str, Any], data: Mapping[str, Any],
) -> None: ) -> None:
if not isinstance(data, dict): if not isinstance(data, dict):
raise DataNotDictLikeError( msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
f"Data must be a dict or dict-like object, got {type(data).__name__}" raise DataNotDictLikeError(msg)
)
redis_key = self.key_builder.build(key, "data") redis_key = self.key_builder.build(key, "data")
if not data: if not data:
@ -124,22 +127,22 @@ class RedisStorage(BaseStorage):
async def get_data( async def get_data(
self, self,
key: StorageKey, key: StorageKey,
) -> Dict[str, Any]: ) -> dict[str, Any]:
redis_key = self.key_builder.build(key, "data") redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key) value = await self.redis.get(redis_key)
if value is None: if value is None:
return {} return {}
if isinstance(value, bytes): if isinstance(value, bytes):
value = value.decode("utf-8") value = value.decode("utf-8")
return cast(Dict[str, Any], self.json_loads(value)) return cast(dict[str, Any], self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation): class RedisEventIsolation(BaseEventIsolation):
def __init__( def __init__(
self, self,
redis: Redis, redis: Redis,
key_builder: Optional[KeyBuilder] = None, key_builder: KeyBuilder | None = None,
lock_kwargs: Optional[Dict[str, Any]] = None, lock_kwargs: dict[str, Any] | None = None,
) -> None: ) -> None:
if key_builder is None: if key_builder is None:
key_builder = DefaultKeyBuilder() key_builder = DefaultKeyBuilder()
@ -153,7 +156,7 @@ class RedisEventIsolation(BaseEventIsolation):
def from_url( def from_url(
cls, cls,
url: str, url: str,
connection_kwargs: Optional[Dict[str, Any]] = None, connection_kwargs: dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> "RedisEventIsolation": ) -> "RedisEventIsolation":
if connection_kwargs is None: if connection_kwargs is None:

View file

@ -1,5 +1,4 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Optional, Tuple
class FSMStrategy(Enum): class FSMStrategy(Enum):
@ -23,8 +22,8 @@ def apply_strategy(
strategy: FSMStrategy, strategy: FSMStrategy,
chat_id: int, chat_id: int,
user_id: int, user_id: int,
thread_id: Optional[int] = None, thread_id: int | None = None,
) -> Tuple[int, int, Optional[int]]: ) -> tuple[int, int, int | None]:
if strategy == FSMStrategy.CHAT: if strategy == FSMStrategy.CHAT:
return chat_id, chat_id, None return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER: if strategy == FSMStrategy.GLOBAL_USER:

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from aiogram.types import Update from aiogram.types import Update
@ -14,7 +14,7 @@ T = TypeVar("T")
class BaseHandlerMixin(Generic[T]): class BaseHandlerMixin(Generic[T]):
if TYPE_CHECKING: if TYPE_CHECKING:
event: T event: T
data: Dict[str, Any] data: dict[str, Any]
class BaseHandler(BaseHandlerMixin[T], ABC): class BaseHandler(BaseHandlerMixin[T], ABC):
@ -24,7 +24,7 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
def __init__(self, event: T, **kwargs: Any) -> None: def __init__(self, event: T, **kwargs: Any) -> None:
self.event: T = event self.event: T = event
self.data: Dict[str, Any] = kwargs self.data: dict[str, Any] = kwargs
@property @property
def bot(self) -> Bot: def bot(self) -> Bot:
@ -32,7 +32,8 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
if "bot" in self.data: if "bot" in self.data:
return cast(Bot, self.data["bot"]) return cast(Bot, self.data["bot"])
raise RuntimeError("Bot instance not found in the context") msg = "Bot instance not found in the context"
raise RuntimeError(msg)
@property @property
def update(self) -> Update: def update(self) -> Update:

View file

@ -29,14 +29,14 @@ class CallbackQueryHandler(BaseHandler[CallbackQuery], ABC):
return self.event.from_user return self.event.from_user
@property @property
def message(self) -> Optional[MaybeInaccessibleMessage]: def message(self) -> MaybeInaccessibleMessage | None:
""" """
Is alias for `event.message` Is alias for `event.message`
""" """
return self.event.message return self.event.message
@property @property
def callback_data(self) -> Optional[str]: def callback_data(self) -> str | None:
""" """
Is alias for `event.data` Is alias for `event.data`
""" """

View file

@ -12,7 +12,7 @@ class MessageHandler(BaseHandler[Message], ABC):
""" """
@property @property
def from_user(self) -> Optional[User]: def from_user(self) -> User | None:
return self.event.from_user return self.event.from_user
@property @property
@ -22,7 +22,7 @@ class MessageHandler(BaseHandler[Message], ABC):
class MessageHandlerCommandMixin(BaseHandlerMixin[Message]): class MessageHandlerCommandMixin(BaseHandlerMixin[Message]):
@property @property
def command(self) -> Optional[CommandObject]: def command(self) -> CommandObject | None:
if "command" in self.data: if "command" in self.data:
return cast(CommandObject, self.data["command"]) return cast(CommandObject, self.data["command"])
return None return None

View file

@ -1,5 +1,4 @@
from abc import ABC from abc import ABC
from typing import List
from aiogram.handlers import BaseHandler from aiogram.handlers import BaseHandler
from aiogram.types import Poll, PollOption from aiogram.types import Poll, PollOption
@ -15,5 +14,5 @@ class PollHandler(BaseHandler[Poll], ABC):
return self.event.question return self.event.question
@property @property
def options(self) -> List[PollOption]: def options(self) -> list[PollOption]:
return self.event.options return self.event.options

View file

@ -1,6 +1,6 @@
import hashlib import hashlib
import hmac import hmac
from typing import Any, Dict from typing import Any
def check_signature(token: str, hash: str, **kwargs: Any) -> bool: def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
@ -17,12 +17,14 @@ def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
secret = hashlib.sha256(token.encode("utf-8")) secret = hashlib.sha256(token.encode("utf-8"))
check_string = "\n".join(f"{k}={kwargs[k]}" for k in sorted(kwargs)) check_string = "\n".join(f"{k}={kwargs[k]}" for k in sorted(kwargs))
hmac_string = hmac.new( hmac_string = hmac.new(
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256 secret.digest(),
check_string.encode("utf-8"),
digestmod=hashlib.sha256,
).hexdigest() ).hexdigest()
return hmac_string == hash return hmac_string == hash
def check_integrity(token: str, data: Dict[str, Any]) -> bool: def check_integrity(token: str, data: dict[str, Any]) -> bool:
""" """
Verify the authentication and the integrity Verify the authentication and the integrity
of the data received on user's auth of the data received on user's auth

View file

@ -13,9 +13,11 @@ class BackoffConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.max_delay <= self.min_delay: if self.max_delay <= self.min_delay:
raise ValueError("`max_delay` should be greater than `min_delay`") msg = "`max_delay` should be greater than `min_delay`"
raise ValueError(msg)
if self.factor <= 1: if self.factor <= 1:
raise ValueError("`factor` should be greater than 1") msg = "`factor` should be greater than 1"
raise ValueError(msg)
class Backoff: class Backoff:

View file

@ -1,4 +1,5 @@
from typing import Any, Awaitable, Callable, Dict, Optional, Union from collections.abc import Awaitable, Callable
from typing import Any
from aiogram import BaseMiddleware, loggers from aiogram import BaseMiddleware, loggers
from aiogram.dispatcher.flags import get_flag from aiogram.dispatcher.flags import get_flag
@ -12,10 +13,10 @@ class CallbackAnswer:
self, self,
answered: bool, answered: bool,
disabled: bool = False, disabled: bool = False,
text: Optional[str] = None, text: str | None = None,
show_alert: Optional[bool] = None, show_alert: bool | None = None,
url: Optional[str] = None, url: str | None = None,
cache_time: Optional[int] = None, cache_time: int | None = None,
) -> None: ) -> None:
""" """
Callback answer configuration Callback answer configuration
@ -48,7 +49,8 @@ class CallbackAnswer:
@disabled.setter @disabled.setter
def disabled(self, value: bool) -> None: def disabled(self, value: bool) -> None:
if self._answered: if self._answered:
raise CallbackAnswerException("Can't change disabled state after answer") msg = "Can't change disabled state after answer"
raise CallbackAnswerException(msg)
self._disabled = value self._disabled = value
@property @property
@ -59,7 +61,7 @@ class CallbackAnswer:
return self._answered return self._answered
@property @property
def text(self) -> Optional[str]: def text(self) -> str | None:
""" """
Response text Response text
:return: :return:
@ -67,48 +69,52 @@ class CallbackAnswer:
return self._text return self._text
@text.setter @text.setter
def text(self, value: Optional[str]) -> None: def text(self, value: str | None) -> None:
if self._answered: if self._answered:
raise CallbackAnswerException("Can't change text after answer") msg = "Can't change text after answer"
raise CallbackAnswerException(msg)
self._text = value self._text = value
@property @property
def show_alert(self) -> Optional[bool]: def show_alert(self) -> bool | None:
""" """
Whether to display an alert Whether to display an alert
""" """
return self._show_alert return self._show_alert
@show_alert.setter @show_alert.setter
def show_alert(self, value: Optional[bool]) -> None: def show_alert(self, value: bool | None) -> None:
if self._answered: if self._answered:
raise CallbackAnswerException("Can't change show_alert after answer") msg = "Can't change show_alert after answer"
raise CallbackAnswerException(msg)
self._show_alert = value self._show_alert = value
@property @property
def url(self) -> Optional[str]: def url(self) -> str | None:
""" """
Game url Game url
""" """
return self._url return self._url
@url.setter @url.setter
def url(self, value: Optional[str]) -> None: def url(self, value: str | None) -> None:
if self._answered: if self._answered:
raise CallbackAnswerException("Can't change url after answer") msg = "Can't change url after answer"
raise CallbackAnswerException(msg)
self._url = value self._url = value
@property @property
def cache_time(self) -> Optional[int]: def cache_time(self) -> int | None:
""" """
Response cache time Response cache time
""" """
return self._cache_time return self._cache_time
@cache_time.setter @cache_time.setter
def cache_time(self, value: Optional[int]) -> None: def cache_time(self, value: int | None) -> None:
if self._answered: if self._answered:
raise CallbackAnswerException("Can't change cache_time after answer") msg = "Can't change cache_time after answer"
raise CallbackAnswerException(msg)
self._cache_time = value self._cache_time = value
def __str__(self) -> str: def __str__(self) -> str:
@ -131,10 +137,10 @@ class CallbackAnswerMiddleware(BaseMiddleware):
def __init__( def __init__(
self, self,
pre: bool = False, pre: bool = False,
text: Optional[str] = None, text: str | None = None,
show_alert: Optional[bool] = None, show_alert: bool | None = None,
url: Optional[str] = None, url: str | None = None,
cache_time: Optional[int] = None, cache_time: int | None = None,
) -> None: ) -> None:
""" """
Inner middleware for callback query handlers, can be useful in bots with a lot of callback Inner middleware for callback query handlers, can be useful in bots with a lot of callback
@ -154,15 +160,15 @@ class CallbackAnswerMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
if not isinstance(event, CallbackQuery): if not isinstance(event, CallbackQuery):
return await handler(event, data) return await handler(event, data)
callback_answer = data["callback_answer"] = self.construct_callback_answer( callback_answer = data["callback_answer"] = self.construct_callback_answer(
properties=get_flag(data, "callback_answer") properties=get_flag(data, "callback_answer"),
) )
if not callback_answer.disabled and callback_answer.answered: if not callback_answer.disabled and callback_answer.answered:
@ -174,7 +180,8 @@ class CallbackAnswerMiddleware(BaseMiddleware):
await self.answer(event, callback_answer) await self.answer(event, callback_answer)
def construct_callback_answer( def construct_callback_answer(
self, properties: Optional[Union[Dict[str, Any], bool]] self,
properties: dict[str, Any] | bool | None,
) -> CallbackAnswer: ) -> CallbackAnswer:
pre, disabled, text, show_alert, url, cache_time = ( pre, disabled, text, show_alert, url, cache_time = (
self.pre, self.pre,

View file

@ -2,9 +2,10 @@ import asyncio
import logging import logging
import time import time
from asyncio import Event, Lock from asyncio import Event, Lock
from collections.abc import Awaitable, Callable
from contextlib import suppress from contextlib import suppress
from types import TracebackType from types import TracebackType
from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union from typing import Any
from aiogram import BaseMiddleware, Bot from aiogram import BaseMiddleware, Bot
from aiogram.dispatcher.flags import get_flag from aiogram.dispatcher.flags import get_flag
@ -32,8 +33,8 @@ class ChatActionSender:
self, self,
*, *,
bot: Bot, bot: Bot,
chat_id: Union[str, int], chat_id: str | int,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
action: str = "typing", action: str = "typing",
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
@ -56,7 +57,7 @@ class ChatActionSender:
self._lock = Lock() self._lock = Lock()
self._close_event = Event() self._close_event = Event()
self._closed_event = Event() self._closed_event = Event()
self._task: Optional[asyncio.Task[Any]] = None self._task: asyncio.Task[Any] | None = None
@property @property
def running(self) -> bool: def running(self) -> bool:
@ -108,7 +109,8 @@ class ChatActionSender:
self._close_event.clear() self._close_event.clear()
self._closed_event.clear() self._closed_event.clear()
if self.running: if self.running:
raise RuntimeError("Already running") msg = "Already running"
raise RuntimeError(msg)
self._task = asyncio.create_task(self._worker()) self._task = asyncio.create_task(self._worker())
async def _stop(self) -> None: async def _stop(self) -> None:
@ -126,18 +128,18 @@ class ChatActionSender:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_value: Optional[BaseException], exc_value: BaseException | None,
traceback: Optional[TracebackType], traceback: TracebackType | None,
) -> Any: ) -> Any:
await self._stop() await self._stop()
@classmethod @classmethod
def typing( def typing(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -154,9 +156,9 @@ class ChatActionSender:
@classmethod @classmethod
def upload_photo( def upload_photo(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -173,9 +175,9 @@ class ChatActionSender:
@classmethod @classmethod
def record_video( def record_video(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -192,9 +194,9 @@ class ChatActionSender:
@classmethod @classmethod
def upload_video( def upload_video(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -211,9 +213,9 @@ class ChatActionSender:
@classmethod @classmethod
def record_voice( def record_voice(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -230,9 +232,9 @@ class ChatActionSender:
@classmethod @classmethod
def upload_voice( def upload_voice(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -249,9 +251,9 @@ class ChatActionSender:
@classmethod @classmethod
def upload_document( def upload_document(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -268,9 +270,9 @@ class ChatActionSender:
@classmethod @classmethod
def choose_sticker( def choose_sticker(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -287,9 +289,9 @@ class ChatActionSender:
@classmethod @classmethod
def find_location( def find_location(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -306,9 +308,9 @@ class ChatActionSender:
@classmethod @classmethod
def record_video_note( def record_video_note(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -325,9 +327,9 @@ class ChatActionSender:
@classmethod @classmethod
def upload_video_note( def upload_video_note(
cls, cls,
chat_id: Union[int, str], chat_id: int | str,
bot: Bot, bot: Bot,
message_thread_id: Optional[int] = None, message_thread_id: int | None = None,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -349,9 +351,9 @@ class ChatActionMiddleware(BaseMiddleware):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
if not isinstance(event, Message): if not isinstance(event, Message):
return await handler(event, data) return await handler(event, data)

View file

@ -1,7 +1,6 @@
from typing import Tuple, Type, Union from typing import Annotated
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from aiogram.types import ( from aiogram.types import (
ChatMember, ChatMember,
@ -13,22 +12,22 @@ from aiogram.types import (
ChatMemberRestricted, ChatMemberRestricted,
) )
ChatMemberUnion = Union[ ChatMemberUnion = (
ChatMemberOwner, ChatMemberOwner
ChatMemberAdministrator, | ChatMemberAdministrator
ChatMemberMember, | ChatMemberMember
ChatMemberRestricted, | ChatMemberRestricted
ChatMemberLeft, | ChatMemberLeft
ChatMemberBanned, | ChatMemberBanned
] )
ChatMemberCollection = Tuple[Type[ChatMember], ...] ChatMemberCollection = tuple[type[ChatMember], ...]
ChatMemberAdapter: TypeAdapter[ChatMemberUnion] = TypeAdapter( ChatMemberAdapter: TypeAdapter[ChatMemberUnion] = TypeAdapter(
Annotated[ Annotated[
ChatMemberUnion, ChatMemberUnion,
Field(discriminator="status"), Field(discriminator="status"),
] ],
) )
ADMINS: ChatMemberCollection = (ChatMemberOwner, ChatMemberAdministrator) ADMINS: ChatMemberCollection = (ChatMemberOwner, ChatMemberAdministrator)

View file

@ -1,7 +1,8 @@
import inspect import inspect
from collections.abc import Generator
from dataclasses import dataclass from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
from typing import Any, Generator, NamedTuple, Protocol from typing import Any, NamedTuple, Protocol
from aiogram.utils.dataclass import dataclass_kwargs from aiogram.utils.dataclass import dataclass_kwargs

View file

@ -9,16 +9,16 @@ from typing import Any, Union
def dataclass_kwargs( def dataclass_kwargs(
init: Union[bool, None] = None, init: bool | None = None,
repr: Union[bool, None] = None, repr: bool | None = None,
eq: Union[bool, None] = None, eq: bool | None = None,
order: Union[bool, None] = None, order: bool | None = None,
unsafe_hash: Union[bool, None] = None, unsafe_hash: bool | None = None,
frozen: Union[bool, None] = None, frozen: bool | None = None,
match_args: Union[bool, None] = None, match_args: bool | None = None,
kw_only: Union[bool, None] = None, kw_only: bool | None = None,
slots: Union[bool, None] = None, slots: bool | None = None,
weakref_slot: Union[bool, None] = None, weakref_slot: bool | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Generates a dictionary of keyword arguments that can be passed to a Python Generates a dictionary of keyword arguments that can be passed to a Python
@ -48,13 +48,12 @@ def dataclass_kwargs(
params["frozen"] = frozen params["frozen"] = frozen
# Added in 3.10 # Added in 3.10
if sys.version_info >= (3, 10): if match_args is not None:
if match_args is not None: params["match_args"] = match_args
params["match_args"] = match_args if kw_only is not None:
if kw_only is not None: params["kw_only"] = kw_only
params["kw_only"] = kw_only if slots is not None:
if slots is not None: params["slots"] = slots
params["slots"] = slots
# Added in 3.11 # Added in 3.11
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):

View file

@ -1,22 +1,24 @@
from __future__ import annotations from __future__ import annotations
__all__ = [ __all__ = [
"create_start_link",
"create_startgroup_link",
"create_startapp_link",
"create_deep_link", "create_deep_link",
"create_start_link",
"create_startapp_link",
"create_startgroup_link",
"create_telegram_link", "create_telegram_link",
"encode_payload",
"decode_payload", "decode_payload",
"encode_payload",
] ]
import re import re
from typing import TYPE_CHECKING, Callable, Literal, Optional, cast from typing import TYPE_CHECKING, Literal, Optional, cast
from aiogram.utils.link import create_telegram_link from aiogram.utils.link import create_telegram_link
from aiogram.utils.payload import decode_payload, encode_payload from aiogram.utils.payload import decode_payload, encode_payload
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from aiogram import Bot from aiogram import Bot
BAD_PATTERN = re.compile(r"[^a-zA-Z0-9-_]") BAD_PATTERN = re.compile(r"[^a-zA-Z0-9-_]")
@ -26,7 +28,7 @@ async def create_start_link(
bot: Bot, bot: Bot,
payload: str, payload: str,
encode: bool = False, encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None, encoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
""" """
Create 'start' deep link with your payload. Create 'start' deep link with your payload.
@ -53,7 +55,7 @@ async def create_startgroup_link(
bot: Bot, bot: Bot,
payload: str, payload: str,
encode: bool = False, encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None, encoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
""" """
Create 'startgroup' deep link with your payload. Create 'startgroup' deep link with your payload.
@ -80,8 +82,8 @@ async def create_startapp_link(
bot: Bot, bot: Bot,
payload: str, payload: str,
encode: bool = False, encode: bool = False,
app_name: Optional[str] = None, app_name: str | None = None,
encoder: Optional[Callable[[bytes], bytes]] = None, encoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
""" """
Create 'startapp' deep link with your payload. Create 'startapp' deep link with your payload.
@ -115,9 +117,9 @@ def create_deep_link(
username: str, username: str,
link_type: Literal["start", "startgroup", "startapp"], link_type: Literal["start", "startgroup", "startapp"],
payload: str, payload: str,
app_name: Optional[str] = None, app_name: str | None = None,
encode: bool = False, encode: bool = False,
encoder: Optional[Callable[[bytes], bytes]] = None, encoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
""" """
Create deep link. Create deep link.
@ -137,13 +139,15 @@ def create_deep_link(
payload = encode_payload(payload, encoder=encoder) payload = encode_payload(payload, encoder=encoder)
if re.search(BAD_PATTERN, payload): if re.search(BAD_PATTERN, payload):
raise ValueError( msg = (
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. " "Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
"Pass `encode=True` or encode payload manually." "Pass `encode=True` or encode payload manually."
) )
raise ValueError(msg)
if len(payload) > 64: if len(payload) > 64:
raise ValueError("Payload must be up to 64 characters long.") msg = "Payload must be up to 64 characters long."
raise ValueError(msg)
if not app_name: if not app_name:
deep_link = create_telegram_link(username, **{cast(str, link_type): payload}) deep_link = create_telegram_link(username, **{cast(str, link_type): payload})

View file

@ -1,16 +1,8 @@
from __future__ import annotations
import textwrap import textwrap
from typing import ( from collections.abc import Generator, Iterable, Iterator
Any, from typing import TYPE_CHECKING, Any, ClassVar
ClassVar,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
)
from typing_extensions import Self from typing_extensions import Self
@ -35,7 +27,7 @@ class Text(Iterable[NodeType]):
Simple text element Simple text element
""" """
type: ClassVar[Optional[str]] = None type: ClassVar[str | None] = None
__slots__ = ("_body", "_params") __slots__ = ("_body", "_params")
@ -44,16 +36,16 @@ class Text(Iterable[NodeType]):
*body: NodeType, *body: NodeType,
**params: Any, **params: Any,
) -> None: ) -> None:
self._body: Tuple[NodeType, ...] = body self._body: tuple[NodeType, ...] = body
self._params: Dict[str, Any] = params self._params: dict[str, Any] = params
@classmethod @classmethod
def from_entities(cls, text: str, entities: List[MessageEntity]) -> "Text": def from_entities(cls, text: str, entities: list[MessageEntity]) -> Text:
return cls( return cls(
*_unparse_entities( *_unparse_entities(
text=add_surrogates(text), text=add_surrogates(text),
entities=sorted(entities, key=lambda item: item.offset) if entities else [], entities=sorted(entities, key=lambda item: item.offset) if entities else [],
) ),
) )
def render( def render(
@ -62,7 +54,7 @@ class Text(Iterable[NodeType]):
_offset: int = 0, _offset: int = 0,
_sort: bool = True, _sort: bool = True,
_collect_entities: bool = True, _collect_entities: bool = True,
) -> Tuple[str, List[MessageEntity]]: ) -> tuple[str, list[MessageEntity]]:
""" """
Render elements tree as text with entities list Render elements tree as text with entities list
@ -108,7 +100,7 @@ class Text(Iterable[NodeType]):
entities_key: str = "entities", entities_key: str = "entities",
replace_parse_mode: bool = True, replace_parse_mode: bool = True,
parse_mode_key: str = "parse_mode", parse_mode_key: str = "parse_mode",
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Render element tree as keyword arguments for usage in an API call, for example: Render element tree as keyword arguments for usage in an API call, for example:
@ -124,7 +116,7 @@ class Text(Iterable[NodeType]):
:return: :return:
""" """
text_value, entities_value = self.render() text_value, entities_value = self.render()
result: Dict[str, Any] = { result: dict[str, Any] = {
text_key: text_value, text_key: text_value,
entities_key: entities_value, entities_key: entities_value,
} }
@ -132,7 +124,7 @@ class Text(Iterable[NodeType]):
result[parse_mode_key] = None result[parse_mode_key] = None
return result return result
def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]: def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
""" """
Shortcut for :meth:`as_kwargs` for usage with API calls that take Shortcut for :meth:`as_kwargs` for usage with API calls that take
``caption`` as a parameter. ``caption`` as a parameter.
@ -151,7 +143,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode, replace_parse_mode=replace_parse_mode,
) )
def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]: def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
""" """
Shortcut for :meth:`as_kwargs` for usage with Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_poll.SendPoll`. method :class:`aiogram.methods.send_poll.SendPoll`.
@ -171,7 +163,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode, replace_parse_mode=replace_parse_mode,
) )
def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]: def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
""" """
Shortcut for :meth:`as_kwargs` for usage with Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_poll.SendPoll`. method :class:`aiogram.methods.send_poll.SendPoll`.
@ -196,7 +188,7 @@ class Text(Iterable[NodeType]):
replace_parse_mode=replace_parse_mode, replace_parse_mode=replace_parse_mode,
) )
def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]: def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
""" """
Shortcut for :meth:`as_kwargs` for usage with Shortcut for :meth:`as_kwargs` for usage with
method :class:`aiogram.methods.send_gift.SendGift`. method :class:`aiogram.methods.send_gift.SendGift`.
@ -252,7 +244,7 @@ class Text(Iterable[NodeType]):
args_str = textwrap.indent("\n" + args_str + "\n", " ") args_str = textwrap.indent("\n" + args_str + "\n", " ")
return f"{type(self).__name__}({args_str})" return f"{type(self).__name__}({args_str})"
def __add__(self, other: NodeType) -> "Text": def __add__(self, other: NodeType) -> Text:
if isinstance(other, Text) and other.type == self.type and self._params == other._params: if isinstance(other, Text) and other.type == self.type and self._params == other._params:
return type(self)(*self, *other, **self._params) return type(self)(*self, *other, **self._params)
if type(self) is Text and isinstance(other, str): if type(self) is Text and isinstance(other, str):
@ -266,9 +258,10 @@ class Text(Iterable[NodeType]):
text, _ = self.render(_collect_entities=False) text, _ = self.render(_collect_entities=False)
return sizeof(text) return sizeof(text)
def __getitem__(self, item: slice) -> "Text": def __getitem__(self, item: slice) -> Text:
if not isinstance(item, slice): if not isinstance(item, slice):
raise TypeError("Can only be sliced") msg = "Can only be sliced"
raise TypeError(msg)
if (item.start is None or item.start == 0) and item.stop is None: if (item.start is None or item.start == 0) and item.stop is None:
return self.replace(*self._body) return self.replace(*self._body)
start = 0 if item.start is None else item.start start = 0 if item.start is None else item.start
@ -313,9 +306,11 @@ class HashTag(Text):
def __init__(self, *body: NodeType, **params: Any) -> None: def __init__(self, *body: NodeType, **params: Any) -> None:
if len(body) != 1: if len(body) != 1:
raise ValueError("Hashtag can contain only one element") msg = "Hashtag can contain only one element"
raise ValueError(msg)
if not isinstance(body[0], str): if not isinstance(body[0], str):
raise ValueError("Hashtag can contain only string") msg = "Hashtag can contain only string"
raise ValueError(msg)
if not body[0].startswith("#"): if not body[0].startswith("#"):
body = ("#" + body[0],) body = ("#" + body[0],)
super().__init__(*body, **params) super().__init__(*body, **params)
@ -337,9 +332,11 @@ class CashTag(Text):
def __init__(self, *body: NodeType, **params: Any) -> None: def __init__(self, *body: NodeType, **params: Any) -> None:
if len(body) != 1: if len(body) != 1:
raise ValueError("Cashtag can contain only one element") msg = "Cashtag can contain only one element"
raise ValueError(msg)
if not isinstance(body[0], str): if not isinstance(body[0], str):
raise ValueError("Cashtag can contain only string") msg = "Cashtag can contain only string"
raise ValueError(msg)
if not body[0].startswith("$"): if not body[0].startswith("$"):
body = ("$" + body[0],) body = ("$" + body[0],)
super().__init__(*body, **params) super().__init__(*body, **params)
@ -469,7 +466,7 @@ class Pre(Text):
type = MessageEntityType.PRE type = MessageEntityType.PRE
def __init__(self, *body: NodeType, language: Optional[str] = None, **params: Any) -> None: def __init__(self, *body: NodeType, language: str | None = None, **params: Any) -> None:
super().__init__(*body, language=language, **params) super().__init__(*body, language=language, **params)
@ -537,7 +534,7 @@ class ExpandableBlockQuote(Text):
type = MessageEntityType.EXPANDABLE_BLOCKQUOTE type = MessageEntityType.EXPANDABLE_BLOCKQUOTE
NODE_TYPES: Dict[Optional[str], Type[Text]] = { NODE_TYPES: dict[str | None, type[Text]] = {
Text.type: Text, Text.type: Text,
HashTag.type: HashTag, HashTag.type: HashTag,
CashTag.type: CashTag, CashTag.type: CashTag,
@ -570,15 +567,16 @@ def _apply_entity(entity: MessageEntity, *nodes: NodeType) -> NodeType:
""" """
node_type = NODE_TYPES.get(entity.type, Text) node_type = NODE_TYPES.get(entity.type, Text)
return node_type( return node_type(
*nodes, **entity.model_dump(exclude={"type", "offset", "length"}, warnings=False) *nodes,
**entity.model_dump(exclude={"type", "offset", "length"}, warnings=False),
) )
def _unparse_entities( def _unparse_entities(
text: bytes, text: bytes,
entities: List[MessageEntity], entities: list[MessageEntity],
offset: Optional[int] = None, offset: int | None = None,
length: Optional[int] = None, length: int | None = None,
) -> Generator[NodeType, None, None]: ) -> Generator[NodeType, None, None]:
if offset is None: if offset is None:
offset = 0 offset = 0
@ -615,8 +613,7 @@ def as_line(*items: NodeType, end: str = "\n", sep: str = "") -> Text:
nodes = [] nodes = []
for item in items[:-1]: for item in items[:-1]:
nodes.extend([item, sep]) nodes.extend([item, sep])
nodes.append(items[-1]) nodes.extend([items[-1], end])
nodes.append(end)
else: else:
nodes = [*items, end] nodes = [*items, end]
return Text(*nodes) return Text(*nodes)

View file

@ -8,14 +8,14 @@ from .middleware import (
) )
__all__ = ( __all__ = (
"ConstI18nMiddleware",
"FSMI18nMiddleware",
"I18n", "I18n",
"I18nMiddleware", "I18nMiddleware",
"SimpleI18nMiddleware", "SimpleI18nMiddleware",
"ConstI18nMiddleware", "get_i18n",
"FSMI18nMiddleware",
"gettext", "gettext",
"lazy_gettext", "lazy_gettext",
"ngettext",
"lazy_ngettext", "lazy_ngettext",
"get_i18n", "ngettext",
) )

View file

@ -7,7 +7,8 @@ from aiogram.utils.i18n.lazy_proxy import LazyProxy
def get_i18n() -> I18n: def get_i18n() -> I18n:
i18n = I18n.get_current(no_error=True) i18n = I18n.get_current(no_error=True)
if i18n is None: if i18n is None:
raise LookupError("I18n context is not set") msg = "I18n context is not set"
raise LookupError(msg)
return i18n return i18n

View file

@ -1,23 +1,27 @@
from __future__ import annotations
import gettext import gettext
import os
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from typing import Dict, Generator, Optional, Tuple, Union from typing import TYPE_CHECKING
from aiogram.utils.i18n.lazy_proxy import LazyProxy from aiogram.utils.i18n.lazy_proxy import LazyProxy
from aiogram.utils.mixins import ContextInstanceMixin from aiogram.utils.mixins import ContextInstanceMixin
if TYPE_CHECKING:
from collections.abc import Generator
class I18n(ContextInstanceMixin["I18n"]): class I18n(ContextInstanceMixin["I18n"]):
def __init__( def __init__(
self, self,
*, *,
path: Union[str, Path], path: str | Path,
default_locale: str = "en", default_locale: str = "en",
domain: str = "messages", domain: str = "messages",
) -> None: ) -> None:
self.path = path self.path = Path(path)
self.default_locale = default_locale self.default_locale = default_locale
self.domain = domain self.domain = domain
self.ctx_locale = ContextVar("aiogram_ctx_locale", default=default_locale) self.ctx_locale = ContextVar("aiogram_ctx_locale", default=default_locale)
@ -43,7 +47,7 @@ class I18n(ContextInstanceMixin["I18n"]):
self.ctx_locale.reset(ctx_token) self.ctx_locale.reset(ctx_token)
@contextmanager @contextmanager
def context(self) -> Generator["I18n", None, None]: def context(self) -> Generator[I18n, None, None]:
""" """
Use I18n context Use I18n context
""" """
@ -53,24 +57,25 @@ class I18n(ContextInstanceMixin["I18n"]):
finally: finally:
self.reset_current(token) self.reset_current(token)
def find_locales(self) -> Dict[str, gettext.GNUTranslations]: def find_locales(self) -> dict[str, gettext.GNUTranslations]:
""" """
Load all compiled locales from path Load all compiled locales from path
:return: dict with locales :return: dict with locales
""" """
translations: Dict[str, gettext.GNUTranslations] = {} translations: dict[str, gettext.GNUTranslations] = {}
for name in os.listdir(self.path): for name in self.path.iterdir():
if not os.path.isdir(os.path.join(self.path, name)): if not (self.path / name).is_dir():
continue continue
mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo") mo_path = self.path / name / "LC_MESSAGES" / (self.domain + ".mo")
if os.path.exists(mo_path): if mo_path.exists():
with open(mo_path, "rb") as fp: with mo_path.open("rb") as fp:
translations[name] = gettext.GNUTranslations(fp) translations[name.name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + "po"): # pragma: no cover elif mo_path.with_suffix(".po").exists(): # pragma: no cover
raise RuntimeError(f"Found locale '{name}' but this language is not compiled!") msg = f"Found locale '{name.name}' but this language is not compiled!"
raise RuntimeError(msg)
return translations return translations
@ -81,7 +86,7 @@ class I18n(ContextInstanceMixin["I18n"]):
self.locales = self.find_locales() self.locales = self.find_locales()
@property @property
def available_locales(self) -> Tuple[str, ...]: def available_locales(self) -> tuple[str, ...]:
""" """
list of loaded locales list of loaded locales
@ -90,7 +95,11 @@ class I18n(ContextInstanceMixin["I18n"]):
return tuple(self.locales.keys()) return tuple(self.locales.keys())
def gettext( def gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None self,
singular: str,
plural: str | None = None,
n: int = 1,
locale: str | None = None,
) -> str: ) -> str:
""" """
Get text Get text
@ -107,7 +116,7 @@ class I18n(ContextInstanceMixin["I18n"]):
if locale not in self.locales: if locale not in self.locales:
if n == 1: if n == 1:
return singular return singular
return plural if plural else singular return plural or singular
translator = self.locales[locale] translator = self.locales[locale]
@ -116,8 +125,17 @@ class I18n(ContextInstanceMixin["I18n"]):
return translator.ngettext(singular, plural, n) return translator.ngettext(singular, plural, n)
def lazy_gettext( def lazy_gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None self,
singular: str,
plural: str | None = None,
n: int = 1,
locale: str | None = None,
) -> LazyProxy: ) -> LazyProxy:
return LazyProxy( return LazyProxy(
self.gettext, singular=singular, plural=plural, n=n, locale=locale, enable_cache=False self.gettext,
singular=singular,
plural=plural,
n=n,
locale=locale,
enable_cache=False,
) )

View file

@ -6,8 +6,9 @@ except ImportError: # pragma: no cover
class LazyProxy: # type: ignore class LazyProxy: # type: ignore
def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None: def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError( msg = (
"LazyProxy can be used only when Babel installed\n" "LazyProxy can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) " "Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)" "or aiogram with i18n support (`pip install aiogram[i18n]`)"
) )
raise RuntimeError(msg)

View file

@ -1,5 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, Optional, Set from typing import TYPE_CHECKING, Any
try: try:
from babel import Locale, UnknownLocaleError from babel import Locale, UnknownLocaleError
@ -11,9 +13,13 @@ except ImportError: # pragma: no cover
from aiogram import BaseMiddleware, Router from aiogram import BaseMiddleware, Router
from aiogram.fsm.context import FSMContext
from aiogram.types import TelegramObject, User if TYPE_CHECKING:
from aiogram.utils.i18n.core import I18n from collections.abc import Awaitable, Callable
from aiogram.fsm.context import FSMContext
from aiogram.types import TelegramObject, User
from aiogram.utils.i18n.core import I18n
class I18nMiddleware(BaseMiddleware, ABC): class I18nMiddleware(BaseMiddleware, ABC):
@ -24,7 +30,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
def __init__( def __init__(
self, self,
i18n: I18n, i18n: I18n,
i18n_key: Optional[str] = "i18n", i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware", middleware_key: str = "i18n_middleware",
) -> None: ) -> None:
""" """
@ -39,7 +45,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
self.middleware_key = middleware_key self.middleware_key = middleware_key
def setup( def setup(
self: BaseMiddleware, router: Router, exclude: Optional[Set[str]] = None self: BaseMiddleware,
router: Router,
exclude: set[str] | None = None,
) -> BaseMiddleware: ) -> BaseMiddleware:
""" """
Register middleware for all events in the Router Register middleware for all events in the Router
@ -59,9 +67,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
async def __call__( async def __call__(
self, self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: dict[str, Any],
) -> Any: ) -> Any:
current_locale = await self.get_locale(event=event, data=data) or self.i18n.default_locale current_locale = await self.get_locale(event=event, data=data) or self.i18n.default_locale
@ -74,7 +82,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
return await handler(event, data) return await handler(event, data)
@abstractmethod @abstractmethod
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
""" """
Detect current user locale based on event and context. Detect current user locale based on event and context.
@ -84,7 +92,6 @@ class I18nMiddleware(BaseMiddleware, ABC):
:param data: :param data:
:return: :return:
""" """
pass
class SimpleI18nMiddleware(I18nMiddleware): class SimpleI18nMiddleware(I18nMiddleware):
@ -97,27 +104,29 @@ class SimpleI18nMiddleware(I18nMiddleware):
def __init__( def __init__(
self, self,
i18n: I18n, i18n: I18n,
i18n_key: Optional[str] = "i18n", i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware", middleware_key: str = "i18n_middleware",
) -> None: ) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key) super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
if Locale is None: # pragma: no cover if Locale is None: # pragma: no cover
raise RuntimeError( msg = (
f"{type(self).__name__} can be used only when Babel installed\n" f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) " "Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)" "or aiogram with i18n support (`pip install aiogram[i18n]`)"
) )
raise RuntimeError(msg)
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
if Locale is None: # pragma: no cover if Locale is None: # pragma: no cover
raise RuntimeError( msg = (
f"{type(self).__name__} can be used only when Babel installed\n" f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) " "Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)" "or aiogram with i18n support (`pip install aiogram[i18n]`)"
) )
raise RuntimeError(msg)
event_from_user: Optional[User] = data.get("event_from_user", None) event_from_user: User | None = data.get("event_from_user")
if event_from_user is None or event_from_user.language_code is None: if event_from_user is None or event_from_user.language_code is None:
return self.i18n.default_locale return self.i18n.default_locale
try: try:
@ -139,13 +148,13 @@ class ConstI18nMiddleware(I18nMiddleware):
self, self,
locale: str, locale: str,
i18n: I18n, i18n: I18n,
i18n_key: Optional[str] = "i18n", i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware", middleware_key: str = "i18n_middleware",
) -> None: ) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key) super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
self.locale = locale self.locale = locale
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
return self.locale return self.locale
@ -158,14 +167,14 @@ class FSMI18nMiddleware(SimpleI18nMiddleware):
self, self,
i18n: I18n, i18n: I18n,
key: str = "locale", key: str = "locale",
i18n_key: Optional[str] = "i18n", i18n_key: str | None = "i18n",
middleware_key: str = "i18n_middleware", middleware_key: str = "i18n_middleware",
) -> None: ) -> None:
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key) super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
self.key = key self.key = key
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
fsm_context: Optional[FSMContext] = data.get("state") fsm_context: FSMContext | None = data.get("state")
locale = None locale = None
if fsm_context: if fsm_context:
fsm_data = await fsm_context.get_data() fsm_data = await fsm_context.get_data()

View file

@ -4,19 +4,7 @@ from abc import ABC
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from itertools import cycle as repeat_all from itertools import cycle as repeat_all
from typing import ( from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
TYPE_CHECKING,
Any,
Generator,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from aiogram.filters.callback_data import CallbackData from aiogram.filters.callback_data import CallbackData
from aiogram.types import ( from aiogram.types import (
@ -34,11 +22,14 @@ from aiogram.types import (
WebAppInfo, WebAppInfo,
) )
if TYPE_CHECKING:
from collections.abc import Generator, Iterable
ButtonType = TypeVar("ButtonType", InlineKeyboardButton, KeyboardButton) ButtonType = TypeVar("ButtonType", InlineKeyboardButton, KeyboardButton)
T = TypeVar("T") T = TypeVar("T")
class KeyboardBuilder(Generic[ButtonType], ABC): class KeyboardBuilder(ABC, Generic[ButtonType]):
""" """
Generic keyboard builder that helps to adjust your markup with defined shape of lines. Generic keyboard builder that helps to adjust your markup with defined shape of lines.
@ -50,16 +41,19 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
max_buttons: int = 0 max_buttons: int = 0
def __init__( def __init__(
self, button_type: Type[ButtonType], markup: Optional[List[List[ButtonType]]] = None self,
button_type: type[ButtonType],
markup: list[list[ButtonType]] | None = None,
) -> None: ) -> None:
if not issubclass(button_type, (InlineKeyboardButton, KeyboardButton)): if not issubclass(button_type, (InlineKeyboardButton, KeyboardButton)):
raise ValueError(f"Button type {button_type} are not allowed here") msg = f"Button type {button_type} are not allowed here"
self._button_type: Type[ButtonType] = button_type raise ValueError(msg)
self._button_type: type[ButtonType] = button_type
if markup: if markup:
self._validate_markup(markup) self._validate_markup(markup)
else: else:
markup = [] markup = []
self._markup: List[List[ButtonType]] = markup self._markup: list[list[ButtonType]] = markup
@property @property
def buttons(self) -> Generator[ButtonType, None, None]: def buttons(self) -> Generator[ButtonType, None, None]:
@ -79,9 +73,8 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
""" """
allowed = self._button_type allowed = self._button_type
if not isinstance(button, allowed): if not isinstance(button, allowed):
raise ValueError( msg = f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}"
f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}" raise ValueError(msg)
)
return True return True
def _validate_buttons(self, *buttons: ButtonType) -> bool: def _validate_buttons(self, *buttons: ButtonType) -> bool:
@ -93,7 +86,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
""" """
return all(map(self._validate_button, buttons)) return all(map(self._validate_button, buttons))
def _validate_row(self, row: List[ButtonType]) -> bool: def _validate_row(self, row: list[ButtonType]) -> bool:
""" """
Check that row of buttons are correct Check that row of buttons are correct
Row can be only list of allowed button types and has length 0 <= n <= 8 Row can be only list of allowed button types and has length 0 <= n <= 8
@ -102,16 +95,18 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
:return: :return:
""" """
if not isinstance(row, list): if not isinstance(row, list):
raise ValueError( msg = (
f"Row {row!r} should be type 'List[{self._button_type.__name__}]' " f"Row {row!r} should be type 'List[{self._button_type.__name__}]' "
f"not type {type(row).__name__}" f"not type {type(row).__name__}"
) )
raise ValueError(msg)
if len(row) > self.max_width: if len(row) > self.max_width:
raise ValueError(f"Row {row!r} is too long (max width: {self.max_width})") msg = f"Row {row!r} is too long (max width: {self.max_width})"
raise ValueError(msg)
self._validate_buttons(*row) self._validate_buttons(*row)
return True return True
def _validate_markup(self, markup: List[List[ButtonType]]) -> bool: def _validate_markup(self, markup: list[list[ButtonType]]) -> bool:
""" """
Check that passed markup has correct data structure Check that passed markup has correct data structure
Markup is list of lists of buttons Markup is list of lists of buttons
@ -121,15 +116,17 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
""" """
count = 0 count = 0
if not isinstance(markup, list): if not isinstance(markup, list):
raise ValueError( msg = (
f"Markup should be type 'List[List[{self._button_type.__name__}]]' " f"Markup should be type 'List[List[{self._button_type.__name__}]]' "
f"not type {type(markup).__name__!r}" f"not type {type(markup).__name__!r}"
) )
raise ValueError(msg)
for row in markup: for row in markup:
self._validate_row(row) self._validate_row(row)
count += len(row) count += len(row)
if count > self.max_buttons: if count > self.max_buttons:
raise ValueError(f"Too much buttons detected Max allowed count - {self.max_buttons}") msg = f"Too much buttons detected Max allowed count - {self.max_buttons}"
raise ValueError(msg)
return True return True
def _validate_size(self, size: Any) -> int: def _validate_size(self, size: Any) -> int:
@ -140,14 +137,14 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
:return: :return:
""" """
if not isinstance(size, int): if not isinstance(size, int):
raise ValueError("Only int sizes are allowed") msg = "Only int sizes are allowed"
raise ValueError(msg)
if size not in range(self.min_width, self.max_width + 1): if size not in range(self.min_width, self.max_width + 1):
raise ValueError( msg = f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]"
f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]" raise ValueError(msg)
)
return size return size
def export(self) -> List[List[ButtonType]]: def export(self) -> list[list[ButtonType]]:
""" """
Export configured markup as list of lists of buttons Export configured markup as list of lists of buttons
@ -161,7 +158,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
""" """
return deepcopy(self._markup) return deepcopy(self._markup)
def add(self, *buttons: ButtonType) -> "KeyboardBuilder[ButtonType]": def add(self, *buttons: ButtonType) -> KeyboardBuilder[ButtonType]:
""" """
Add one or many buttons to markup. Add one or many buttons to markup.
@ -189,9 +186,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
self._markup = markup self._markup = markup
return self return self
def row( def row(self, *buttons: ButtonType, width: int | None = None) -> KeyboardBuilder[ButtonType]:
self, *buttons: ButtonType, width: Optional[int] = None
) -> "KeyboardBuilder[ButtonType]":
""" """
Add row to markup Add row to markup
@ -211,7 +206,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
) )
return self return self
def adjust(self, *sizes: int, repeat: bool = False) -> "KeyboardBuilder[ButtonType]": def adjust(self, *sizes: int, repeat: bool = False) -> KeyboardBuilder[ButtonType]:
""" """
Adjust previously added buttons to specific row sizes. Adjust previously added buttons to specific row sizes.
@ -232,7 +227,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
size = next(sizes_iter) size = next(sizes_iter)
markup = [] markup = []
row: List[ButtonType] = [] row: list[ButtonType] = []
for button in self.buttons: for button in self.buttons:
if len(row) >= size: if len(row) >= size:
markup.append(row) markup.append(row)
@ -244,33 +239,35 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
self._markup = markup self._markup = markup
return self return self
def _button(self, **kwargs: Any) -> "KeyboardBuilder[ButtonType]": def _button(self, **kwargs: Any) -> KeyboardBuilder[ButtonType]:
""" """
Add button to markup Add button to markup
:param kwargs: :param kwargs:
:return: :return:
""" """
if isinstance(callback_data := kwargs.get("callback_data", None), CallbackData): if isinstance(callback_data := kwargs.get("callback_data"), CallbackData):
kwargs["callback_data"] = callback_data.pack() kwargs["callback_data"] = callback_data.pack()
button = self._button_type(**kwargs) button = self._button_type(**kwargs)
return self.add(button) return self.add(button)
def as_markup(self, **kwargs: Any) -> Union[InlineKeyboardMarkup, ReplyKeyboardMarkup]: def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup | ReplyKeyboardMarkup:
if self._button_type is KeyboardButton: if self._button_type is KeyboardButton:
keyboard = cast(List[List[KeyboardButton]], self.export()) # type: ignore keyboard = cast(list[list[KeyboardButton]], self.export()) # type: ignore
return ReplyKeyboardMarkup(keyboard=keyboard, **kwargs) return ReplyKeyboardMarkup(keyboard=keyboard, **kwargs)
inline_keyboard = cast(List[List[InlineKeyboardButton]], self.export()) # type: ignore inline_keyboard = cast(list[list[InlineKeyboardButton]], self.export()) # type: ignore
return InlineKeyboardMarkup(inline_keyboard=inline_keyboard) return InlineKeyboardMarkup(inline_keyboard=inline_keyboard)
def attach(self, builder: "KeyboardBuilder[ButtonType]") -> "KeyboardBuilder[ButtonType]": def attach(self, builder: KeyboardBuilder[ButtonType]) -> KeyboardBuilder[ButtonType]:
if not isinstance(builder, KeyboardBuilder): if not isinstance(builder, KeyboardBuilder):
raise ValueError(f"Only KeyboardBuilder can be attached, not {type(builder).__name__}") msg = f"Only KeyboardBuilder can be attached, not {type(builder).__name__}"
raise ValueError(msg)
if builder._button_type is not self._button_type: if builder._button_type is not self._button_type:
raise ValueError( msg = (
f"Only builders with same button type can be attached, " f"Only builders with same button type can be attached, "
f"not {self._button_type.__name__} and {builder._button_type.__name__}" f"not {self._button_type.__name__} and {builder._button_type.__name__}"
) )
raise ValueError(msg)
self._markup.extend(builder.export()) self._markup.extend(builder.export())
return self return self
@ -306,18 +303,18 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
self, self,
*, *,
text: str, text: str,
url: Optional[str] = None, url: str | None = None,
callback_data: Optional[Union[str, CallbackData]] = None, callback_data: str | CallbackData | None = None,
web_app: Optional[WebAppInfo] = None, web_app: WebAppInfo | None = None,
login_url: Optional[LoginUrl] = None, login_url: LoginUrl | None = None,
switch_inline_query: Optional[str] = None, switch_inline_query: str | None = None,
switch_inline_query_current_chat: Optional[str] = None, switch_inline_query_current_chat: str | None = None,
switch_inline_query_chosen_chat: Optional[SwitchInlineQueryChosenChat] = None, switch_inline_query_chosen_chat: SwitchInlineQueryChosenChat | None = None,
copy_text: Optional[CopyTextButton] = None, copy_text: CopyTextButton | None = None,
callback_game: Optional[CallbackGame] = None, callback_game: CallbackGame | None = None,
pay: Optional[bool] = None, pay: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> "InlineKeyboardBuilder": ) -> InlineKeyboardBuilder:
return cast( return cast(
InlineKeyboardBuilder, InlineKeyboardBuilder,
self._button( self._button(
@ -340,10 +337,10 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
"""Construct an InlineKeyboardMarkup""" """Construct an InlineKeyboardMarkup"""
return cast(InlineKeyboardMarkup, super().as_markup(**kwargs)) return cast(InlineKeyboardMarkup, super().as_markup(**kwargs))
def __init__(self, markup: Optional[List[List[InlineKeyboardButton]]] = None) -> None: def __init__(self, markup: list[list[InlineKeyboardButton]] | None = None) -> None:
super().__init__(button_type=InlineKeyboardButton, markup=markup) super().__init__(button_type=InlineKeyboardButton, markup=markup)
def copy(self: "InlineKeyboardBuilder") -> "InlineKeyboardBuilder": def copy(self: InlineKeyboardBuilder) -> InlineKeyboardBuilder:
""" """
Make full copy of current builder with markup Make full copy of current builder with markup
@ -353,8 +350,9 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
@classmethod @classmethod
def from_markup( def from_markup(
cls: Type["InlineKeyboardBuilder"], markup: InlineKeyboardMarkup cls: type[InlineKeyboardBuilder],
) -> "InlineKeyboardBuilder": markup: InlineKeyboardMarkup,
) -> InlineKeyboardBuilder:
""" """
Create builder from existing markup Create builder from existing markup
@ -377,14 +375,14 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
self, self,
*, *,
text: str, text: str,
request_users: Optional[KeyboardButtonRequestUsers] = None, request_users: KeyboardButtonRequestUsers | None = None,
request_chat: Optional[KeyboardButtonRequestChat] = None, request_chat: KeyboardButtonRequestChat | None = None,
request_contact: Optional[bool] = None, request_contact: bool | None = None,
request_location: Optional[bool] = None, request_location: bool | None = None,
request_poll: Optional[KeyboardButtonPollType] = None, request_poll: KeyboardButtonPollType | None = None,
web_app: Optional[WebAppInfo] = None, web_app: WebAppInfo | None = None,
**kwargs: Any, **kwargs: Any,
) -> "ReplyKeyboardBuilder": ) -> ReplyKeyboardBuilder:
return cast( return cast(
ReplyKeyboardBuilder, ReplyKeyboardBuilder,
self._button( self._button(
@ -403,10 +401,10 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
"""Construct a ReplyKeyboardMarkup""" """Construct a ReplyKeyboardMarkup"""
return cast(ReplyKeyboardMarkup, super().as_markup(**kwargs)) return cast(ReplyKeyboardMarkup, super().as_markup(**kwargs))
def __init__(self, markup: Optional[List[List[KeyboardButton]]] = None) -> None: def __init__(self, markup: list[list[KeyboardButton]] | None = None) -> None:
super().__init__(button_type=KeyboardButton, markup=markup) super().__init__(button_type=KeyboardButton, markup=markup)
def copy(self: "ReplyKeyboardBuilder") -> "ReplyKeyboardBuilder": def copy(self: ReplyKeyboardBuilder) -> ReplyKeyboardBuilder:
""" """
Make full copy of current builder with markup Make full copy of current builder with markup
@ -415,7 +413,7 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
return ReplyKeyboardBuilder(markup=self.export()) return ReplyKeyboardBuilder(markup=self.export())
@classmethod @classmethod
def from_markup(cls, markup: ReplyKeyboardMarkup) -> "ReplyKeyboardBuilder": def from_markup(cls, markup: ReplyKeyboardMarkup) -> ReplyKeyboardBuilder:
""" """
Create builder from existing markup Create builder from existing markup

View file

@ -7,7 +7,7 @@ BRANCH = "dev-3.x"
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/" BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query: Any) -> str: def _format_url(url: str, *path: str, fragment_: str | None = None, **query: Any) -> str:
url = urljoin(url, "/".join(path), allow_fragments=True) url = urljoin(url, "/".join(path), allow_fragments=True)
if query: if query:
url += "?" + urlencode(query) url += "?" + urlencode(query)
@ -16,7 +16,7 @@ def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query:
return url return url
def docs_url(*path: str, fragment_: Optional[str] = None, **query: Any) -> str: def docs_url(*path: str, fragment_: str | None = None, **query: Any) -> str:
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query) return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
@ -30,7 +30,7 @@ def create_telegram_link(*path: str, **kwargs: Any) -> str:
def create_channel_bot_link( def create_channel_bot_link(
username: str, username: str,
parameter: Optional[str] = None, parameter: str | None = None,
change_info: bool = False, change_info: bool = False,
post_messages: bool = False, post_messages: bool = False,
edit_messages: bool = False, edit_messages: bool = False,

View file

@ -1,4 +1,5 @@
from typing import Any, Iterable from collections.abc import Iterable
from typing import Any
from magic_filter import MagicFilter as _MagicFilter from magic_filter import MagicFilter as _MagicFilter
from magic_filter import MagicT as _MagicT from magic_filter import MagicT as _MagicT

View file

@ -137,7 +137,7 @@ def strikethrough(*content: Any, sep: str = " ") -> str:
:return: :return:
""" """
return markdown_decoration.strikethrough( return markdown_decoration.strikethrough(
value=markdown_decoration.quote(_join(*content, sep=sep)) value=markdown_decoration.quote(_join(*content, sep=sep)),
) )
@ -183,7 +183,7 @@ def blockquote(*content: Any, sep: str = "\n") -> str:
:return: :return:
""" """
return markdown_decoration.blockquote( return markdown_decoration.blockquote(
value=markdown_decoration.quote(_join(*content, sep=sep)) value=markdown_decoration.quote(_join(*content, sep=sep)),
) )

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Optional, Union, overload from typing import Any, Literal, overload
from aiogram.enums import InputMediaType from aiogram.enums import InputMediaType
from aiogram.types import ( from aiogram.types import (
@ -12,12 +12,7 @@ from aiogram.types import (
MessageEntity, MessageEntity,
) )
MediaType = Union[ MediaType = InputMediaAudio | InputMediaPhoto | InputMediaVideo | InputMediaDocument
InputMediaAudio,
InputMediaPhoto,
InputMediaVideo,
InputMediaDocument,
]
MAX_MEDIA_GROUP_SIZE = 10 MAX_MEDIA_GROUP_SIZE = 10
@ -27,9 +22,9 @@ class MediaGroupBuilder:
def __init__( def __init__(
self, self,
media: Optional[List[MediaType]] = None, media: list[MediaType] | None = None,
caption: Optional[str] = None, caption: str | None = None,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
) -> None: ) -> None:
""" """
Helper class for building media groups. Helper class for building media groups.
@ -39,7 +34,7 @@ class MediaGroupBuilder:
:param caption_entities: List of special entities in the caption, :param caption_entities: List of special entities in the caption,
like usernames, URLs, etc. (optional) like usernames, URLs, etc. (optional)
""" """
self._media: List[MediaType] = [] self._media: list[MediaType] = []
self.caption = caption self.caption = caption
self.caption_entities = caption_entities self.caption_entities = caption_entities
@ -47,14 +42,16 @@ class MediaGroupBuilder:
def _add(self, media: MediaType) -> None: def _add(self, media: MediaType) -> None:
if not isinstance(media, InputMedia): if not isinstance(media, InputMedia):
raise ValueError("Media must be instance of InputMedia") msg = "Media must be instance of InputMedia"
raise ValueError(msg)
if len(self._media) >= MAX_MEDIA_GROUP_SIZE: if len(self._media) >= MAX_MEDIA_GROUP_SIZE:
raise ValueError("Media group can't contain more than 10 elements") msg = "Media group can't contain more than 10 elements"
raise ValueError(msg)
self._media.append(media) self._media.append(media)
def _extend(self, media: List[MediaType]) -> None: def _extend(self, media: list[MediaType]) -> None:
for m in media: for m in media:
self._add(m) self._add(m)
@ -63,13 +60,13 @@ class MediaGroupBuilder:
self, self,
*, *,
type: Literal[InputMediaType.AUDIO], type: Literal[InputMediaType.AUDIO],
media: Union[str, InputFile], media: str | InputFile,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
duration: Optional[int] = None, duration: int | None = None,
performer: Optional[str] = None, performer: str | None = None,
title: Optional[str] = None, title: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
pass pass
@ -79,11 +76,11 @@ class MediaGroupBuilder:
self, self,
*, *,
type: Literal[InputMediaType.PHOTO], type: Literal[InputMediaType.PHOTO],
media: Union[str, InputFile], media: str | InputFile,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
has_spoiler: Optional[bool] = None, has_spoiler: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
pass pass
@ -93,16 +90,16 @@ class MediaGroupBuilder:
self, self,
*, *,
type: Literal[InputMediaType.VIDEO], type: Literal[InputMediaType.VIDEO],
media: Union[str, InputFile], media: str | InputFile,
thumbnail: Optional[Union[InputFile, str]] = None, thumbnail: InputFile | str | None = None,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
width: Optional[int] = None, width: int | None = None,
height: Optional[int] = None, height: int | None = None,
duration: Optional[int] = None, duration: int | None = None,
supports_streaming: Optional[bool] = None, supports_streaming: bool | None = None,
has_spoiler: Optional[bool] = None, has_spoiler: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
pass pass
@ -112,12 +109,12 @@ class MediaGroupBuilder:
self, self,
*, *,
type: Literal[InputMediaType.DOCUMENT], type: Literal[InputMediaType.DOCUMENT],
media: Union[str, InputFile], media: str | InputFile,
thumbnail: Optional[Union[InputFile, str]] = None, thumbnail: InputFile | str | None = None,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
disable_content_type_detection: Optional[bool] = None, disable_content_type_detection: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
pass pass
@ -140,18 +137,19 @@ class MediaGroupBuilder:
elif type_ == InputMediaType.DOCUMENT: elif type_ == InputMediaType.DOCUMENT:
self.add_document(**kwargs) self.add_document(**kwargs)
else: else:
raise ValueError(f"Unknown media type: {type_!r}") msg = f"Unknown media type: {type_!r}"
raise ValueError(msg)
def add_audio( def add_audio(
self, self,
media: Union[str, InputFile], media: str | InputFile,
thumbnail: Optional[InputFile] = None, thumbnail: InputFile | None = None,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
duration: Optional[int] = None, duration: int | None = None,
performer: Optional[str] = None, performer: str | None = None,
title: Optional[str] = None, title: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -189,16 +187,16 @@ class MediaGroupBuilder:
performer=performer, performer=performer,
title=title, title=title,
**kwargs, **kwargs,
) ),
) )
def add_photo( def add_photo(
self, self,
media: Union[str, InputFile], media: str | InputFile,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
has_spoiler: Optional[bool] = None, has_spoiler: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -228,21 +226,21 @@ class MediaGroupBuilder:
caption_entities=caption_entities, caption_entities=caption_entities,
has_spoiler=has_spoiler, has_spoiler=has_spoiler,
**kwargs, **kwargs,
) ),
) )
def add_video( def add_video(
self, self,
media: Union[str, InputFile], media: str | InputFile,
thumbnail: Optional[InputFile] = None, thumbnail: InputFile | None = None,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
width: Optional[int] = None, width: int | None = None,
height: Optional[int] = None, height: int | None = None,
duration: Optional[int] = None, duration: int | None = None,
supports_streaming: Optional[bool] = None, supports_streaming: bool | None = None,
has_spoiler: Optional[bool] = None, has_spoiler: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -290,17 +288,17 @@ class MediaGroupBuilder:
supports_streaming=supports_streaming, supports_streaming=supports_streaming,
has_spoiler=has_spoiler, has_spoiler=has_spoiler,
**kwargs, **kwargs,
) ),
) )
def add_document( def add_document(
self, self,
media: Union[str, InputFile], media: str | InputFile,
thumbnail: Optional[InputFile] = None, thumbnail: InputFile | None = None,
caption: Optional[str] = None, caption: str | None = None,
parse_mode: Optional[str] = UNSET_PARSE_MODE, parse_mode: str | None = UNSET_PARSE_MODE,
caption_entities: Optional[List[MessageEntity]] = None, caption_entities: list[MessageEntity] | None = None,
disable_content_type_detection: Optional[bool] = None, disable_content_type_detection: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -342,10 +340,10 @@ class MediaGroupBuilder:
caption_entities=caption_entities, caption_entities=caption_entities,
disable_content_type_detection=disable_content_type_detection, disable_content_type_detection=disable_content_type_detection,
**kwargs, **kwargs,
) ),
) )
def build(self) -> List[MediaType]: def build(self) -> list[MediaType]:
""" """
Builds a list of media objects for a media group. Builds a list of media objects for a media group.
@ -353,7 +351,7 @@ class MediaGroupBuilder:
:return: List of media objects. :return: List of media objects.
""" """
update_first_media: Dict[str, Any] = {"caption": self.caption} update_first_media: dict[str, Any] = {"caption": self.caption}
if self.caption_entities is not None: if self.caption_entities is not None:
update_first_media["caption_entities"] = self.caption_entities update_first_media["caption_entities"] = self.caption_entities
update_first_media["parse_mode"] = None update_first_media["parse_mode"] = None

View file

@ -1,18 +1,18 @@
from __future__ import annotations from __future__ import annotations
import contextvars import contextvars
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar, cast, overload from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Literal from typing import Literal
__all__ = ("ContextInstanceMixin", "DataMixin") __all__ = ("ContextInstanceMixin", "DataMixin")
class DataMixin: class DataMixin:
@property @property
def data(self) -> Dict[str, Any]: def data(self) -> dict[str, Any]:
data: Optional[Dict[str, Any]] = getattr(self, "_data", None) data: dict[str, Any] | None = getattr(self, "_data", None)
if data is None: if data is None:
data = {} data = {}
setattr(self, "_data", data) setattr(self, "_data", data)
@ -30,7 +30,7 @@ class DataMixin:
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return key in self.data return key in self.data
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: def get(self, key: str, default: Any | None = None) -> Any | None:
return self.data.get(key, default) return self.data.get(key, default)
@ -44,36 +44,40 @@ class ContextInstanceMixin(Generic[ContextInstance]):
super().__init_subclass__() super().__init_subclass__()
cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}") cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}")
@overload # noqa: F811 @overload
@classmethod @classmethod
def get_current(cls) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811 def get_current(cls) -> ContextInstance | None: # pragma: no cover
... ...
@overload # noqa: F811 @overload
@classmethod @classmethod
def get_current( # noqa: F811 def get_current(
cls, no_error: Literal[True] cls,
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811 no_error: Literal[True],
) -> ContextInstance | None: # pragma: no cover
... ...
@overload # noqa: F811 @overload
@classmethod @classmethod
def get_current( # noqa: F811 def get_current(
cls, no_error: Literal[False] cls,
) -> ContextInstance: # pragma: no cover # noqa: F811 no_error: Literal[False],
) -> ContextInstance: # pragma: no cover
... ...
@classmethod # noqa: F811 @classmethod
def get_current( # noqa: F811 def get_current(
cls, no_error: bool = True cls,
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811 no_error: bool = True,
) -> ContextInstance | None: # pragma: no cover
# on mypy 0.770 I catch that contextvars.ContextVar always contextvars.ContextVar[Any] # on mypy 0.770 I catch that contextvars.ContextVar always contextvars.ContextVar[Any]
cls.__context_instance = cast( cls.__context_instance = cast(
contextvars.ContextVar[ContextInstance], cls.__context_instance contextvars.ContextVar[ContextInstance],
cls.__context_instance,
) )
try: try:
current: Optional[ContextInstance] = cls.__context_instance.get() current: ContextInstance | None = cls.__context_instance.get()
except LookupError: except LookupError:
if no_error: if no_error:
current = None current = None
@ -85,9 +89,8 @@ class ContextInstanceMixin(Generic[ContextInstance]):
@classmethod @classmethod
def set_current(cls, value: ContextInstance) -> contextvars.Token[ContextInstance]: def set_current(cls, value: ContextInstance) -> contextvars.Token[ContextInstance]:
if not isinstance(value, cls): if not isinstance(value, cls):
raise TypeError( msg = f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}" raise TypeError(msg)
)
return cls.__context_instance.set(value) return cls.__context_instance.set(value)
@classmethod @classmethod

View file

@ -1,5 +1,10 @@
from __future__ import annotations
import functools import functools
from typing import Callable, TypeVar from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from collections.abc import Callable
T = TypeVar("T") T = TypeVar("T")

View file

@ -61,13 +61,18 @@ Encoding and decoding with your own methods:
""" """
from __future__ import annotations
from base64 import urlsafe_b64decode, urlsafe_b64encode from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Callable, Optional from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
def encode_payload( def encode_payload(
payload: str, payload: str,
encoder: Optional[Callable[[bytes], bytes]] = None, encoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
"""Encode payload with encoder. """Encode payload with encoder.
@ -85,7 +90,7 @@ def encode_payload(
def decode_payload( def decode_payload(
payload: str, payload: str,
decoder: Optional[Callable[[bytes], bytes]] = None, decoder: Callable[[bytes], bytes] | None = None,
) -> str: ) -> str:
"""Decode URL-safe base64url payload with decoder.""" """Decode URL-safe base64url payload with decoder."""
original_payload = _decode_b64(payload) original_payload = _decode_b64(payload)

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -9,7 +9,7 @@ from aiogram.methods import TelegramMethod
from aiogram.types import InputFile from aiogram.types import InputFile
def _get_fake_bot(default: Optional[DefaultBotProperties] = None) -> Bot: def _get_fake_bot(default: DefaultBotProperties | None = None) -> Bot:
if default is None: if default is None:
default = DefaultBotProperties() default = DefaultBotProperties()
return Bot(token="42:Fake", default=default) return Bot(token="42:Fake", default=default)
@ -28,12 +28,12 @@ class DeserializedTelegramObject:
""" """
data: Any data: Any
files: Dict[str, InputFile] files: dict[str, InputFile]
def deserialize_telegram_object( def deserialize_telegram_object(
obj: Any, obj: Any,
default: Optional[DefaultBotProperties] = None, default: DefaultBotProperties | None = None,
include_api_method_name: bool = True, include_api_method_name: bool = True,
) -> DeserializedTelegramObject: ) -> DeserializedTelegramObject:
""" """
@ -55,7 +55,7 @@ def deserialize_telegram_object(
# Fake bot is needed to exclude global defaults from the object. # Fake bot is needed to exclude global defaults from the object.
fake_bot = _get_fake_bot(default=default) fake_bot = _get_fake_bot(default=default)
files: Dict[str, InputFile] = {} files: dict[str, InputFile] = {}
prepared = fake_bot.session.prepare_value( prepared = fake_bot.session.prepare_value(
obj, obj,
bot=fake_bot, bot=fake_bot,
@ -70,7 +70,7 @@ def deserialize_telegram_object(
def deserialize_telegram_object_to_python( def deserialize_telegram_object_to_python(
obj: Any, obj: Any,
default: Optional[DefaultBotProperties] = None, default: DefaultBotProperties | None = None,
include_api_method_name: bool = True, include_api_method_name: bool = True,
) -> Any: ) -> Any:
""" """

View file

@ -3,20 +3,23 @@ from __future__ import annotations
import html import html
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generator, List, Optional, Pattern, cast from typing import TYPE_CHECKING, cast
from aiogram.enums import MessageEntityType from aiogram.enums import MessageEntityType
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator
from re import Pattern
from aiogram.types import MessageEntity from aiogram.types import MessageEntity
__all__ = ( __all__ = (
"HtmlDecoration", "HtmlDecoration",
"MarkdownDecoration", "MarkdownDecoration",
"TextDecoration", "TextDecoration",
"add_surrogates",
"html_decoration", "html_decoration",
"markdown_decoration", "markdown_decoration",
"add_surrogates",
"remove_surrogates", "remove_surrogates",
) )
@ -80,7 +83,7 @@ class TextDecoration(ABC):
# API it will be here too # API it will be here too
return self.quote(text) return self.quote(text)
def unparse(self, text: str, entities: Optional[List[MessageEntity]] = None) -> str: def unparse(self, text: str, entities: list[MessageEntity] | None = None) -> str:
""" """
Unparse message entities Unparse message entities
@ -92,15 +95,15 @@ class TextDecoration(ABC):
self._unparse_entities( self._unparse_entities(
add_surrogates(text), add_surrogates(text),
sorted(entities, key=lambda item: item.offset) if entities else [], sorted(entities, key=lambda item: item.offset) if entities else [],
) ),
) )
def _unparse_entities( def _unparse_entities(
self, self,
text: bytes, text: bytes,
entities: List[MessageEntity], entities: list[MessageEntity],
offset: Optional[int] = None, offset: int | None = None,
length: Optional[int] = None, length: int | None = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if offset is None: if offset is None:
offset = 0 offset = 0
@ -115,7 +118,7 @@ class TextDecoration(ABC):
offset = entity.offset * 2 + entity.length * 2 offset = entity.offset * 2 + entity.length * 2
sub_entities = list( sub_entities = list(
filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :]) filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :]),
) )
yield self.apply_entity( yield self.apply_entity(
entity, entity,

View file

@ -5,7 +5,7 @@ class TokenValidationError(Exception):
pass pass
@lru_cache() @lru_cache
def validate_token(token: str) -> bool: def validate_token(token: str) -> bool:
""" """
Validate Telegram token Validate Telegram token
@ -14,9 +14,8 @@ def validate_token(token: str) -> bool:
:return: :return:
""" """
if not isinstance(token, str): if not isinstance(token, str):
raise TokenValidationError( msg = f"Token is invalid! It must be 'str' type instead of {type(token)} type."
f"Token is invalid! It must be 'str' type instead of {type(token)} type." raise TokenValidationError(msg)
)
if any(x.isspace() for x in token): if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces." message = "Token is invalid! It can't contains spaces."
@ -24,12 +23,13 @@ def validate_token(token: str) -> bool:
left, sep, right = token.partition(":") left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right): if (not sep) or (not left.isdigit()) or (not right):
raise TokenValidationError("Token is invalid!") msg = "Token is invalid!"
raise TokenValidationError(msg)
return True return True
@lru_cache() @lru_cache
def extract_bot_id(token: str) -> int: def extract_bot_id(token: str) -> int:
""" """
Extract bot ID from Telegram token Extract bot ID from Telegram token

View file

@ -1,9 +1,10 @@
import hashlib import hashlib
import hmac import hmac
import json import json
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from operator import itemgetter from operator import itemgetter
from typing import Any, Callable, Optional from typing import Any
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from aiogram.types import TelegramObject from aiogram.types import TelegramObject
@ -25,9 +26,9 @@ class WebAppChat(TelegramObject):
"""Type of chat, can be either “group”, “supergroup” or “channel”""" """Type of chat, can be either “group”, “supergroup” or “channel”"""
title: str title: str
"""Title of the chat""" """Title of the chat"""
username: Optional[str] = None username: str | None = None
"""Username of the chat""" """Username of the chat"""
photo_url: Optional[str] = None photo_url: str | None = None
"""URL of the chats photo. The photo can be in .jpeg or .svg formats. """URL of the chats photo. The photo can be in .jpeg or .svg formats.
Only returned for Web Apps launched from the attachment menu.""" Only returned for Web Apps launched from the attachment menu."""
@ -44,23 +45,23 @@ class WebAppUser(TelegramObject):
and some programming languages may have difficulty/silent defects in interpreting it. and some programming languages may have difficulty/silent defects in interpreting it.
It has at most 52 significant bits, so a 64-bit integer or a double-precision float type It has at most 52 significant bits, so a 64-bit integer or a double-precision float type
is safe for storing this identifier.""" is safe for storing this identifier."""
is_bot: Optional[bool] = None is_bot: bool | None = None
"""True, if this user is a bot. Returns in the receiver field only.""" """True, if this user is a bot. Returns in the receiver field only."""
first_name: str first_name: str
"""First name of the user or bot.""" """First name of the user or bot."""
last_name: Optional[str] = None last_name: str | None = None
"""Last name of the user or bot.""" """Last name of the user or bot."""
username: Optional[str] = None username: str | None = None
"""Username of the user or bot.""" """Username of the user or bot."""
language_code: Optional[str] = None language_code: str | None = None
"""IETF language tag of the user's language. Returns in user field only.""" """IETF language tag of the user's language. Returns in user field only."""
is_premium: Optional[bool] = None is_premium: bool | None = None
"""True, if this user is a Telegram Premium user.""" """True, if this user is a Telegram Premium user."""
added_to_attachment_menu: Optional[bool] = None added_to_attachment_menu: bool | None = None
"""True, if this user added the bot to the attachment menu.""" """True, if this user added the bot to the attachment menu."""
allows_write_to_pm: Optional[bool] = None allows_write_to_pm: bool | None = None
"""True, if this user allowed the bot to message them.""" """True, if this user allowed the bot to message them."""
photo_url: Optional[str] = None photo_url: str | None = None
"""URL of the users profile photo. The photo can be in .jpeg or .svg formats. """URL of the users profile photo. The photo can be in .jpeg or .svg formats.
Only returned for Web Apps launched from the attachment menu.""" Only returned for Web Apps launched from the attachment menu."""
@ -73,33 +74,33 @@ class WebAppInitData(TelegramObject):
Source: https://core.telegram.org/bots/webapps#webappinitdata Source: https://core.telegram.org/bots/webapps#webappinitdata
""" """
query_id: Optional[str] = None query_id: str | None = None
"""A unique identifier for the Web App session, required for sending messages """A unique identifier for the Web App session, required for sending messages
via the answerWebAppQuery method.""" via the answerWebAppQuery method."""
user: Optional[WebAppUser] = None user: WebAppUser | None = None
"""An object containing data about the current user.""" """An object containing data about the current user."""
receiver: Optional[WebAppUser] = None receiver: WebAppUser | None = None
"""An object containing data about the chat partner of the current user in the chat where """An object containing data about the chat partner of the current user in the chat where
the bot was launched via the attachment menu. the bot was launched via the attachment menu.
Returned only for Web Apps launched via the attachment menu.""" Returned only for Web Apps launched via the attachment menu."""
chat: Optional[WebAppChat] = None chat: WebAppChat | None = None
"""An object containing data about the chat where the bot was launched via the attachment menu. """An object containing data about the chat where the bot was launched via the attachment menu.
Returned for supergroups, channels, and group chats only for Web Apps launched via the Returned for supergroups, channels, and group chats only for Web Apps launched via the
attachment menu.""" attachment menu."""
chat_type: Optional[str] = None chat_type: str | None = None
"""Type of the chat from which the Web App was opened. """Type of the chat from which the Web App was opened.
Can be either sender for a private chat with the user opening the link, Can be either sender for a private chat with the user opening the link,
private, group, supergroup, or channel. private, group, supergroup, or channel.
Returned only for Web Apps launched from direct links.""" Returned only for Web Apps launched from direct links."""
chat_instance: Optional[str] = None chat_instance: str | None = None
"""Global identifier, uniquely corresponding to the chat from which the Web App was opened. """Global identifier, uniquely corresponding to the chat from which the Web App was opened.
Returned only for Web Apps launched from a direct link.""" Returned only for Web Apps launched from a direct link."""
start_param: Optional[str] = None start_param: str | None = None
"""The value of the startattach parameter, passed via link. """The value of the startattach parameter, passed via link.
Only returned for Web Apps when launched from the attachment menu via link. Only returned for Web Apps when launched from the attachment menu via link.
The value of the start_param parameter will also be passed in the GET-parameter The value of the start_param parameter will also be passed in the GET-parameter
tgWebAppStartParam, so the Web App can load the correct interface right away.""" tgWebAppStartParam, so the Web App can load the correct interface right away."""
can_send_after: Optional[int] = None can_send_after: int | None = None
"""Time in seconds, after which a message can be sent via the answerWebAppQuery method.""" """Time in seconds, after which a message can be sent via the answerWebAppQuery method."""
auth_date: datetime auth_date: datetime
"""Unix time when the form was opened.""" """Unix time when the form was opened."""
@ -132,7 +133,9 @@ def check_webapp_signature(token: str, init_data: str) -> bool:
) )
secret_key = hmac.new(key=b"WebAppData", msg=token.encode(), digestmod=hashlib.sha256) secret_key = hmac.new(key=b"WebAppData", msg=token.encode(), digestmod=hashlib.sha256)
calculated_hash = hmac.new( calculated_hash = hmac.new(
key=secret_key.digest(), msg=data_check_string.encode(), digestmod=hashlib.sha256 key=secret_key.digest(),
msg=data_check_string.encode(),
digestmod=hashlib.sha256,
).hexdigest() ).hexdigest()
return hmac.compare_digest(calculated_hash, hash_) return hmac.compare_digest(calculated_hash, hash_)
@ -180,4 +183,5 @@ def safe_parse_webapp_init_data(
""" """
if check_webapp_signature(token, init_data): if check_webapp_signature(token, init_data):
return parse_webapp_init_data(init_data, loads=loads) return parse_webapp_init_data(init_data, loads=loads)
raise ValueError("Invalid init data signature") msg = "Invalid init data signature"
raise ValueError(msg)

View file

@ -8,13 +8,15 @@ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from .web_app import WebAppInitData, parse_webapp_init_data from .web_app import WebAppInitData, parse_webapp_init_data
PRODUCTION_PUBLIC_KEY = bytes.fromhex( PRODUCTION_PUBLIC_KEY = bytes.fromhex(
"e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d" "e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d",
) )
TEST_PUBLIC_KEY = bytes.fromhex("40055058a4ee38156a06562e52eece92a771bcd8346a8c4615cb7376eddf72ec") TEST_PUBLIC_KEY = bytes.fromhex("40055058a4ee38156a06562e52eece92a771bcd8346a8c4615cb7376eddf72ec")
def check_webapp_signature( def check_webapp_signature(
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY bot_id: int,
init_data: str,
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
) -> bool: ) -> bool:
""" """
Check incoming WebApp init data signature without bot token using only bot id. Check incoming WebApp init data signature without bot token using only bot id.
@ -49,13 +51,16 @@ def check_webapp_signature(
try: try:
public_key.verify(signature, message) public_key.verify(signature, message)
return True
except InvalidSignature: except InvalidSignature:
return False return False
else:
return True
def safe_check_webapp_init_data_from_signature( def safe_check_webapp_init_data_from_signature(
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY bot_id: int,
init_data: str,
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
) -> WebAppInitData: ) -> WebAppInitData:
""" """
Validate raw WebApp init data using only bot id and return it as WebAppInitData object Validate raw WebApp init data using only bot id and return it as WebAppInitData object
@ -67,4 +72,5 @@ def safe_check_webapp_init_data_from_signature(
""" """
if check_webapp_signature(bot_id, init_data, public_key_bytes): if check_webapp_signature(bot_id, init_data, public_key_bytes):
return parse_webapp_init_data(init_data) return parse_webapp_init_data(init_data)
raise ValueError("Invalid init data signature") msg = "Invalid init data signature"
raise ValueError(msg)

View file

@ -2,7 +2,8 @@ import asyncio
import secrets import secrets
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Transport from asyncio import Transport
from typing import Any, Awaitable, Callable, Dict, Optional, Set, Tuple, cast from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from aiohttp import JsonPayload, MultipartWriter, Payload, web from aiohttp import JsonPayload, MultipartWriter, Payload, web
from aiohttp.typedefs import Handler from aiohttp.typedefs import Handler
@ -12,9 +13,11 @@ from aiohttp.web_middlewares import middleware
from aiogram import Bot, Dispatcher, loggers from aiogram import Bot, Dispatcher, loggers
from aiogram.methods import TelegramMethod from aiogram.methods import TelegramMethod
from aiogram.methods.base import TelegramType from aiogram.methods.base import TelegramType
from aiogram.types import InputFile
from aiogram.webhook.security import IPFilter from aiogram.webhook.security import IPFilter
if TYPE_CHECKING:
from aiogram.types import InputFile
def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any) -> None: def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any) -> None:
""" """
@ -42,7 +45,7 @@ def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any
app.on_shutdown.append(on_shutdown) app.on_shutdown.append(on_shutdown)
def check_ip(ip_filter: IPFilter, request: web.Request) -> Tuple[str, bool]: def check_ip(ip_filter: IPFilter, request: web.Request) -> tuple[str, bool]:
# Try to resolve client IP over reverse proxy # Try to resolve client IP over reverse proxy
if forwarded_for := request.headers.get("X-Forwarded-For", ""): if forwarded_for := request.headers.get("X-Forwarded-For", ""):
# Get the left-most ip when there is multiple ips # Get the left-most ip when there is multiple ips
@ -98,7 +101,7 @@ class BaseRequestHandler(ABC):
self.dispatcher = dispatcher self.dispatcher = dispatcher
self.handle_in_background = handle_in_background self.handle_in_background = handle_in_background
self.data = data self.data = data
self._background_feed_update_tasks: Set[asyncio.Task[Any]] = set() self._background_feed_update_tasks: set[asyncio.Task[Any]] = set()
def register(self, app: Application, /, path: str, **kwargs: Any) -> None: def register(self, app: Application, /, path: str, **kwargs: Any) -> None:
""" """
@ -128,13 +131,12 @@ class BaseRequestHandler(ABC):
:param request: :param request:
:return: Bot instance :return: Bot instance
""" """
pass
@abstractmethod @abstractmethod
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool: def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
pass pass
async def _background_feed_update(self, bot: Bot, update: Dict[str, Any]) -> None: async def _background_feed_update(self, bot: Bot, update: dict[str, Any]) -> None:
result = await self.dispatcher.feed_raw_update(bot=bot, update=update, **self.data) result = await self.dispatcher.feed_raw_update(bot=bot, update=update, **self.data)
if isinstance(result, TelegramMethod): if isinstance(result, TelegramMethod):
await self.dispatcher.silent_call_request(bot=bot, result=result) await self.dispatcher.silent_call_request(bot=bot, result=result)
@ -142,15 +144,18 @@ class BaseRequestHandler(ABC):
async def _handle_request_background(self, bot: Bot, request: web.Request) -> web.Response: async def _handle_request_background(self, bot: Bot, request: web.Request) -> web.Response:
feed_update_task = asyncio.create_task( feed_update_task = asyncio.create_task(
self._background_feed_update( self._background_feed_update(
bot=bot, update=await request.json(loads=bot.session.json_loads) bot=bot,
) update=await request.json(loads=bot.session.json_loads),
),
) )
self._background_feed_update_tasks.add(feed_update_task) self._background_feed_update_tasks.add(feed_update_task)
feed_update_task.add_done_callback(self._background_feed_update_tasks.discard) feed_update_task.add_done_callback(self._background_feed_update_tasks.discard)
return web.json_response({}, dumps=bot.session.json_dumps) return web.json_response({}, dumps=bot.session.json_dumps)
def _build_response_writer( def _build_response_writer(
self, bot: Bot, result: Optional[TelegramMethod[TelegramType]] self,
bot: Bot,
result: TelegramMethod[TelegramType] | None,
) -> Payload: ) -> Payload:
if not result: if not result:
# we need to return something "empty" # we need to return something "empty"
@ -166,7 +171,7 @@ class BaseRequestHandler(ABC):
payload = writer.append(result.__api_method__) payload = writer.append(result.__api_method__)
payload.set_content_disposition("form-data", name="method") payload.set_content_disposition("form-data", name="method")
files: Dict[str, InputFile] = {} files: dict[str, InputFile] = {}
for key, value in result.model_dump(warnings=False).items(): for key, value in result.model_dump(warnings=False).items():
value = bot.session.prepare_value(value, bot=bot, files=files) value = bot.session.prepare_value(value, bot=bot, files=files)
if not value: if not value:
@ -185,7 +190,7 @@ class BaseRequestHandler(ABC):
return writer return writer
async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response: async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response:
result: Optional[TelegramMethod[Any]] = await self.dispatcher.feed_webhook_update( result: TelegramMethod[Any] | None = await self.dispatcher.feed_webhook_update(
bot, bot,
await request.json(loads=bot.session.json_loads), await request.json(loads=bot.session.json_loads),
**self.data, **self.data,
@ -209,7 +214,7 @@ class SimpleRequestHandler(BaseRequestHandler):
dispatcher: Dispatcher, dispatcher: Dispatcher,
bot: Bot, bot: Bot,
handle_in_background: bool = True, handle_in_background: bool = True,
secret_token: Optional[str] = None, secret_token: str | None = None,
**data: Any, **data: Any,
) -> None: ) -> None:
""" """
@ -244,7 +249,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
self, self,
dispatcher: Dispatcher, dispatcher: Dispatcher,
handle_in_background: bool = True, handle_in_background: bool = True,
bot_settings: Optional[Dict[str, Any]] = None, bot_settings: dict[str, Any] | None = None,
**data: Any, **data: Any,
) -> None: ) -> None:
""" """
@ -265,7 +270,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
if bot_settings is None: if bot_settings is None:
bot_settings = {} bot_settings = {}
self.bot_settings = bot_settings self.bot_settings = bot_settings
self.bots: Dict[str, Bot] = {} self.bots: dict[str, Bot] = {}
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool: def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
return True return True
@ -283,7 +288,8 @@ class TokenBasedRequestHandler(BaseRequestHandler):
:param kwargs: :param kwargs:
""" """
if "{bot_token}" not in path: if "{bot_token}" not in path:
raise ValueError("Path should contains '{bot_token}' substring") msg = "Path should contains '{bot_token}' substring"
raise ValueError(msg)
super().register(app, path=path, **kwargs) super().register(app, path=path, **kwargs)
async def resolve_bot(self, request: web.Request) -> Bot: async def resolve_bot(self, request: web.Request) -> Bot:

View file

@ -1,5 +1,5 @@
from collections.abc import Sequence
from ipaddress import IPv4Address, IPv4Network from ipaddress import IPv4Address, IPv4Network
from typing import Optional, Sequence, Set, Union
DEFAULT_TELEGRAM_NETWORKS = [ DEFAULT_TELEGRAM_NETWORKS = [
IPv4Network("149.154.160.0/20"), IPv4Network("149.154.160.0/20"),
@ -8,17 +8,17 @@ DEFAULT_TELEGRAM_NETWORKS = [
class IPFilter: class IPFilter:
def __init__(self, ips: Optional[Sequence[Union[str, IPv4Network, IPv4Address]]] = None): def __init__(self, ips: Sequence[str | IPv4Network | IPv4Address] | None = None):
self._allowed_ips: Set[IPv4Address] = set() self._allowed_ips: set[IPv4Address] = set()
if ips: if ips:
self.allow(*ips) self.allow(*ips)
def allow(self, *ips: Union[str, IPv4Network, IPv4Address]) -> None: def allow(self, *ips: str | IPv4Network | IPv4Address) -> None:
for ip in ips: for ip in ips:
self.allow_ip(ip) self.allow_ip(ip)
def allow_ip(self, ip: Union[str, IPv4Network, IPv4Address]) -> None: def allow_ip(self, ip: str | IPv4Network | IPv4Address) -> None:
if isinstance(ip, str): if isinstance(ip, str):
ip = IPv4Network(ip) if "/" in ip else IPv4Address(ip) ip = IPv4Network(ip) if "/" in ip else IPv4Address(ip)
if isinstance(ip, IPv4Address): if isinstance(ip, IPv4Address):
@ -26,16 +26,17 @@ class IPFilter:
elif isinstance(ip, IPv4Network): elif isinstance(ip, IPv4Network):
self._allowed_ips.update(ip.hosts()) self._allowed_ips.update(ip.hosts())
else: else:
raise ValueError(f"Invalid type of ipaddress: {type(ip)} ('{ip}')") msg = f"Invalid type of ipaddress: {type(ip)} ('{ip}')"
raise ValueError(msg)
@classmethod @classmethod
def default(cls) -> "IPFilter": def default(cls) -> "IPFilter":
return cls(DEFAULT_TELEGRAM_NETWORKS) return cls(DEFAULT_TELEGRAM_NETWORKS)
def check(self, ip: Union[str, IPv4Address]) -> bool: def check(self, ip: str | IPv4Address) -> bool:
if not isinstance(ip, IPv4Address): if not isinstance(ip, IPv4Address):
ip = IPv4Address(ip) ip = IPv4Address(ip)
return ip in self._allowed_ips return ip in self._allowed_ips
def __contains__(self, item: Union[str, IPv4Address]) -> bool: def __contains__(self, item: str | IPv4Address) -> bool:
return self.check(item) return self.check(item)

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union from typing import Any
from aiogram import Router from aiogram import Router
from aiogram.filters import Filter from aiogram.filters import Filter
@ -8,7 +8,7 @@ router = Router(name=__name__)
class HelloFilter(Filter): class HelloFilter(Filter):
def __init__(self, name: Optional[str] = None) -> None: def __init__(self, name: str | None = None) -> None:
self.name = name self.name = name
async def __call__( async def __call__(
@ -16,7 +16,7 @@ class HelloFilter(Filter):
message: Message, message: Message,
event_from_user: User, event_from_user: User,
# Filters also can accept keyword parameters like in handlers # Filters also can accept keyword parameters like in handlers
) -> Union[bool, Dict[str, Any]]: ) -> bool | dict[str, Any]:
if message.text.casefold() == "hello": if message.text.casefold() == "hello":
# Returning a dictionary that will update the context data # Returning a dictionary that will update the context data
return {"name": event_from_user.mention_html(name=self.name)} return {"name": event_from_user.mention_html(name=self.name)}
@ -25,6 +25,7 @@ class HelloFilter(Filter):
@router.message(HelloFilter()) @router.message(HelloFilter())
async def my_handler( async def my_handler(
message: Message, name: str # Now we can accept "name" as named parameter message: Message,
name: str, # Now we can accept "name" as named parameter
) -> Any: ) -> Any:
return message.answer("Hello, {name}!".format(name=name)) return message.answer(f"Hello, {name}!")

View file

@ -75,11 +75,13 @@ async def handle_set_age(message: types.Message, command: CommandObject) -> None
# To get the command arguments you can use `command.args` property. # To get the command arguments you can use `command.args` property.
age = command.args age = command.args
if not age: if not age:
raise InvalidAge("No age provided. Please provide your age as a command argument.") msg = "No age provided. Please provide your age as a command argument."
raise InvalidAge(msg)
# If the age is invalid, raise an exception. # If the age is invalid, raise an exception.
if not age.isdigit(): if not age.isdigit():
raise InvalidAge("Age should be a number") msg = "Age should be a number"
raise InvalidAge(msg)
# If the age is valid, send a message to the user. # If the age is valid, send a message to the user.
age = int(age) age = int(age)
@ -95,7 +97,8 @@ async def handle_set_name(message: types.Message, command: CommandObject) -> Non
# To get the command arguments you can use `command.args` property. # To get the command arguments you can use `command.args` property.
name = command.args name = command.args
if not name: if not name:
raise InvalidName("Invalid name. Please provide your name as a command argument.") msg = "Invalid name. Please provide your name as a command argument."
raise InvalidName(msg)
# If the name is valid, send a message to the user. # If the name is valid, send a message to the user.
await message.reply(text=f"Your name is {name}") await message.reply(text=f"Your name is {name}")

View file

@ -2,7 +2,7 @@ import asyncio
import logging import logging
import sys import sys
from os import getenv from os import getenv
from typing import Any, Dict from typing import Any
from aiogram import Bot, Dispatcher, F, Router, html from aiogram import Bot, Dispatcher, F, Router, html
from aiogram.client.default import DefaultBotProperties from aiogram.client.default import DefaultBotProperties
@ -66,7 +66,7 @@ async def process_name(message: Message, state: FSMContext) -> None:
[ [
KeyboardButton(text="Yes"), KeyboardButton(text="Yes"),
KeyboardButton(text="No"), KeyboardButton(text="No"),
] ],
], ],
resize_keyboard=True, resize_keyboard=True,
), ),
@ -106,13 +106,13 @@ async def process_language(message: Message, state: FSMContext) -> None:
if message.text.casefold() == "python": if message.text.casefold() == "python":
await message.reply( await message.reply(
"Python, you say? That's the language that makes my circuits light up! 😉" "Python, you say? That's the language that makes my circuits light up! 😉",
) )
await show_summary(message=message, data=data) await show_summary(message=message, data=data)
async def show_summary(message: Message, data: Dict[str, Any], positive: bool = True) -> None: async def show_summary(message: Message, data: dict[str, Any], positive: bool = True) -> None:
name = data["name"] name = data["name"]
language = data.get("language", "<something unexpected>") language = data.get("language", "<something unexpected>")
text = f"I'll keep in mind that, {html.quote(name)}, " text = f"I'll keep in mind that, {html.quote(name)}, "
@ -124,7 +124,7 @@ async def show_summary(message: Message, data: Dict[str, Any], positive: bool =
await message.answer(text=text, reply_markup=ReplyKeyboardRemove()) await message.answer(text=text, reply_markup=ReplyKeyboardRemove())
async def main(): async def main() -> None:
# Initialize Bot instance with default bot properties which will be passed to all API calls # Initialize Bot instance with default bot properties which will be passed to all API calls
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML)) bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))

View file

@ -1,7 +1,7 @@
import logging import logging
import sys import sys
from os import getenv from os import getenv
from typing import Any, Dict, Union from typing import Any
from aiohttp import web from aiohttp import web
from finite_state_machine import form_router from finite_state_machine import form_router
@ -34,7 +34,7 @@ REDIS_DSN = "redis://127.0.0.1:6479"
OTHER_BOTS_URL = f"{BASE_URL}{OTHER_BOTS_PATH}" OTHER_BOTS_URL = f"{BASE_URL}{OTHER_BOTS_PATH}"
def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]: def is_bot_token(value: str) -> bool | dict[str, Any]:
try: try:
validate_token(value) validate_token(value)
except TokenValidationError: except TokenValidationError:
@ -54,11 +54,11 @@ async def command_add_bot(message: Message, command: CommandObject, bot: Bot) ->
return await message.answer(f"Bot @{bot_user.username} successful added") return await message.answer(f"Bot @{bot_user.username} successful added")
async def on_startup(dispatcher: Dispatcher, bot: Bot): async def on_startup(dispatcher: Dispatcher, bot: Bot) -> None:
await bot.set_webhook(f"{BASE_URL}{MAIN_BOT_PATH}") await bot.set_webhook(f"{BASE_URL}{MAIN_BOT_PATH}")
def main(): def main() -> None:
logging.basicConfig(level=logging.INFO, stream=sys.stdout) logging.basicConfig(level=logging.INFO, stream=sys.stdout)
session = AiohttpSession() session = AiohttpSession()
bot_settings = {"session": session, "parse_mode": ParseMode.HTML} bot_settings = {"session": session, "parse_mode": ParseMode.HTML}

View file

@ -14,4 +14,4 @@ class MyFilter(Filter):
@router.message(MyFilter("hello")) @router.message(MyFilter("hello"))
async def my_handler(message: Message): ... async def my_handler(message: Message) -> None: ...

View file

@ -263,7 +263,7 @@ quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
@quiz_router.message(Command("start")) @quiz_router.message(Command("start"))
async def command_start(message: Message, scenes: ScenesManager): async def command_start(message: Message, scenes: ScenesManager) -> None:
await scenes.close() await scenes.close()
await message.answer( await message.answer(
"Hi! This is a quiz bot. To start the quiz, use the /quiz command.", "Hi! This is a quiz bot. To start the quiz, use the /quiz command.",
@ -271,7 +271,7 @@ async def command_start(message: Message, scenes: ScenesManager):
) )
def create_dispatcher(): def create_dispatcher() -> Dispatcher:
# Event isolation is needed to correctly handle fast user responses # Event isolation is needed to correctly handle fast user responses
dispatcher = Dispatcher( dispatcher = Dispatcher(
events_isolation=SimpleEventIsolation(), events_isolation=SimpleEventIsolation(),
@ -288,7 +288,7 @@ def create_dispatcher():
return dispatcher return dispatcher
async def main(): async def main() -> None:
dp = create_dispatcher() dp = create_dispatcher()
bot = Bot(token=TOKEN) bot = Bot(token=TOKEN)
await dp.start_polling(bot) await dp.start_polling(bot)

View file

@ -34,11 +34,11 @@ class CancellableScene(Scene):
""" """
@on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit()) @on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit())
async def handle_cancel(self, message: Message): async def handle_cancel(self, message: Message) -> None:
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove()) await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back()) @on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
async def handle_back(self, message: Message): async def handle_back(self, message: Message) -> None:
await message.answer("Back.") await message.answer("Back.")
@ -48,7 +48,7 @@ class LanguageScene(CancellableScene, state="language"):
""" """
@on.message.enter() @on.message.enter()
async def on_enter(self, message: Message): async def on_enter(self, message: Message) -> None:
await message.answer( await message.answer(
"What language do you prefer?", "What language do you prefer?",
reply_markup=ReplyKeyboardMarkup( reply_markup=ReplyKeyboardMarkup(
@ -58,14 +58,14 @@ class LanguageScene(CancellableScene, state="language"):
) )
@on.message(F.text.casefold() == "python", after=After.exit()) @on.message(F.text.casefold() == "python", after=After.exit())
async def process_python(self, message: Message): async def process_python(self, message: Message) -> None:
await message.answer( await message.answer(
"Python, you say? That's the language that makes my circuits light up! 😉" "Python, you say? That's the language that makes my circuits light up! 😉",
) )
await self.input_language(message) await self.input_language(message)
@on.message(after=After.exit()) @on.message(after=After.exit())
async def input_language(self, message: Message): async def input_language(self, message: Message) -> None:
data: FSMData = await self.wizard.get_data() data: FSMData = await self.wizard.get_data()
await self.show_results(message, language=message.text, **data) await self.show_results(message, language=message.text, **data)
@ -83,7 +83,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
""" """
@on.message.enter() @on.message.enter()
async def on_enter(self, message: Message): async def on_enter(self, message: Message) -> None:
await message.answer( await message.answer(
"Did you like to write bots?", "Did you like to write bots?",
reply_markup=ReplyKeyboardMarkup( reply_markup=ReplyKeyboardMarkup(
@ -96,18 +96,18 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
) )
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene)) @on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
async def process_like_write_bots(self, message: Message): async def process_like_write_bots(self, message: Message) -> None:
await message.reply("Cool! I'm too!") await message.reply("Cool! I'm too!")
@on.message(F.text.casefold() == "no", after=After.exit()) @on.message(F.text.casefold() == "no", after=After.exit())
async def process_dont_like_write_bots(self, message: Message): async def process_dont_like_write_bots(self, message: Message) -> None:
await message.answer( await message.answer(
"Not bad not terrible.\nSee you soon.", "Not bad not terrible.\nSee you soon.",
reply_markup=ReplyKeyboardRemove(), reply_markup=ReplyKeyboardRemove(),
) )
@on.message() @on.message()
async def input_like_bots(self, message: Message): async def input_like_bots(self, message: Message) -> None:
await message.answer("I don't understand you :(") await message.answer("I don't understand you :(")
@ -117,25 +117,25 @@ class NameScene(CancellableScene, state="name"):
""" """
@on.message.enter() # Marker for handler that should be called when a user enters the scene. @on.message.enter() # Marker for handler that should be called when a user enters the scene.
async def on_enter(self, message: Message): async def on_enter(self, message: Message) -> None:
await message.answer( await message.answer(
"Hi there! What's your name?", "Hi there! What's your name?",
reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True), reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True),
) )
@on.callback_query.enter() # different types of updates that start the scene also supported. @on.callback_query.enter() # different types of updates that start the scene also supported.
async def on_enter_callback(self, callback_query: CallbackQuery): async def on_enter_callback(self, callback_query: CallbackQuery) -> None:
await callback_query.answer() await callback_query.answer()
await self.on_enter(callback_query.message) await self.on_enter(callback_query.message)
@on.message.leave() # Marker for handler that should be called when a user leaves the scene. @on.message.leave() # Marker for handler that should be called when a user leaves the scene.
async def on_leave(self, message: Message): async def on_leave(self, message: Message) -> None:
data: FSMData = await self.wizard.get_data() data: FSMData = await self.wizard.get_data()
name = data.get("name", "Anonymous") name = data.get("name", "Anonymous")
await message.answer(f"Nice to meet you, {html.quote(name)}!") await message.answer(f"Nice to meet you, {html.quote(name)}!")
@on.message(after=After.goto(LikeBotsScene)) @on.message(after=After.goto(LikeBotsScene))
async def input_name(self, message: Message): async def input_name(self, message: Message) -> None:
await self.wizard.update_data(name=message.text) await self.wizard.update_data(name=message.text)
@ -154,22 +154,22 @@ class DefaultScene(
start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene)) start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene))
@on.message(Command("demo")) @on.message(Command("demo"))
async def demo(self, message: Message): async def demo(self, message: Message) -> None:
await message.answer( await message.answer(
"Demo started", "Demo started",
reply_markup=InlineKeyboardMarkup( reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]] inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]],
), ),
) )
@on.callback_query(F.data == "start", after=After.goto(NameScene)) @on.callback_query(F.data == "start", after=After.goto(NameScene))
async def demo_callback(self, callback_query: CallbackQuery): async def demo_callback(self, callback_query: CallbackQuery) -> None:
await callback_query.answer(cache_time=0) await callback_query.answer(cache_time=0)
await callback_query.message.delete_reply_markup() await callback_query.message.delete_reply_markup()
@on.message.enter() # Mark that this handler should be called when a user enters the scene. @on.message.enter() # Mark that this handler should be called when a user enters the scene.
@on.message() @on.message()
async def default_handler(self, message: Message): async def default_handler(self, message: Message) -> None:
await message.answer( await message.answer(
"Start demo?\nYou can also start demo via command /demo", "Start demo?\nYou can also start demo via command /demo",
reply_markup=ReplyKeyboardMarkup( reply_markup=ReplyKeyboardMarkup(

View file

@ -33,7 +33,7 @@ async def command_start_handler(message: Message) -> None:
await message.answer( await message.answer(
f"Hello, {hbold(message.from_user.full_name)}!", f"Hello, {hbold(message.from_user.full_name)}!",
reply_markup=InlineKeyboardMarkup( reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]] inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]],
), ),
) )
@ -43,7 +43,7 @@ async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message( await bot.send_message(
chat_member.chat.id, chat_member.chat.id,
f"Member {hcode(chat_member.from_user.id)} was changed " f"Member {hcode(chat_member.from_user.id)} was changed "
+ f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}", f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}",
) )

View file

@ -12,7 +12,7 @@ my_router = Router()
@my_router.message(CommandStart()) @my_router.message(CommandStart())
async def command_start(message: Message, bot: Bot, base_url: str): async def command_start(message: Message, bot: Bot, base_url: str) -> None:
await bot.set_chat_menu_button( await bot.set_chat_menu_button(
chat_id=message.chat.id, chat_id=message.chat.id,
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")), menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
@ -21,28 +21,29 @@ async def command_start(message: Message, bot: Bot, base_url: str):
@my_router.message(Command("webview")) @my_router.message(Command("webview"))
async def command_webview(message: Message, base_url: str): async def command_webview(message: Message, base_url: str) -> None:
await message.answer( await message.answer(
"Good. Now you can try to send it via Webview", "Good. Now you can try to send it via Webview",
reply_markup=InlineKeyboardMarkup( reply_markup=InlineKeyboardMarkup(
inline_keyboard=[ inline_keyboard=[
[ [
InlineKeyboardButton( InlineKeyboardButton(
text="Open Webview", web_app=WebAppInfo(url=f"{base_url}/demo") text="Open Webview",
) web_app=WebAppInfo(url=f"{base_url}/demo"),
] ),
] ],
],
), ),
) )
@my_router.message(~F.message.via_bot) # Echo to all messages except messages via bot @my_router.message(~F.message.via_bot) # Echo to all messages except messages via bot
async def echo_all(message: Message, base_url: str): async def echo_all(message: Message, base_url: str) -> None:
await message.answer( await message.answer(
"Test webview", "Test webview",
reply_markup=InlineKeyboardMarkup( reply_markup=InlineKeyboardMarkup(
inline_keyboard=[ inline_keyboard=[
[InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))] [InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))],
] ],
), ),
) )

View file

@ -18,14 +18,14 @@ TOKEN = getenv("BOT_TOKEN")
APP_BASE_URL = getenv("APP_BASE_URL") APP_BASE_URL = getenv("APP_BASE_URL")
async def on_startup(bot: Bot, base_url: str): async def on_startup(bot: Bot, base_url: str) -> None:
await bot.set_webhook(f"{base_url}/webhook") await bot.set_webhook(f"{base_url}/webhook")
await bot.set_chat_menu_button( await bot.set_chat_menu_button(
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")) menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
) )
def main(): def main() -> None:
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML)) bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
dispatcher = Dispatcher() dispatcher = Dispatcher()
dispatcher["base_url"] = APP_BASE_URL dispatcher["base_url"] = APP_BASE_URL

View file

@ -1,10 +1,11 @@
from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from aiohttp.web_fileresponse import FileResponse from aiohttp.web_fileresponse import FileResponse
from aiohttp.web_request import Request
from aiohttp.web_response import json_response from aiohttp.web_response import json_response
from aiogram import Bot
from aiogram.types import ( from aiogram.types import (
InlineKeyboardButton, InlineKeyboardButton,
InlineKeyboardMarkup, InlineKeyboardMarkup,
@ -14,12 +15,18 @@ from aiogram.types import (
) )
from aiogram.utils.web_app import check_webapp_signature, safe_parse_webapp_init_data from aiogram.utils.web_app import check_webapp_signature, safe_parse_webapp_init_data
if TYPE_CHECKING:
from aiohttp.web_request import Request
from aiohttp.web_response import Response
async def demo_handler(request: Request): from aiogram import Bot
async def demo_handler(request: Request) -> FileResponse:
return FileResponse(Path(__file__).parent.resolve() / "demo.html") return FileResponse(Path(__file__).parent.resolve() / "demo.html")
async def check_data_handler(request: Request): async def check_data_handler(request: Request) -> Response:
bot: Bot = request.app["bot"] bot: Bot = request.app["bot"]
data = await request.post() data = await request.post()
@ -28,7 +35,7 @@ async def check_data_handler(request: Request):
return json_response({"ok": False, "err": "Unauthorized"}, status=401) return json_response({"ok": False, "err": "Unauthorized"}, status=401)
async def send_message_handler(request: Request): async def send_message_handler(request: Request) -> Response:
bot: Bot = request.app["bot"] bot: Bot = request.app["bot"]
data = await request.post() data = await request.post()
try: try:
@ -44,11 +51,11 @@ async def send_message_handler(request: Request):
InlineKeyboardButton( InlineKeyboardButton(
text="Open", text="Open",
web_app=WebAppInfo( web_app=WebAppInfo(
url=str(request.url.with_scheme("https").with_path("demo")) url=str(request.url.with_scheme("https").with_path("demo")),
), ),
) ),
] ],
] ],
) )
await bot.answer_web_app_query( await bot.answer_web_app_query(
web_app_query_id=web_app_init_data.query_id, web_app_query_id=web_app_init_data.query_id,

View file

@ -15,7 +15,7 @@ def create_parser() -> ArgumentParser:
return parser return parser
async def main(): async def main() -> None:
parser = create_parser() parser = create_parser()
ns = parser.parse_args() ns = parser.parse_args()

View file

@ -6,7 +6,7 @@ build-backend = "hatchling.build"
name = "aiogram" name = "aiogram"
description = 'Modern and fully asynchronous framework for Telegram Bot API' description = 'Modern and fully asynchronous framework for Telegram Bot API'
readme = "README.rst" readme = "README.rst"
requires-python = ">=3.9" requires-python = ">=3.10"
license = "MIT" license = "MIT"
authors = [ authors = [
{ name = "Alex Root Junior", email = "jroot.junior@gmail.com" }, { name = "Alex Root Junior", email = "jroot.junior@gmail.com" },
@ -30,7 +30,6 @@ classifiers = [
"Typing :: Typed", "Typing :: Typed",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: System Administrators", "Intended Audience :: System Administrators",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
@ -60,7 +59,7 @@ fast = [
"aiodns>=3.0.0", "aiodns>=3.0.0",
] ]
redis = [ redis = [
"redis[hiredis]>=5.0.1,<5.3.0", "redis[hiredis]>=6.2.0,<7",
] ]
mongo = [ mongo = [
"motor>=3.3.2,<3.7.0", "motor>=3.3.2,<3.7.0",
@ -76,7 +75,7 @@ cli = [
"aiogram-cli>=1.1.0,<2.0.0", "aiogram-cli>=1.1.0,<2.0.0",
] ]
signature = [ signature = [
"cryptography>=43.0.0", "cryptography>=46.0.0",
] ]
test = [ test = [
"pytest~=7.4.2", "pytest~=7.4.2",
@ -88,8 +87,8 @@ test = [
"pytest-cov~=4.1.0", "pytest-cov~=4.1.0",
"pytest-aiohttp~=1.0.5", "pytest-aiohttp~=1.0.5",
"aresponses~=2.1.6", "aresponses~=2.1.6",
"pytz~=2023.3", "pytz~=2025.2",
"pycryptodomex~=3.19.0", "pycryptodomex~=3.23.0",
] ]
docs = [ docs = [
"Sphinx~=8.0.2", "Sphinx~=8.0.2",
@ -105,12 +104,12 @@ docs = [
"sphinxcontrib-towncrier~=0.4.0a0", "sphinxcontrib-towncrier~=0.4.0a0",
] ]
dev = [ dev = [
"black~=24.4.2", "black~=25.9.0",
"isort~=5.13.2", "isort~=6.1.0",
"ruff~=0.5.1", "ruff~=0.13.3",
"mypy~=1.10.0", "mypy~=1.10.1",
"toml~=0.10.2", "toml~=0.10.2",
"pre-commit~=3.5", "pre-commit~=4.3.0",
"packaging~=24.1", "packaging~=24.1",
"motor-types~=1.0.0b4", "motor-types~=1.0.0b4",
] ]
@ -200,9 +199,8 @@ cov-mongo = [
] ]
view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html" view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html"
[[tool.hatch.envs.test.matrix]] [[tool.hatch.envs.test.matrix]]
python = ["39", "310", "311", "312", "313"] python = ["310", "311", "312", "313"]
[tool.ruff] [tool.ruff]
line-length = 99 line-length = 99
@ -219,7 +217,6 @@ exclude = [
"scripts", "scripts",
"*.egg-info", "*.egg-info",
] ]
target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
@ -227,13 +224,13 @@ select = [
"C4", "C4",
"E", "E",
"F", "F",
"T10",
"T20",
"Q", "Q",
"RET", "RET",
"T10",
"T20",
] ]
ignore = [ ignore = [
"F401" "F401",
] ]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
@ -280,7 +277,7 @@ exclude_lines = [
[tool.mypy] [tool.mypy]
plugins = "pydantic.mypy" plugins = "pydantic.mypy"
python_version = "3.9" python_version = "3.10"
show_error_codes = true show_error_codes = true
show_error_context = true show_error_context = true
pretty = true pretty = true
@ -315,7 +312,7 @@ disallow_untyped_defs = true
[tool.black] [tool.black]
line-length = 99 line-length = 99
target-version = ['py39', 'py310', 'py311', 'py312', 'py313'] target-version = ['py310', 'py311', 'py312', 'py313']
exclude = ''' exclude = '''
( (
\.eggs \.eggs

View file

@ -1,13 +1,11 @@
version: "3.9"
services: services:
redis: redis:
image: redis:6-alpine image: redis:8-alpine
ports: ports:
- "${REDIS_PORT-6379}:6379" - "${REDIS_PORT-6379}:6379"
mongo: mongo:
image: mongo:7.0.6 image: mongo:8.0.14
environment: environment:
MONGO_INITDB_ROOT_USERNAME: mongo MONGO_INITDB_ROOT_USERNAME: mongo
MONGO_INITDB_ROOT_PASSWORD: mongo MONGO_INITDB_ROOT_PASSWORD: mongo

View file

@ -24,8 +24,6 @@ class TestDataclassKwargs:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"py_version,expected", "py_version,expected",
[ [
((3, 9, 0), ALL_VERSIONS),
((3, 9, 2), ALL_VERSIONS),
((3, 10, 2), PY_310), ((3, 10, 2), PY_310),
((3, 11, 0), PY_311), ((3, 11, 0), PY_311),
((4, 13, 0), LATEST_PY), ((4, 13, 0), LATEST_PY),