mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
EOL of Py3.9 (#1726)
Some checks failed
Tests / tests (macos-latest, 3.10) (push) Has been cancelled
Tests / tests (macos-latest, 3.11) (push) Has been cancelled
Tests / tests (macos-latest, 3.12) (push) Has been cancelled
Tests / tests (macos-latest, 3.13) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / tests (windows-latest, 3.10) (push) Has been cancelled
Tests / tests (windows-latest, 3.11) (push) Has been cancelled
Tests / tests (windows-latest, 3.12) (push) Has been cancelled
Tests / tests (windows-latest, 3.13) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.11) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.11) (push) Has been cancelled
Some checks failed
Tests / tests (macos-latest, 3.10) (push) Has been cancelled
Tests / tests (macos-latest, 3.11) (push) Has been cancelled
Tests / tests (macos-latest, 3.12) (push) Has been cancelled
Tests / tests (macos-latest, 3.13) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / tests (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / tests (windows-latest, 3.10) (push) Has been cancelled
Tests / tests (windows-latest, 3.11) (push) Has been cancelled
Tests / tests (windows-latest, 3.12) (push) Has been cancelled
Tests / tests (windows-latest, 3.13) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (macos-latest, pypy3.11) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.10) (push) Has been cancelled
Tests / pypy-tests (ubuntu-latest, pypy3.11) (push) Has been cancelled
* Drop py3.9 and pypy3.9 Add pypy3.11 (testing) into `tests.yml` Remove py3.9 from matrix in `tests.yml` Refactor not auto-gen code to be compatible with py3.10+, droping ugly 3.9 annotation. Replace some `from typing` imports to `from collections.abc`, due to deprecation Add `from __future__ import annotations` and `if TYPE_CHECKING:` where possible Add some `noqa` to calm down Ruff in some places, if Ruff will be used as default linting+formatting tool in future Replace some relative imports to absolute Sort `__all__` tuples in `__init__.py` and some other `.py` files Sort `__slots__` tuples in classes Split raises into `msg` and `raise` (`EM101`, `EM102`) to not duplicate error message in the traceback Add `Self` from `typing_extenstion` where possible Resolve typing problem in `aiogram/filters/command.py:18` Concatenate nested `if` statements Convert `HandlerContainer` into a dataclass in `aiogram/fsm/scene.py` Bump tests docker-compose.yml `redis:6-alpine` -> `redis:8-alpine` Bump tests docker-compose.yml `mongo:7.0.6` -> `mongo:8.0.14` Bump pre-commit-config `black==24.4.2` -> `black==25.9.0` Bump pre-commit-config `ruff==0.5.1` -> `ruff==0.13.3` Update Makefile lint for ruff to show fixes Add `make outdated` into Makefile Use `pathlib` instead of `os.path` Bump `redis[hiredis]>=5.0.1,<5.3.0` -> `redis[hiredis]>=6.2.0,<7` Bump `cryptography>=43.0.0` -> `cryptography>=46.0.0` due to security reasons Bump `pytz~=2023.3` -> `pytz~=2025.2` Bump `pycryptodomex~=3.19.0` -> `pycryptodomex~=3.23.0` due to security reasons Bump linting and formatting tools * Add `1726.removal.rst` * Update aiogram/utils/dataclass.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update aiogram/filters/callback_data.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update 1726.removal.rst * Remove `outdated` from Makefile * Add `__slots__` to `HandlerContainer` * Remove unused imports * Add `@dataclass` with `slots=True` to `HandlerContainer` --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
ab32296d07
commit
df7b16d5b3
94 changed files with 1383 additions and 1215 deletions
2
.github/workflows/pull_request_changelog.yml
vendored
2
.github/workflows/pull_request_changelog.yml
vendored
|
|
@ -19,7 +19,7 @@ jobs:
|
||||||
ref: ${{ github.event.pull_request.head.sha }}
|
ref: ${{ github.event.pull_request.head.sha }}
|
||||||
fetch-depth: '0'
|
fetch-depth: '0'
|
||||||
|
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.12
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|
|
||||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
|
|
@ -29,7 +29,6 @@ jobs:
|
||||||
- macos-latest
|
- macos-latest
|
||||||
- windows-latest
|
- windows-latest
|
||||||
python-version:
|
python-version:
|
||||||
- "3.9"
|
|
||||||
- "3.10"
|
- "3.10"
|
||||||
- "3.11"
|
- "3.11"
|
||||||
- "3.12"
|
- "3.12"
|
||||||
|
|
@ -111,8 +110,8 @@ jobs:
|
||||||
- macos-latest
|
- macos-latest
|
||||||
# - windows-latest
|
# - windows-latest
|
||||||
python-version:
|
python-version:
|
||||||
- "pypy3.9"
|
|
||||||
- "pypy3.10"
|
- "pypy3.10"
|
||||||
|
- "pypy3.11"
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
# Windows sucks. Force use bash instead of PowerShell
|
# Windows sucks. Force use bash instead of PowerShell
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ repos:
|
||||||
- id: "check-json"
|
- id: "check-json"
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 24.4.2
|
rev: 25.9.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
files: &files '^(aiogram|tests|examples)'
|
files: &files '^(aiogram|tests|examples)'
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: 'v0.5.1'
|
rev: 'v0.13.3'
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
|
|
||||||
7
CHANGES/1726.removal.rst
Normal file
7
CHANGES/1726.removal.rst
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
This PR updates the codebase following the end of life for Python 3.9.
|
||||||
|
|
||||||
|
Reference: https://devguide.python.org/versions/
|
||||||
|
|
||||||
|
- Updated type annotations to Python 3.10+ style, replacing deprecated ``List``, ``Set``, etc., with built-in ``list``, ``set``, and related types.
|
||||||
|
- Refactored code by simplifying nested ``if`` expressions.
|
||||||
|
- Updated several dependencies, including security-related upgrades.
|
||||||
2
Makefile
2
Makefile
|
|
@ -39,7 +39,7 @@ install: clean
|
||||||
lint:
|
lint:
|
||||||
isort --check-only $(code_dir)
|
isort --check-only $(code_dir)
|
||||||
black --check --diff $(code_dir)
|
black --check --diff $(code_dir)
|
||||||
ruff check $(package_dir) $(examples_dir)
|
ruff check --show-fixes --preview $(package_dir) $(examples_dir)
|
||||||
mypy $(package_dir)
|
mypy $(package_dir)
|
||||||
|
|
||||||
.PHONY: reformat
|
.PHONY: reformat
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ aiogram
|
||||||
:alt: Codecov
|
:alt: Codecov
|
||||||
|
|
||||||
**aiogram** is a modern and fully asynchronous framework for
|
**aiogram** is a modern and fully asynchronous framework for
|
||||||
`Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.8+ using
|
`Telegram Bot API <https://core.telegram.org/bots/api>`_ written in Python 3.10+ using
|
||||||
`asyncio <https://docs.python.org/3/library/asyncio.html>`_ and
|
`asyncio <https://docs.python.org/3/library/asyncio.html>`_ and
|
||||||
`aiohttp <https://github.com/aio-libs/aiohttp>`_.
|
`aiohttp <https://github.com/aio-libs/aiohttp>`_.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,18 +24,18 @@ F = MagicFilter()
|
||||||
flags = FlagGenerator()
|
flags = FlagGenerator()
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
"BaseMiddleware",
|
||||||
|
"Bot",
|
||||||
|
"Dispatcher",
|
||||||
|
"F",
|
||||||
|
"Router",
|
||||||
"__api_version__",
|
"__api_version__",
|
||||||
"__version__",
|
"__version__",
|
||||||
"types",
|
|
||||||
"methods",
|
|
||||||
"enums",
|
"enums",
|
||||||
"Bot",
|
"flags",
|
||||||
"session",
|
|
||||||
"Dispatcher",
|
|
||||||
"Router",
|
|
||||||
"BaseMiddleware",
|
|
||||||
"F",
|
|
||||||
"html",
|
"html",
|
||||||
"md",
|
"md",
|
||||||
"flags",
|
"methods",
|
||||||
|
"session",
|
||||||
|
"types",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||||
class BotContextController(BaseModel):
|
class BotContextController(BaseModel):
|
||||||
_bot: Optional["Bot"] = PrivateAttr()
|
_bot: Optional["Bot"] = PrivateAttr()
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None: # noqa: PYI063
|
||||||
self._bot = __context.get("bot") if __context else None
|
self._bot = __context.get("bot") if __context else None
|
||||||
|
|
||||||
def as_(self, bot: Optional["Bot"]) -> Self:
|
def as_(self, bot: Optional["Bot"]) -> Self:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiogram.utils.dataclass import dataclass_kwargs
|
from aiogram.utils.dataclass import dataclass_kwargs
|
||||||
|
|
||||||
|
|
@ -35,25 +35,25 @@ class DefaultBotProperties:
|
||||||
Default bot properties.
|
Default bot properties.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parse_mode: Optional[str] = None
|
parse_mode: str | None = None
|
||||||
"""Default parse mode for messages."""
|
"""Default parse mode for messages."""
|
||||||
disable_notification: Optional[bool] = None
|
disable_notification: bool | None = None
|
||||||
"""Sends the message silently. Users will receive a notification with no sound."""
|
"""Sends the message silently. Users will receive a notification with no sound."""
|
||||||
protect_content: Optional[bool] = None
|
protect_content: bool | None = None
|
||||||
"""Protects content from copying."""
|
"""Protects content from copying."""
|
||||||
allow_sending_without_reply: Optional[bool] = None
|
allow_sending_without_reply: bool | None = None
|
||||||
"""Allows to send messages without reply."""
|
"""Allows to send messages without reply."""
|
||||||
link_preview: Optional[LinkPreviewOptions] = None
|
link_preview: LinkPreviewOptions | None = None
|
||||||
"""Link preview settings."""
|
"""Link preview settings."""
|
||||||
link_preview_is_disabled: Optional[bool] = None
|
link_preview_is_disabled: bool | None = None
|
||||||
"""Disables link preview."""
|
"""Disables link preview."""
|
||||||
link_preview_prefer_small_media: Optional[bool] = None
|
link_preview_prefer_small_media: bool | None = None
|
||||||
"""Prefer small media in link preview."""
|
"""Prefer small media in link preview."""
|
||||||
link_preview_prefer_large_media: Optional[bool] = None
|
link_preview_prefer_large_media: bool | None = None
|
||||||
"""Prefer large media in link preview."""
|
"""Prefer large media in link preview."""
|
||||||
link_preview_show_above_text: Optional[bool] = None
|
link_preview_show_above_text: bool | None = None
|
||||||
"""Show link preview above text."""
|
"""Show link preview above text."""
|
||||||
show_caption_above_media: Optional[bool] = None
|
show_caption_above_media: bool | None = None
|
||||||
"""Show caption above media."""
|
"""Show caption above media."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|
@ -63,11 +63,11 @@ class DefaultBotProperties:
|
||||||
self.link_preview_prefer_small_media,
|
self.link_preview_prefer_small_media,
|
||||||
self.link_preview_prefer_large_media,
|
self.link_preview_prefer_large_media,
|
||||||
self.link_preview_show_above_text,
|
self.link_preview_show_above_text,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_any_link_preview_option and self.link_preview is None:
|
if has_any_link_preview_option and self.link_preview is None:
|
||||||
from ..types import LinkPreviewOptions
|
from aiogram.types import LinkPreviewOptions
|
||||||
|
|
||||||
self.link_preview = LinkPreviewOptions(
|
self.link_preview = LinkPreviewOptions(
|
||||||
is_disabled=self.link_preview_is_disabled,
|
is_disabled=self.link_preview_is_disabled,
|
||||||
|
|
|
||||||
|
|
@ -2,45 +2,35 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import ssl
|
import ssl
|
||||||
from typing import (
|
from collections.abc import AsyncGenerator, Iterable
|
||||||
TYPE_CHECKING,
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
import certifi
|
import certifi
|
||||||
from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
|
from aiohttp import BasicAuth, ClientError, ClientSession, FormData, TCPConnector
|
||||||
from aiohttp.hdrs import USER_AGENT
|
from aiohttp.hdrs import USER_AGENT
|
||||||
from aiohttp.http import SERVER_SOFTWARE
|
from aiohttp.http import SERVER_SOFTWARE
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from aiogram.__meta__ import __version__
|
from aiogram.__meta__ import __version__
|
||||||
from aiogram.methods import TelegramMethod
|
from aiogram.exceptions import TelegramNetworkError
|
||||||
|
from aiogram.methods.base import TelegramType
|
||||||
|
|
||||||
from ...exceptions import TelegramNetworkError
|
|
||||||
from ...methods.base import TelegramType
|
|
||||||
from ...types import InputFile
|
|
||||||
from .base import BaseSession
|
from .base import BaseSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..bot import Bot
|
from aiogram.client.bot import Bot
|
||||||
|
from aiogram.methods import TelegramMethod
|
||||||
|
from aiogram.types import InputFile
|
||||||
|
|
||||||
_ProxyBasic = Union[str, Tuple[str, BasicAuth]]
|
_ProxyBasic = str | tuple[str, BasicAuth]
|
||||||
_ProxyChain = Iterable[_ProxyBasic]
|
_ProxyChain = Iterable[_ProxyBasic]
|
||||||
_ProxyType = Union[_ProxyChain, _ProxyBasic]
|
_ProxyType = _ProxyChain | _ProxyBasic
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
|
def _retrieve_basic(basic: _ProxyBasic) -> dict[str, Any]:
|
||||||
from aiohttp_socks.utils import parse_proxy_url
|
from aiohttp_socks.utils import parse_proxy_url
|
||||||
|
|
||||||
proxy_auth: Optional[BasicAuth] = None
|
proxy_auth: BasicAuth | None = None
|
||||||
|
|
||||||
if isinstance(basic, str):
|
if isinstance(basic, str):
|
||||||
proxy_url = basic
|
proxy_url = basic
|
||||||
|
|
@ -62,7 +52,7 @@ def _retrieve_basic(basic: _ProxyBasic) -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"], Dict[str, Any]]:
|
def _prepare_connector(chain_or_plain: _ProxyType) -> tuple[type[TCPConnector], dict[str, Any]]:
|
||||||
from aiohttp_socks import ChainProxyConnector, ProxyConnector, ProxyInfo
|
from aiohttp_socks import ChainProxyConnector, ProxyConnector, ProxyInfo
|
||||||
|
|
||||||
# since tuple is Iterable(compatible with _ProxyChain) object, we assume that
|
# since tuple is Iterable(compatible with _ProxyChain) object, we assume that
|
||||||
|
|
@ -74,17 +64,13 @@ def _prepare_connector(chain_or_plain: _ProxyType) -> Tuple[Type["TCPConnector"]
|
||||||
return ProxyConnector, _retrieve_basic(chain_or_plain)
|
return ProxyConnector, _retrieve_basic(chain_or_plain)
|
||||||
|
|
||||||
chain_or_plain = cast(_ProxyChain, chain_or_plain)
|
chain_or_plain = cast(_ProxyChain, chain_or_plain)
|
||||||
infos: List[ProxyInfo] = []
|
infos: list[ProxyInfo] = [ProxyInfo(**_retrieve_basic(basic)) for basic in chain_or_plain]
|
||||||
for basic in chain_or_plain:
|
|
||||||
infos.append(ProxyInfo(**_retrieve_basic(basic)))
|
|
||||||
|
|
||||||
return ChainProxyConnector, {"proxy_infos": infos}
|
return ChainProxyConnector, {"proxy_infos": infos}
|
||||||
|
|
||||||
|
|
||||||
class AiohttpSession(BaseSession):
|
class AiohttpSession(BaseSession):
|
||||||
def __init__(
|
def __init__(self, proxy: _ProxyType | None = None, limit: int = 100, **kwargs: Any) -> None:
|
||||||
self, proxy: Optional[_ProxyType] = None, limit: int = 100, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Client session based on aiohttp.
|
Client session based on aiohttp.
|
||||||
|
|
||||||
|
|
@ -94,31 +80,32 @@ class AiohttpSession(BaseSession):
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self._session: Optional[ClientSession] = None
|
self._session: ClientSession | None = None
|
||||||
self._connector_type: Type[TCPConnector] = TCPConnector
|
self._connector_type: type[TCPConnector] = TCPConnector
|
||||||
self._connector_init: Dict[str, Any] = {
|
self._connector_init: dict[str, Any] = {
|
||||||
"ssl": ssl.create_default_context(cafile=certifi.where()),
|
"ssl": ssl.create_default_context(cafile=certifi.where()),
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"ttl_dns_cache": 3600, # Workaround for https://github.com/aiogram/aiogram/issues/1500
|
"ttl_dns_cache": 3600, # Workaround for https://github.com/aiogram/aiogram/issues/1500
|
||||||
}
|
}
|
||||||
self._should_reset_connector = True # flag determines connector state
|
self._should_reset_connector = True # flag determines connector state
|
||||||
self._proxy: Optional[_ProxyType] = None
|
self._proxy: _ProxyType | None = None
|
||||||
|
|
||||||
if proxy is not None:
|
if proxy is not None:
|
||||||
try:
|
try:
|
||||||
self._setup_proxy_connector(proxy)
|
self._setup_proxy_connector(proxy)
|
||||||
except ImportError as exc: # pragma: no cover
|
except ImportError as exc: # pragma: no cover
|
||||||
raise RuntimeError(
|
msg = (
|
||||||
"In order to use aiohttp client for proxy requests, install "
|
"In order to use aiohttp client for proxy requests, install "
|
||||||
"https://pypi.org/project/aiohttp-socks/"
|
"https://pypi.org/project/aiohttp-socks/"
|
||||||
) from exc
|
)
|
||||||
|
raise RuntimeError(msg) from exc
|
||||||
|
|
||||||
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
|
def _setup_proxy_connector(self, proxy: _ProxyType) -> None:
|
||||||
self._connector_type, self._connector_init = _prepare_connector(proxy)
|
self._connector_type, self._connector_init = _prepare_connector(proxy)
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def proxy(self) -> Optional[_ProxyType]:
|
def proxy(self) -> _ProxyType | None:
|
||||||
return self._proxy
|
return self._proxy
|
||||||
|
|
||||||
@proxy.setter
|
@proxy.setter
|
||||||
|
|
@ -151,7 +138,7 @@ class AiohttpSession(BaseSession):
|
||||||
|
|
||||||
def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData:
|
def build_form_data(self, bot: Bot, method: TelegramMethod[TelegramType]) -> FormData:
|
||||||
form = FormData(quote_fields=False)
|
form = FormData(quote_fields=False)
|
||||||
files: Dict[str, InputFile] = {}
|
files: dict[str, InputFile] = {}
|
||||||
for key, value in method.model_dump(warnings=False).items():
|
for key, value in method.model_dump(warnings=False).items():
|
||||||
value = self.prepare_value(value, bot=bot, files=files)
|
value = self.prepare_value(value, bot=bot, files=files)
|
||||||
if not value:
|
if not value:
|
||||||
|
|
@ -166,7 +153,10 @@ class AiohttpSession(BaseSession):
|
||||||
return form
|
return form
|
||||||
|
|
||||||
async def make_request(
|
async def make_request(
|
||||||
self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = None
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
timeout: int | None = None,
|
||||||
) -> TelegramType:
|
) -> TelegramType:
|
||||||
session = await self.create_session()
|
session = await self.create_session()
|
||||||
|
|
||||||
|
|
@ -175,7 +165,9 @@ class AiohttpSession(BaseSession):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url, data=form, timeout=self.timeout if timeout is None else timeout
|
url,
|
||||||
|
data=form,
|
||||||
|
timeout=self.timeout if timeout is None else timeout,
|
||||||
) as resp:
|
) as resp:
|
||||||
raw_result = await resp.text()
|
raw_result = await resp.text()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
@ -183,14 +175,17 @@ class AiohttpSession(BaseSession):
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}")
|
raise TelegramNetworkError(method=method, message=f"{type(e).__name__}: {e}")
|
||||||
response = self.check_response(
|
response = self.check_response(
|
||||||
bot=bot, method=method, status_code=resp.status, content=raw_result
|
bot=bot,
|
||||||
|
method=method,
|
||||||
|
status_code=resp.status,
|
||||||
|
content=raw_result,
|
||||||
)
|
)
|
||||||
return cast(TelegramType, response.result)
|
return cast(TelegramType, response.result)
|
||||||
|
|
||||||
async def stream_content(
|
async def stream_content(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
headers: Optional[Dict[str, Any]] = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
chunk_size: int = 65536,
|
chunk_size: int = 65536,
|
||||||
raise_for_status: bool = True,
|
raise_for_status: bool = True,
|
||||||
|
|
@ -201,11 +196,14 @@ class AiohttpSession(BaseSession):
|
||||||
session = await self.create_session()
|
session = await self.create_session()
|
||||||
|
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, timeout=timeout, headers=headers, raise_for_status=raise_for_status
|
url,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
raise_for_status=raise_for_status,
|
||||||
) as resp:
|
) as resp:
|
||||||
async for chunk in resp.content.iter_chunked(chunk_size):
|
async for chunk in resp.content.iter_chunked(chunk_size):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def __aenter__(self) -> AiohttpSession:
|
async def __aenter__(self) -> Self:
|
||||||
await self.create_session()
|
await self.create_session()
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -4,23 +4,16 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import secrets
|
import secrets
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from types import TracebackType
|
from typing import TYPE_CHECKING, Any, Final, cast
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Final,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from aiogram.client.default import Default
|
||||||
|
from aiogram.client.telegram import PRODUCTION, TelegramAPIServer
|
||||||
from aiogram.exceptions import (
|
from aiogram.exceptions import (
|
||||||
ClientDecodeError,
|
ClientDecodeError,
|
||||||
RestartingTelegram,
|
RestartingTelegram,
|
||||||
|
|
@ -35,16 +28,16 @@ from aiogram.exceptions import (
|
||||||
TelegramServerError,
|
TelegramServerError,
|
||||||
TelegramUnauthorizedError,
|
TelegramUnauthorizedError,
|
||||||
)
|
)
|
||||||
|
from aiogram.methods import Response, TelegramMethod
|
||||||
|
from aiogram.methods.base import TelegramType
|
||||||
|
from aiogram.types import InputFile, TelegramObject
|
||||||
|
|
||||||
from ...methods import Response, TelegramMethod
|
|
||||||
from ...methods.base import TelegramType
|
|
||||||
from ...types import InputFile, TelegramObject
|
|
||||||
from ..default import Default
|
|
||||||
from ..telegram import PRODUCTION, TelegramAPIServer
|
|
||||||
from .middlewares.manager import RequestMiddlewareManager
|
from .middlewares.manager import RequestMiddlewareManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..bot import Bot
|
from types import TracebackType
|
||||||
|
|
||||||
|
from aiogram.client.bot import Bot
|
||||||
|
|
||||||
_JsonLoads = Callable[..., Any]
|
_JsonLoads = Callable[..., Any]
|
||||||
_JsonDumps = Callable[..., str]
|
_JsonDumps = Callable[..., str]
|
||||||
|
|
@ -81,24 +74,30 @@ class BaseSession(abc.ABC):
|
||||||
self.middleware = RequestMiddlewareManager()
|
self.middleware = RequestMiddlewareManager()
|
||||||
|
|
||||||
def check_response(
|
def check_response(
|
||||||
self, bot: Bot, method: TelegramMethod[TelegramType], status_code: int, content: str
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
status_code: int,
|
||||||
|
content: str,
|
||||||
) -> Response[TelegramType]:
|
) -> Response[TelegramType]:
|
||||||
"""
|
"""
|
||||||
Check response status
|
Check response status
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
json_data = self.json_loads(content)
|
json_data = self.json_loads(content)
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
# Handled error type can't be classified as specific error
|
# Handled error type can't be classified as specific error
|
||||||
# in due to decoder can be customized and raise any exception
|
# in due to decoder can be customized and raise any exception
|
||||||
|
|
||||||
raise ClientDecodeError("Failed to decode object", e, content)
|
msg = "Failed to decode object"
|
||||||
|
raise ClientDecodeError(msg, e, content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_type = Response[method.__returning__] # type: ignore
|
response_type = Response[method.__returning__] # type: ignore
|
||||||
response = response_type.model_validate(json_data, context={"bot": bot})
|
response = response_type.model_validate(json_data, context={"bot": bot})
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise ClientDecodeError("Failed to deserialize object", e, json_data)
|
msg = "Failed to deserialize object"
|
||||||
|
raise ClientDecodeError(msg, e, json_data)
|
||||||
|
|
||||||
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED and response.ok:
|
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED and response.ok:
|
||||||
return response
|
return response
|
||||||
|
|
@ -108,7 +107,9 @@ class BaseSession(abc.ABC):
|
||||||
if parameters := response.parameters:
|
if parameters := response.parameters:
|
||||||
if parameters.retry_after:
|
if parameters.retry_after:
|
||||||
raise TelegramRetryAfter(
|
raise TelegramRetryAfter(
|
||||||
method=method, message=description, retry_after=parameters.retry_after
|
method=method,
|
||||||
|
message=description,
|
||||||
|
retry_after=parameters.retry_after,
|
||||||
)
|
)
|
||||||
if parameters.migrate_to_chat_id:
|
if parameters.migrate_to_chat_id:
|
||||||
raise TelegramMigrateToChat(
|
raise TelegramMigrateToChat(
|
||||||
|
|
@ -143,14 +144,13 @@ class BaseSession(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Close client session
|
Close client session
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def make_request(
|
async def make_request(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
method: TelegramMethod[TelegramType],
|
method: TelegramMethod[TelegramType],
|
||||||
timeout: Optional[int] = None,
|
timeout: int | None = None,
|
||||||
) -> TelegramType: # pragma: no cover
|
) -> TelegramType: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Make request to Telegram Bot API
|
Make request to Telegram Bot API
|
||||||
|
|
@ -161,13 +161,12 @@ class BaseSession(abc.ABC):
|
||||||
:return:
|
:return:
|
||||||
:raise TelegramApiError:
|
:raise TelegramApiError:
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def stream_content(
|
async def stream_content(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
headers: Optional[Dict[str, Any]] = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
chunk_size: int = 65536,
|
chunk_size: int = 65536,
|
||||||
raise_for_status: bool = True,
|
raise_for_status: bool = True,
|
||||||
|
|
@ -181,7 +180,7 @@ class BaseSession(abc.ABC):
|
||||||
self,
|
self,
|
||||||
value: Any,
|
value: Any,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
files: Dict[str, Any],
|
files: dict[str, Any],
|
||||||
_dumps_json: bool = True,
|
_dumps_json: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
@ -204,7 +203,10 @@ class BaseSession(abc.ABC):
|
||||||
for key, item in value.items()
|
for key, item in value.items()
|
||||||
if (
|
if (
|
||||||
prepared_item := self.prepare_value(
|
prepared_item := self.prepare_value(
|
||||||
item, bot=bot, files=files, _dumps_json=False
|
item,
|
||||||
|
bot=bot,
|
||||||
|
files=files,
|
||||||
|
_dumps_json=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
is not None
|
is not None
|
||||||
|
|
@ -218,7 +220,10 @@ class BaseSession(abc.ABC):
|
||||||
for item in value
|
for item in value
|
||||||
if (
|
if (
|
||||||
prepared_item := self.prepare_value(
|
prepared_item := self.prepare_value(
|
||||||
item, bot=bot, files=files, _dumps_json=False
|
item,
|
||||||
|
bot=bot,
|
||||||
|
files=files,
|
||||||
|
_dumps_json=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
is not None
|
is not None
|
||||||
|
|
@ -227,7 +232,7 @@ class BaseSession(abc.ABC):
|
||||||
return self.json_dumps(value)
|
return self.json_dumps(value)
|
||||||
return value
|
return value
|
||||||
if isinstance(value, datetime.timedelta):
|
if isinstance(value, datetime.timedelta):
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now() # noqa: DTZ005
|
||||||
return str(round((now + value).timestamp()))
|
return str(round((now + value).timestamp()))
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return str(round(value.timestamp()))
|
return str(round(value.timestamp()))
|
||||||
|
|
@ -248,18 +253,18 @@ class BaseSession(abc.ABC):
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
method: TelegramMethod[TelegramType],
|
method: TelegramMethod[TelegramType],
|
||||||
timeout: Optional[int] = None,
|
timeout: int | None = None,
|
||||||
) -> TelegramType:
|
) -> TelegramType:
|
||||||
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
|
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
|
||||||
return cast(TelegramType, await middleware(bot, method))
|
return cast(TelegramType, await middleware(bot, method))
|
||||||
|
|
||||||
async def __aenter__(self) -> BaseSession:
|
async def __aenter__(self) -> Self:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Optional[Type[BaseException]],
|
exc_type: type[BaseException] | None,
|
||||||
exc_value: Optional[BaseException],
|
exc_value: BaseException | None,
|
||||||
traceback: Optional[TracebackType],
|
traceback: TracebackType | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,17 @@ from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Protocol
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
from aiogram.methods import Response, TelegramMethod
|
|
||||||
from aiogram.methods.base import TelegramType
|
from aiogram.methods.base import TelegramType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ...bot import Bot
|
from aiogram.client.bot import Bot
|
||||||
|
from aiogram.methods import Response, TelegramMethod
|
||||||
|
|
||||||
|
|
||||||
class NextRequestMiddlewareType(Protocol[TelegramType]): # pragma: no cover
|
class NextRequestMiddlewareType(Protocol[TelegramType]): # pragma: no cover
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
bot: "Bot",
|
bot: Bot,
|
||||||
method: TelegramMethod[TelegramType],
|
method: TelegramMethod[TelegramType],
|
||||||
) -> Response[TelegramType]:
|
) -> Response[TelegramType]:
|
||||||
pass
|
pass
|
||||||
|
|
@ -23,7 +23,7 @@ class RequestMiddlewareType(Protocol): # pragma: no cover
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
make_request: NextRequestMiddlewareType[TelegramType],
|
make_request: NextRequestMiddlewareType[TelegramType],
|
||||||
bot: "Bot",
|
bot: Bot,
|
||||||
method: TelegramMethod[TelegramType],
|
method: TelegramMethod[TelegramType],
|
||||||
) -> Response[TelegramType]:
|
) -> Response[TelegramType]:
|
||||||
pass
|
pass
|
||||||
|
|
@ -38,7 +38,7 @@ class BaseRequestMiddleware(ABC):
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
make_request: NextRequestMiddlewareType[TelegramType],
|
make_request: NextRequestMiddlewareType[TelegramType],
|
||||||
bot: "Bot",
|
bot: Bot,
|
||||||
method: TelegramMethod[TelegramType],
|
method: TelegramMethod[TelegramType],
|
||||||
) -> Response[TelegramType]:
|
) -> Response[TelegramType]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -50,4 +50,3 @@ class BaseRequestMiddleware(ABC):
|
||||||
|
|
||||||
:return: :class:`aiogram.methods.Response`
|
:return: :class:`aiogram.methods.Response`
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
|
from typing import Any, cast, overload
|
||||||
|
|
||||||
from aiogram.client.session.middlewares.base import (
|
from aiogram.client.session.middlewares.base import (
|
||||||
NextRequestMiddlewareType,
|
NextRequestMiddlewareType,
|
||||||
|
|
@ -12,7 +13,7 @@ from aiogram.methods.base import TelegramType
|
||||||
|
|
||||||
class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
|
class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._middlewares: List[RequestMiddlewareType] = []
|
self._middlewares: list[RequestMiddlewareType] = []
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
|
|
@ -26,11 +27,8 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
middleware: Optional[RequestMiddlewareType] = None,
|
middleware: RequestMiddlewareType | None = None,
|
||||||
) -> Union[
|
) -> Callable[[RequestMiddlewareType], RequestMiddlewareType] | RequestMiddlewareType:
|
||||||
Callable[[RequestMiddlewareType], RequestMiddlewareType],
|
|
||||||
RequestMiddlewareType,
|
|
||||||
]:
|
|
||||||
if middleware is None:
|
if middleware is None:
|
||||||
return self.register
|
return self.register
|
||||||
return self.register(middleware)
|
return self.register(middleware)
|
||||||
|
|
@ -44,8 +42,9 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice]
|
self,
|
||||||
) -> Union[RequestMiddlewareType, Sequence[RequestMiddlewareType]]:
|
item: int | slice,
|
||||||
|
) -> RequestMiddlewareType | Sequence[RequestMiddlewareType]:
|
||||||
return self._middlewares[item]
|
return self._middlewares[item]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiogram import loggers
|
from aiogram import loggers
|
||||||
from aiogram.methods import TelegramMethod
|
from aiogram.methods import TelegramMethod
|
||||||
|
|
@ -8,19 +8,19 @@ from aiogram.methods.base import Response, TelegramType
|
||||||
from .base import BaseRequestMiddleware, NextRequestMiddlewareType
|
from .base import BaseRequestMiddleware, NextRequestMiddlewareType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ...bot import Bot
|
from aiogram.client.bot import Bot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RequestLogging(BaseRequestMiddleware):
|
class RequestLogging(BaseRequestMiddleware):
|
||||||
def __init__(self, ignore_methods: Optional[List[Type[TelegramMethod[Any]]]] = None):
|
def __init__(self, ignore_methods: list[type[TelegramMethod[Any]]] | None = None):
|
||||||
"""
|
"""
|
||||||
Middleware for logging outgoing requests
|
Middleware for logging outgoing requests
|
||||||
|
|
||||||
:param ignore_methods: methods to ignore in logging middleware
|
:param ignore_methods: methods to ignore in logging middleware
|
||||||
"""
|
"""
|
||||||
self.ignore_methods = ignore_methods if ignore_methods else []
|
self.ignore_methods = ignore_methods or []
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,24 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class FilesPathWrapper(ABC):
|
class FilesPathWrapper(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_local(self, path: Path | str) -> Path | str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_server(self, path: Path | str) -> Path | str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BareFilesPathWrapper(FilesPathWrapper):
|
class BareFilesPathWrapper(FilesPathWrapper):
|
||||||
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_local(self, path: Path | str) -> Path | str:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_server(self, path: Path | str) -> Path | str:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,15 +29,18 @@ class SimpleFilesPathWrapper(FilesPathWrapper):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve(
|
def _resolve(
|
||||||
cls, base1: Union[Path, str], base2: Union[Path, str], value: Union[Path, str]
|
cls,
|
||||||
|
base1: Path | str,
|
||||||
|
base2: Path | str,
|
||||||
|
value: Path | str,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
relative = Path(value).relative_to(base1)
|
relative = Path(value).relative_to(base1)
|
||||||
return base2 / relative
|
return base2 / relative
|
||||||
|
|
||||||
def to_local(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_local(self, path: Path | str) -> Path | str:
|
||||||
return self._resolve(base1=self.server_path, base2=self.local_path, value=path)
|
return self._resolve(base1=self.server_path, base2=self.local_path, value=path)
|
||||||
|
|
||||||
def to_server(self, path: Union[Path, str]) -> Union[Path, str]:
|
def to_server(self, path: Path | str) -> Path | str:
|
||||||
return self._resolve(base1=self.local_path, base2=self.server_path, value=path)
|
return self._resolve(base1=self.local_path, base2=self.server_path, value=path)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +57,7 @@ class TelegramAPIServer:
|
||||||
is_local: bool = False
|
is_local: bool = False
|
||||||
"""Mark this server is
|
"""Mark this server is
|
||||||
in `local mode <https://core.telegram.org/bots/api#using-a-local-bot-api-server>`_."""
|
in `local mode <https://core.telegram.org/bots/api#using-a-local-bot-api-server>`_."""
|
||||||
wrap_local_file: FilesPathWrapper = BareFilesPathWrapper()
|
wrap_local_file: FilesPathWrapper = field(default=BareFilesPathWrapper())
|
||||||
"""Callback to wrap files path in local mode"""
|
"""Callback to wrap files path in local mode"""
|
||||||
|
|
||||||
def api_url(self, token: str, method: str) -> str:
|
def api_url(self, token: str, method: str) -> str:
|
||||||
|
|
@ -67,7 +70,7 @@ class TelegramAPIServer:
|
||||||
"""
|
"""
|
||||||
return self.base.format(token=token, method=method)
|
return self.base.format(token=token, method=method)
|
||||||
|
|
||||||
def file_url(self, token: str, path: Union[str, Path]) -> str:
|
def file_url(self, token: str, path: str | Path) -> str:
|
||||||
"""
|
"""
|
||||||
Generate URL for downloading files
|
Generate URL for downloading files
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,28 +5,32 @@ import contextvars
|
||||||
import signal
|
import signal
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import CancelledError, Event, Future, Lock
|
from asyncio import CancelledError, Event, Future, Lock
|
||||||
|
from collections.abc import AsyncGenerator, Awaitable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Any, AsyncGenerator, Awaitable, Dict, List, Optional, Set, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from aiogram import loggers
|
||||||
|
from aiogram.exceptions import TelegramAPIError
|
||||||
|
from aiogram.fsm.middleware import FSMContextMiddleware
|
||||||
|
from aiogram.fsm.storage.base import BaseEventIsolation, BaseStorage
|
||||||
|
from aiogram.fsm.storage.memory import DisabledEventIsolation, MemoryStorage
|
||||||
|
from aiogram.fsm.strategy import FSMStrategy
|
||||||
|
from aiogram.methods import GetUpdates, TelegramMethod
|
||||||
|
from aiogram.types import Update, User
|
||||||
|
from aiogram.types.base import UNSET, UNSET_TYPE
|
||||||
|
from aiogram.types.update import UpdateTypeLookupError
|
||||||
|
from aiogram.utils.backoff import Backoff, BackoffConfig
|
||||||
|
|
||||||
from .. import loggers
|
|
||||||
from ..client.bot import Bot
|
|
||||||
from ..exceptions import TelegramAPIError
|
|
||||||
from ..fsm.middleware import FSMContextMiddleware
|
|
||||||
from ..fsm.storage.base import BaseEventIsolation, BaseStorage
|
|
||||||
from ..fsm.storage.memory import DisabledEventIsolation, MemoryStorage
|
|
||||||
from ..fsm.strategy import FSMStrategy
|
|
||||||
from ..methods import GetUpdates, TelegramMethod
|
|
||||||
from ..methods.base import TelegramType
|
|
||||||
from ..types import Update, User
|
|
||||||
from ..types.base import UNSET, UNSET_TYPE
|
|
||||||
from ..types.update import UpdateTypeLookupError
|
|
||||||
from ..utils.backoff import Backoff, BackoffConfig
|
|
||||||
from .event.bases import UNHANDLED, SkipHandler
|
from .event.bases import UNHANDLED, SkipHandler
|
||||||
from .event.telegram import TelegramEventObserver
|
from .event.telegram import TelegramEventObserver
|
||||||
from .middlewares.error import ErrorsMiddleware
|
from .middlewares.error import ErrorsMiddleware
|
||||||
from .middlewares.user_context import UserContextMiddleware
|
from .middlewares.user_context import UserContextMiddleware
|
||||||
from .router import Router
|
from .router import Router
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiogram.client.bot import Bot
|
||||||
|
from aiogram.methods.base import TelegramType
|
||||||
|
|
||||||
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
|
DEFAULT_BACKOFF_CONFIG = BackoffConfig(min_delay=1.0, max_delay=5.0, factor=1.3, jitter=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,11 +42,11 @@ class Dispatcher(Router):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*, # * - Preventing to pass instance of Bot to the FSM storage
|
*, # * - Preventing to pass instance of Bot to the FSM storage
|
||||||
storage: Optional[BaseStorage] = None,
|
storage: BaseStorage | None = None,
|
||||||
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
||||||
events_isolation: Optional[BaseEventIsolation] = None,
|
events_isolation: BaseEventIsolation | None = None,
|
||||||
disable_fsm: bool = False,
|
disable_fsm: bool = False,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -55,18 +59,18 @@ class Dispatcher(Router):
|
||||||
then you should not use storage and events isolation
|
then you should not use storage and events isolation
|
||||||
:param kwargs: Other arguments, will be passed as keyword arguments to handlers
|
:param kwargs: Other arguments, will be passed as keyword arguments to handlers
|
||||||
"""
|
"""
|
||||||
super(Dispatcher, self).__init__(name=name)
|
super().__init__(name=name)
|
||||||
|
|
||||||
if storage and not isinstance(storage, BaseStorage):
|
if storage and not isinstance(storage, BaseStorage):
|
||||||
raise TypeError(
|
msg = f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}"
|
||||||
f"FSM storage should be instance of 'BaseStorage' not {type(storage).__name__}"
|
raise TypeError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
# Telegram API provides originally only one event type - Update
|
# Telegram API provides originally only one event type - Update
|
||||||
# For making easily interactions with events here is registered handler which helps
|
# For making easily interactions with events here is registered handler which helps
|
||||||
# to separate Update to different event types like Message, CallbackQuery etc.
|
# to separate Update to different event types like Message, CallbackQuery etc.
|
||||||
self.update = self.observers["update"] = TelegramEventObserver(
|
self.update = self.observers["update"] = TelegramEventObserver(
|
||||||
router=self, event_name="update"
|
router=self,
|
||||||
|
event_name="update",
|
||||||
)
|
)
|
||||||
self.update.register(self._listen_update)
|
self.update.register(self._listen_update)
|
||||||
|
|
||||||
|
|
@ -91,11 +95,11 @@ class Dispatcher(Router):
|
||||||
self.update.outer_middleware(self.fsm)
|
self.update.outer_middleware(self.fsm)
|
||||||
self.shutdown.register(self.fsm.close)
|
self.shutdown.register(self.fsm.close)
|
||||||
|
|
||||||
self.workflow_data: Dict[str, Any] = kwargs
|
self.workflow_data: dict[str, Any] = kwargs
|
||||||
self._running_lock = Lock()
|
self._running_lock = Lock()
|
||||||
self._stop_signal: Optional[Event] = None
|
self._stop_signal: Event | None = None
|
||||||
self._stopped_signal: Optional[Event] = None
|
self._stopped_signal: Event | None = None
|
||||||
self._handle_update_tasks: Set[asyncio.Task[Any]] = set()
|
self._handle_update_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
def __getitem__(self, item: str) -> Any:
|
def __getitem__(self, item: str) -> Any:
|
||||||
return self.workflow_data[item]
|
return self.workflow_data[item]
|
||||||
|
|
@ -106,7 +110,7 @@ class Dispatcher(Router):
|
||||||
def __delitem__(self, key: str) -> None:
|
def __delitem__(self, key: str) -> None:
|
||||||
del self.workflow_data[key]
|
del self.workflow_data[key]
|
||||||
|
|
||||||
def get(self, key: str, /, default: Optional[Any] = None) -> Optional[Any]:
|
def get(self, key: str, /, default: Any | None = None) -> Any | None:
|
||||||
return self.workflow_data.get(key, default)
|
return self.workflow_data.get(key, default)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -114,13 +118,13 @@ class Dispatcher(Router):
|
||||||
return self.fsm.storage
|
return self.fsm.storage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parent_router(self) -> Optional[Router]:
|
def parent_router(self) -> Router | None:
|
||||||
"""
|
"""
|
||||||
Dispatcher has no parent router and can't be included to any other routers or dispatchers
|
Dispatcher has no parent router and can't be included to any other routers or dispatchers
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return None # noqa: RET501
|
return None
|
||||||
|
|
||||||
@parent_router.setter
|
@parent_router.setter
|
||||||
def parent_router(self, value: Router) -> None:
|
def parent_router(self, value: Router) -> None:
|
||||||
|
|
@ -130,7 +134,8 @@ class Dispatcher(Router):
|
||||||
:param value:
|
:param value:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise RuntimeError("Dispatcher can not be attached to another Router.")
|
msg = "Dispatcher can not be attached to another Router."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
async def feed_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
@ -177,7 +182,7 @@ class Dispatcher(Router):
|
||||||
bot.id,
|
bot.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def feed_raw_update(self, bot: Bot, update: Dict[str, Any], **kwargs: Any) -> Any:
|
async def feed_raw_update(self, bot: Bot, update: dict[str, Any], **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Main entry point for incoming updates with automatic Dict->Update serializer
|
Main entry point for incoming updates with automatic Dict->Update serializer
|
||||||
|
|
||||||
|
|
@ -194,7 +199,7 @@ class Dispatcher(Router):
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
polling_timeout: int = 30,
|
polling_timeout: int = 30,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[List[str]] = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> AsyncGenerator[Update, None]:
|
) -> AsyncGenerator[Update, None]:
|
||||||
"""
|
"""
|
||||||
Endless updates reader with correctly handling any server-side or connection errors.
|
Endless updates reader with correctly handling any server-side or connection errors.
|
||||||
|
|
@ -212,7 +217,7 @@ class Dispatcher(Router):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
updates = await bot(get_updates, **kwargs)
|
updates = await bot(get_updates, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
failed = True
|
failed = True
|
||||||
# In cases when Telegram Bot API was inaccessible don't need to stop polling
|
# In cases when Telegram Bot API was inaccessible don't need to stop polling
|
||||||
# process because some developers can't make auto-restarting of the script
|
# process because some developers can't make auto-restarting of the script
|
||||||
|
|
@ -268,6 +273,7 @@ class Dispatcher(Router):
|
||||||
"installed not latest version of aiogram framework"
|
"installed not latest version of aiogram framework"
|
||||||
f"\nUpdate: {update.model_dump_json(exclude_unset=True)}",
|
f"\nUpdate: {update.model_dump_json(exclude_unset=True)}",
|
||||||
RuntimeWarning,
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
raise SkipHandler() from e
|
raise SkipHandler() from e
|
||||||
|
|
||||||
|
|
@ -294,7 +300,11 @@ class Dispatcher(Router):
|
||||||
loggers.event.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
|
loggers.event.error("Failed to make answer: %s: %s", e.__class__.__name__, e)
|
||||||
|
|
||||||
async def _process_update(
|
async def _process_update(
|
||||||
self, bot: Bot, update: Update, call_answer: bool = True, **kwargs: Any
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
update: Update,
|
||||||
|
call_answer: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Propagate update to event listeners
|
Propagate update to event listeners
|
||||||
|
|
@ -309,9 +319,8 @@ class Dispatcher(Router):
|
||||||
response = await self.feed_update(bot, update, **kwargs)
|
response = await self.feed_update(bot, update, **kwargs)
|
||||||
if call_answer and isinstance(response, TelegramMethod):
|
if call_answer and isinstance(response, TelegramMethod):
|
||||||
await self.silent_call_request(bot=bot, result=response)
|
await self.silent_call_request(bot=bot, result=response)
|
||||||
return response is not UNHANDLED
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e: # noqa: BLE001
|
||||||
loggers.event.exception(
|
loggers.event.exception(
|
||||||
"Cause exception while process update id=%d by bot id=%d\n%s: %s",
|
"Cause exception while process update id=%d by bot id=%d\n%s: %s",
|
||||||
update.update_id,
|
update.update_id,
|
||||||
|
|
@ -321,8 +330,13 @@ class Dispatcher(Router):
|
||||||
)
|
)
|
||||||
return True # because update was processed but unsuccessful
|
return True # because update was processed but unsuccessful
|
||||||
|
|
||||||
|
else:
|
||||||
|
return response is not UNHANDLED
|
||||||
|
|
||||||
async def _process_with_semaphore(
|
async def _process_with_semaphore(
|
||||||
self, handle_update: Awaitable[bool], semaphore: asyncio.Semaphore
|
self,
|
||||||
|
handle_update: Awaitable[bool],
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Process update with semaphore to limit concurrent tasks
|
Process update with semaphore to limit concurrent tasks
|
||||||
|
|
@ -342,8 +356,8 @@ class Dispatcher(Router):
|
||||||
polling_timeout: int = 30,
|
polling_timeout: int = 30,
|
||||||
handle_as_tasks: bool = True,
|
handle_as_tasks: bool = True,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[List[str]] = None,
|
allowed_updates: list[str] | None = None,
|
||||||
tasks_concurrency_limit: Optional[int] = None,
|
tasks_concurrency_limit: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -361,7 +375,10 @@ class Dispatcher(Router):
|
||||||
"""
|
"""
|
||||||
user: User = await bot.me()
|
user: User = await bot.me()
|
||||||
loggers.dispatcher.info(
|
loggers.dispatcher.info(
|
||||||
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
"Run polling for bot @%s id=%d - %r",
|
||||||
|
user.username,
|
||||||
|
bot.id,
|
||||||
|
user.full_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create semaphore if tasks_concurrency_limit is specified
|
# Create semaphore if tasks_concurrency_limit is specified
|
||||||
|
|
@ -382,7 +399,7 @@ class Dispatcher(Router):
|
||||||
# Use semaphore to limit concurrent tasks
|
# Use semaphore to limit concurrent tasks
|
||||||
await semaphore.acquire()
|
await semaphore.acquire()
|
||||||
handle_update_task = asyncio.create_task(
|
handle_update_task = asyncio.create_task(
|
||||||
self._process_with_semaphore(handle_update, semaphore)
|
self._process_with_semaphore(handle_update, semaphore),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
handle_update_task = asyncio.create_task(handle_update)
|
handle_update_task = asyncio.create_task(handle_update)
|
||||||
|
|
@ -393,7 +410,10 @@ class Dispatcher(Router):
|
||||||
await handle_update
|
await handle_update
|
||||||
finally:
|
finally:
|
||||||
loggers.dispatcher.info(
|
loggers.dispatcher.info(
|
||||||
"Polling stopped for bot @%s id=%d - %r", user.username, bot.id, user.full_name
|
"Polling stopped for bot @%s id=%d - %r",
|
||||||
|
user.username,
|
||||||
|
bot.id,
|
||||||
|
user.full_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
|
||||||
|
|
@ -413,8 +433,12 @@ class Dispatcher(Router):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def feed_webhook_update(
|
async def feed_webhook_update(
|
||||||
self, bot: Bot, update: Union[Update, Dict[str, Any]], _timeout: float = 55, **kwargs: Any
|
self,
|
||||||
) -> Optional[TelegramMethod[TelegramType]]:
|
bot: Bot,
|
||||||
|
update: Update | dict[str, Any],
|
||||||
|
_timeout: float = 55,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> TelegramMethod[TelegramType] | None:
|
||||||
if not isinstance(update, Update): # Allow to use raw updates
|
if not isinstance(update, Update): # Allow to use raw updates
|
||||||
update = Update.model_validate(update, context={"bot": bot})
|
update = Update.model_validate(update, context={"bot": bot})
|
||||||
|
|
||||||
|
|
@ -429,7 +453,7 @@ class Dispatcher(Router):
|
||||||
timeout_handle = loop.call_later(_timeout, release_waiter)
|
timeout_handle = loop.call_later(_timeout, release_waiter)
|
||||||
|
|
||||||
process_updates: Future[Any] = asyncio.ensure_future(
|
process_updates: Future[Any] = asyncio.ensure_future(
|
||||||
self._feed_webhook_update(bot=bot, update=update, **kwargs)
|
self._feed_webhook_update(bot=bot, update=update, **kwargs),
|
||||||
)
|
)
|
||||||
process_updates.add_done_callback(release_waiter, context=ctx)
|
process_updates.add_done_callback(release_waiter, context=ctx)
|
||||||
|
|
||||||
|
|
@ -440,11 +464,9 @@ class Dispatcher(Router):
|
||||||
"For preventing this situation response into webhook returned immediately "
|
"For preventing this situation response into webhook returned immediately "
|
||||||
"and handler is moved to background and still processing update.",
|
"and handler is moved to background and still processing update.",
|
||||||
RuntimeWarning,
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
try:
|
result = task.result()
|
||||||
result = task.result()
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
if isinstance(result, TelegramMethod):
|
if isinstance(result, TelegramMethod):
|
||||||
asyncio.ensure_future(self.silent_call_request(bot=bot, result=result))
|
asyncio.ensure_future(self.silent_call_request(bot=bot, result=result))
|
||||||
|
|
||||||
|
|
@ -478,7 +500,8 @@ class Dispatcher(Router):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not self._running_lock.locked():
|
if not self._running_lock.locked():
|
||||||
raise RuntimeError("Polling is not started")
|
msg = "Polling is not started"
|
||||||
|
raise RuntimeError(msg)
|
||||||
if not self._stop_signal or not self._stopped_signal:
|
if not self._stop_signal or not self._stopped_signal:
|
||||||
return
|
return
|
||||||
self._stop_signal.set()
|
self._stop_signal.set()
|
||||||
|
|
@ -499,10 +522,10 @@ class Dispatcher(Router):
|
||||||
polling_timeout: int = 10,
|
polling_timeout: int = 10,
|
||||||
handle_as_tasks: bool = True,
|
handle_as_tasks: bool = True,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
|
allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
|
||||||
handle_signals: bool = True,
|
handle_signals: bool = True,
|
||||||
close_bot_session: bool = True,
|
close_bot_session: bool = True,
|
||||||
tasks_concurrency_limit: Optional[int] = None,
|
tasks_concurrency_limit: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -522,12 +545,14 @@ class Dispatcher(Router):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not bots:
|
if not bots:
|
||||||
raise ValueError("At least one bot instance is required to start polling")
|
msg = "At least one bot instance is required to start polling"
|
||||||
|
raise ValueError(msg)
|
||||||
if "bot" in kwargs:
|
if "bot" in kwargs:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Keyword argument 'bot' is not acceptable, "
|
"Keyword argument 'bot' is not acceptable, "
|
||||||
"the bot instance should be passed as positional argument"
|
"the bot instance should be passed as positional argument"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
async with self._running_lock: # Prevent to run this method twice at a once
|
async with self._running_lock: # Prevent to run this method twice at a once
|
||||||
if self._stop_signal is None:
|
if self._stop_signal is None:
|
||||||
|
|
@ -547,10 +572,14 @@ class Dispatcher(Router):
|
||||||
# Signals handling is not supported on Windows
|
# Signals handling is not supported on Windows
|
||||||
# It also can't be covered on Windows
|
# It also can't be covered on Windows
|
||||||
loop.add_signal_handler(
|
loop.add_signal_handler(
|
||||||
signal.SIGTERM, self._signal_stop_polling, signal.SIGTERM
|
signal.SIGTERM,
|
||||||
|
self._signal_stop_polling,
|
||||||
|
signal.SIGTERM,
|
||||||
)
|
)
|
||||||
loop.add_signal_handler(
|
loop.add_signal_handler(
|
||||||
signal.SIGINT, self._signal_stop_polling, signal.SIGINT
|
signal.SIGINT,
|
||||||
|
self._signal_stop_polling,
|
||||||
|
signal.SIGINT,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_data = {
|
workflow_data = {
|
||||||
|
|
@ -565,7 +594,7 @@ class Dispatcher(Router):
|
||||||
await self.emit_startup(bot=bots[-1], **workflow_data)
|
await self.emit_startup(bot=bots[-1], **workflow_data)
|
||||||
loggers.dispatcher.info("Start polling")
|
loggers.dispatcher.info("Start polling")
|
||||||
try:
|
try:
|
||||||
tasks: List[asyncio.Task[Any]] = [
|
tasks: list[asyncio.Task[Any]] = [
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self._polling(
|
self._polling(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
|
|
@ -575,7 +604,7 @@ class Dispatcher(Router):
|
||||||
allowed_updates=allowed_updates,
|
allowed_updates=allowed_updates,
|
||||||
tasks_concurrency_limit=tasks_concurrency_limit,
|
tasks_concurrency_limit=tasks_concurrency_limit,
|
||||||
**workflow_data,
|
**workflow_data,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
for bot in bots
|
for bot in bots
|
||||||
]
|
]
|
||||||
|
|
@ -605,10 +634,10 @@ class Dispatcher(Router):
|
||||||
polling_timeout: int = 10,
|
polling_timeout: int = 10,
|
||||||
handle_as_tasks: bool = True,
|
handle_as_tasks: bool = True,
|
||||||
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
|
||||||
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
|
allowed_updates: list[str] | UNSET_TYPE | None = UNSET,
|
||||||
handle_signals: bool = True,
|
handle_signals: bool = True,
|
||||||
close_bot_session: bool = True,
|
close_bot_session: bool = True,
|
||||||
tasks_concurrency_limit: Optional[int] = None,
|
tasks_concurrency_limit: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -638,5 +667,5 @@ class Dispatcher(Router):
|
||||||
handle_signals=handle_signals,
|
handle_signals=handle_signals,
|
||||||
close_bot_session=close_bot_session,
|
close_bot_session=close_bot_session,
|
||||||
tasks_concurrency_limit=tasks_concurrency_limit,
|
tasks_concurrency_limit=tasks_concurrency_limit,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,22 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, NoReturn, Optional, TypeVar, Union
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, NoReturn, TypeVar
|
||||||
from unittest.mock import sentinel
|
from unittest.mock import sentinel
|
||||||
|
|
||||||
from ...types import TelegramObject
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from ..middlewares.base import BaseMiddleware
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
MiddlewareEventType = TypeVar("MiddlewareEventType", bound=TelegramObject)
|
MiddlewareEventType = TypeVar("MiddlewareEventType", bound=TelegramObject)
|
||||||
NextMiddlewareType = Callable[[MiddlewareEventType, Dict[str, Any]], Awaitable[Any]]
|
NextMiddlewareType = Callable[[MiddlewareEventType, dict[str, Any]], Awaitable[Any]]
|
||||||
MiddlewareType = Union[
|
MiddlewareType = (
|
||||||
BaseMiddleware,
|
BaseMiddleware
|
||||||
Callable[
|
| Callable[
|
||||||
[NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, Dict[str, Any]],
|
[NextMiddlewareType[MiddlewareEventType], MiddlewareEventType, dict[str, Any]],
|
||||||
Awaitable[Any],
|
Awaitable[Any],
|
||||||
],
|
]
|
||||||
]
|
)
|
||||||
|
|
||||||
|
|
||||||
UNHANDLED = sentinel.UNHANDLED
|
UNHANDLED = sentinel.UNHANDLED
|
||||||
REJECTED = sentinel.REJECTED
|
REJECTED = sentinel.REJECTED
|
||||||
|
|
@ -28,7 +30,7 @@ class CancelHandler(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def skip(message: Optional[str] = None) -> NoReturn:
|
def skip(message: str | None = None) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Raise an SkipHandler
|
Raise an SkipHandler
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, List
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .handler import CallbackType, HandlerObject
|
from .handler import CallbackType, HandlerObject
|
||||||
|
|
||||||
|
|
@ -25,7 +26,7 @@ class EventObserver:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.handlers: List[HandlerObject] = []
|
self.handlers: list[HandlerObject] = []
|
||||||
|
|
||||||
def register(self, callback: CallbackType) -> None:
|
def register(self, callback: CallbackType) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from magic_filter.magic import MagicFilter as OriginalMagicFilter
|
from magic_filter.magic import MagicFilter as OriginalMagicFilter
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ CallbackType = Callable[..., Any]
|
||||||
class CallableObject:
|
class CallableObject:
|
||||||
callback: CallbackType
|
callback: CallbackType
|
||||||
awaitable: bool = field(init=False)
|
awaitable: bool = field(init=False)
|
||||||
params: Set[str] = field(init=False)
|
params: set[str] = field(init=False)
|
||||||
varkw: bool = field(init=False)
|
varkw: bool = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|
@ -31,7 +31,7 @@ class CallableObject:
|
||||||
self.params = {*spec.args, *spec.kwonlyargs}
|
self.params = {*spec.args, *spec.kwonlyargs}
|
||||||
self.varkw = spec.varkw is not None
|
self.varkw = spec.varkw is not None
|
||||||
|
|
||||||
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
if self.varkw:
|
if self.varkw:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ class CallableObject:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilterObject(CallableObject):
|
class FilterObject(CallableObject):
|
||||||
magic: Optional[MagicFilter] = None
|
magic: MagicFilter | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if isinstance(self.callback, OriginalMagicFilter):
|
if isinstance(self.callback, OriginalMagicFilter):
|
||||||
|
|
@ -65,7 +65,7 @@ class FilterObject(CallableObject):
|
||||||
stacklevel=6,
|
stacklevel=6,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(FilterObject, self).__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
if isinstance(self.callback, Filter):
|
if isinstance(self.callback, Filter):
|
||||||
self.awaitable = True
|
self.awaitable = True
|
||||||
|
|
@ -73,17 +73,17 @@ class FilterObject(CallableObject):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HandlerObject(CallableObject):
|
class HandlerObject(CallableObject):
|
||||||
filters: Optional[List[FilterObject]] = None
|
filters: list[FilterObject] | None = None
|
||||||
flags: Dict[str, Any] = field(default_factory=dict)
|
flags: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super(HandlerObject, self).__post_init__()
|
super().__post_init__()
|
||||||
callback = inspect.unwrap(self.callback)
|
callback = inspect.unwrap(self.callback)
|
||||||
if inspect.isclass(callback) and issubclass(callback, BaseHandler):
|
if inspect.isclass(callback) and issubclass(callback, BaseHandler):
|
||||||
self.awaitable = True
|
self.awaitable = True
|
||||||
self.flags.update(extract_flags_from_object(callback))
|
self.flags.update(extract_flags_from_object(callback))
|
||||||
|
|
||||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
async def check(self, *args: Any, **kwargs: Any) -> tuple[bool, dict[str, Any]]:
|
||||||
if not self.filters:
|
if not self.filters:
|
||||||
return True, kwargs
|
return True, kwargs
|
||||||
for event_filter in self.filters:
|
for event_filter in self.filters:
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||||
|
from aiogram.exceptions import UnsupportedKeywordArgument
|
||||||
|
from aiogram.filters.base import Filter
|
||||||
|
|
||||||
from ...exceptions import UnsupportedKeywordArgument
|
|
||||||
from ...filters.base import Filter
|
|
||||||
from ...types import TelegramObject
|
|
||||||
from .bases import UNHANDLED, MiddlewareType, SkipHandler
|
from .bases import UNHANDLED, MiddlewareType, SkipHandler
|
||||||
from .handler import CallbackType, FilterObject, HandlerObject
|
from .handler import CallbackType, FilterObject, HandlerObject
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiogram.dispatcher.router import Router
|
from aiogram.dispatcher.router import Router
|
||||||
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
|
|
||||||
class TelegramEventObserver:
|
class TelegramEventObserver:
|
||||||
|
|
@ -26,7 +27,7 @@ class TelegramEventObserver:
|
||||||
self.router: Router = router
|
self.router: Router = router
|
||||||
self.event_name: str = event_name
|
self.event_name: str = event_name
|
||||||
|
|
||||||
self.handlers: List[HandlerObject] = []
|
self.handlers: list[HandlerObject] = []
|
||||||
|
|
||||||
self.middleware = MiddlewareManager()
|
self.middleware = MiddlewareManager()
|
||||||
self.outer_middleware = MiddlewareManager()
|
self.outer_middleware = MiddlewareManager()
|
||||||
|
|
@ -45,8 +46,8 @@ class TelegramEventObserver:
|
||||||
self._handler.filters = []
|
self._handler.filters = []
|
||||||
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
|
self._handler.filters.extend([FilterObject(filter_) for filter_ in filters])
|
||||||
|
|
||||||
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
|
def _resolve_middlewares(self) -> list[MiddlewareType[TelegramObject]]:
|
||||||
middlewares: List[MiddlewareType[TelegramObject]] = []
|
middlewares: list[MiddlewareType[TelegramObject]] = []
|
||||||
for router in reversed(tuple(self.router.chain_head)):
|
for router in reversed(tuple(self.router.chain_head)):
|
||||||
observer = router.observers.get(self.event_name)
|
observer = router.observers.get(self.event_name)
|
||||||
if observer:
|
if observer:
|
||||||
|
|
@ -58,14 +59,14 @@ class TelegramEventObserver:
|
||||||
self,
|
self,
|
||||||
callback: CallbackType,
|
callback: CallbackType,
|
||||||
*filters: CallbackType,
|
*filters: CallbackType,
|
||||||
flags: Optional[Dict[str, Any]] = None,
|
flags: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CallbackType:
|
) -> CallbackType:
|
||||||
"""
|
"""
|
||||||
Register event handler
|
Register event handler
|
||||||
"""
|
"""
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise UnsupportedKeywordArgument(
|
msg = (
|
||||||
"Passing any additional keyword arguments to the registrar method "
|
"Passing any additional keyword arguments to the registrar method "
|
||||||
"is not supported.\n"
|
"is not supported.\n"
|
||||||
"This error may be caused when you are trying to register filters like in 2.x "
|
"This error may be caused when you are trying to register filters like in 2.x "
|
||||||
|
|
@ -73,6 +74,7 @@ class TelegramEventObserver:
|
||||||
"documentation pages.\n"
|
"documentation pages.\n"
|
||||||
f"Please remove the {set(kwargs.keys())} arguments from this call.\n"
|
f"Please remove the {set(kwargs.keys())} arguments from this call.\n"
|
||||||
)
|
)
|
||||||
|
raise UnsupportedKeywordArgument(msg)
|
||||||
|
|
||||||
if flags is None:
|
if flags is None:
|
||||||
flags = {}
|
flags = {}
|
||||||
|
|
@ -86,13 +88,16 @@ class TelegramEventObserver:
|
||||||
callback=callback,
|
callback=callback,
|
||||||
filters=[FilterObject(filter_) for filter_ in filters],
|
filters=[FilterObject(filter_) for filter_ in filters],
|
||||||
flags=flags,
|
flags=flags,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
def wrap_outer_middleware(
|
def wrap_outer_middleware(
|
||||||
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
|
self,
|
||||||
|
callback: Any,
|
||||||
|
event: TelegramObject,
|
||||||
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
wrapped_outer = self.middleware.wrap_middlewares(
|
wrapped_outer = self.middleware.wrap_middlewares(
|
||||||
self.outer_middleware,
|
self.outer_middleware,
|
||||||
|
|
@ -127,7 +132,7 @@ class TelegramEventObserver:
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*filters: CallbackType,
|
*filters: CallbackType,
|
||||||
flags: Optional[Dict[str, Any]] = None,
|
flags: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Callable[[CallbackType], CallbackType]:
|
) -> Callable[[CallbackType], CallbackType]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast, overload
|
from typing import TYPE_CHECKING, Any, Union, cast, overload
|
||||||
|
|
||||||
from magic_filter import AttrDict, MagicFilter
|
from magic_filter import AttrDict, MagicFilter
|
||||||
|
|
||||||
|
|
@ -39,11 +40,12 @@ class FlagDecorator:
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
value: Optional[Any] = None,
|
value: Any | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[Callable[..., Any], "FlagDecorator"]:
|
) -> Union[Callable[..., Any], "FlagDecorator"]:
|
||||||
if value and kwargs:
|
if value and kwargs:
|
||||||
raise ValueError("The arguments `value` and **kwargs can not be used together")
|
msg = "The arguments `value` and **kwargs can not be used together"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if value is not None and callable(value):
|
if value is not None and callable(value):
|
||||||
value.aiogram_flag = {
|
value.aiogram_flag = {
|
||||||
|
|
@ -70,20 +72,21 @@ if TYPE_CHECKING:
|
||||||
class FlagGenerator:
|
class FlagGenerator:
|
||||||
def __getattr__(self, name: str) -> FlagDecorator:
|
def __getattr__(self, name: str) -> FlagDecorator:
|
||||||
if name[0] == "_":
|
if name[0] == "_":
|
||||||
raise AttributeError("Flag name must NOT start with underscore")
|
msg = "Flag name must NOT start with underscore"
|
||||||
|
raise AttributeError(msg)
|
||||||
return FlagDecorator(Flag(name, True))
|
return FlagDecorator(Flag(name, True))
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
chat_action: _ChatActionFlagProtocol
|
chat_action: _ChatActionFlagProtocol
|
||||||
|
|
||||||
|
|
||||||
def extract_flags_from_object(obj: Any) -> Dict[str, Any]:
|
def extract_flags_from_object(obj: Any) -> dict[str, Any]:
|
||||||
if not hasattr(obj, "aiogram_flag"):
|
if not hasattr(obj, "aiogram_flag"):
|
||||||
return {}
|
return {}
|
||||||
return cast(Dict[str, Any], obj.aiogram_flag)
|
return cast(dict[str, Any], obj.aiogram_flag)
|
||||||
|
|
||||||
|
|
||||||
def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str, Any]:
|
def extract_flags(handler: Union["HandlerObject", dict[str, Any]]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Extract flags from handler or middleware context data
|
Extract flags from handler or middleware context data
|
||||||
|
|
||||||
|
|
@ -98,10 +101,10 @@ def extract_flags(handler: Union["HandlerObject", Dict[str, Any]]) -> Dict[str,
|
||||||
|
|
||||||
|
|
||||||
def get_flag(
|
def get_flag(
|
||||||
handler: Union["HandlerObject", Dict[str, Any]],
|
handler: Union["HandlerObject", dict[str, Any]],
|
||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
default: Optional[Any] = None,
|
default: Any | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get flag by name
|
Get flag by name
|
||||||
|
|
@ -115,7 +118,7 @@ def get_flag(
|
||||||
return flags.get(name, default)
|
return flags.get(name, default)
|
||||||
|
|
||||||
|
|
||||||
def check_flags(handler: Union["HandlerObject", Dict[str, Any]], magic: MagicFilter) -> Any:
|
def check_flags(handler: Union["HandlerObject", dict[str, Any]], magic: MagicFilter) -> Any:
|
||||||
"""
|
"""
|
||||||
Check flags via magic filter
|
Check flags via magic filter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Awaitable, Callable, Dict, TypeVar
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
|
|
@ -14,9 +15,9 @@ class BaseMiddleware(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any: # pragma: no cover
|
) -> Any: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Execute middleware
|
Execute middleware
|
||||||
|
|
@ -26,4 +27,3 @@ class BaseMiddleware(ABC):
|
||||||
:param data: Contextual data. Will be mapped to handler arguments
|
:param data: Contextual data. Will be mapped to handler arguments
|
||||||
:return: :class:`Any`
|
:return: :class:`Any`
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, cast
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from aiogram.dispatcher.event.bases import UNHANDLED, CancelHandler, SkipHandler
|
||||||
|
from aiogram.types import TelegramObject, Update
|
||||||
|
from aiogram.types.error_event import ErrorEvent
|
||||||
|
|
||||||
from ...types import TelegramObject, Update
|
|
||||||
from ...types.error_event import ErrorEvent
|
|
||||||
from ..event.bases import UNHANDLED, CancelHandler, SkipHandler
|
|
||||||
from .base import BaseMiddleware
|
from .base import BaseMiddleware
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..router import Router
|
from aiogram.dispatcher.router import Router
|
||||||
|
|
||||||
|
|
||||||
class ErrorsMiddleware(BaseMiddleware):
|
class ErrorsMiddleware(BaseMiddleware):
|
||||||
|
|
@ -17,9 +19,9 @@ class ErrorsMiddleware(BaseMiddleware):
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
try:
|
try:
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import Any, overload
|
||||||
|
|
||||||
from aiogram.dispatcher.event.bases import (
|
from aiogram.dispatcher.event.bases import (
|
||||||
MiddlewareEventType,
|
MiddlewareEventType,
|
||||||
|
|
@ -12,7 +13,7 @@ from aiogram.types import TelegramObject
|
||||||
|
|
||||||
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._middlewares: List[MiddlewareType[TelegramObject]] = []
|
self._middlewares: list[MiddlewareType[TelegramObject]] = []
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
|
|
@ -26,11 +27,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
middleware: MiddlewareType[TelegramObject] | None = None,
|
||||||
) -> Union[
|
) -> (
|
||||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]]
|
||||||
MiddlewareType[TelegramObject],
|
| MiddlewareType[TelegramObject]
|
||||||
]:
|
):
|
||||||
if middleware is None:
|
if middleware is None:
|
||||||
return self.register
|
return self.register
|
||||||
return self.register(middleware)
|
return self.register(middleware)
|
||||||
|
|
@ -44,8 +45,9 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice]
|
self,
|
||||||
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
|
item: int | slice,
|
||||||
|
) -> MiddlewareType[TelegramObject] | Sequence[MiddlewareType[TelegramObject]]:
|
||||||
return self._middlewares[item]
|
return self._middlewares[item]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -53,10 +55,11 @@ class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wrap_middlewares(
|
def wrap_middlewares(
|
||||||
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: CallbackType
|
middlewares: Sequence[MiddlewareType[MiddlewareEventType]],
|
||||||
|
handler: CallbackType,
|
||||||
) -> NextMiddlewareType[MiddlewareEventType]:
|
) -> NextMiddlewareType[MiddlewareEventType]:
|
||||||
@functools.wraps(handler)
|
@functools.wraps(handler)
|
||||||
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
def handler_wrapper(event: TelegramObject, kwargs: dict[str, Any]) -> Any:
|
||||||
return handler(event, **kwargs)
|
return handler(event, **kwargs)
|
||||||
|
|
||||||
middleware = handler_wrapper
|
middleware = handler_wrapper
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
from aiogram.types import (
|
from aiogram.types import (
|
||||||
|
|
@ -20,29 +21,30 @@ EVENT_THREAD_ID_KEY = "event_thread_id"
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class EventContext:
|
class EventContext:
|
||||||
chat: Optional[Chat] = None
|
chat: Chat | None = None
|
||||||
user: Optional[User] = None
|
user: User | None = None
|
||||||
thread_id: Optional[int] = None
|
thread_id: int | None = None
|
||||||
business_connection_id: Optional[str] = None
|
business_connection_id: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_id(self) -> Optional[int]:
|
def user_id(self) -> int | None:
|
||||||
return self.user.id if self.user else None
|
return self.user.id if self.user else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_id(self) -> Optional[int]:
|
def chat_id(self) -> int | None:
|
||||||
return self.chat.id if self.chat else None
|
return self.chat.id if self.chat else None
|
||||||
|
|
||||||
|
|
||||||
class UserContextMiddleware(BaseMiddleware):
|
class UserContextMiddleware(BaseMiddleware):
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not isinstance(event, Update):
|
if not isinstance(event, Update):
|
||||||
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
|
msg = "UserContextMiddleware got an unexpected event type!"
|
||||||
|
raise RuntimeError(msg)
|
||||||
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
|
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
|
||||||
|
|
||||||
# Backward compatibility
|
# Backward compatibility
|
||||||
|
|
@ -116,13 +118,15 @@ class UserContextMiddleware(BaseMiddleware):
|
||||||
)
|
)
|
||||||
if event.my_chat_member:
|
if event.my_chat_member:
|
||||||
return EventContext(
|
return EventContext(
|
||||||
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user
|
chat=event.my_chat_member.chat,
|
||||||
|
user=event.my_chat_member.from_user,
|
||||||
)
|
)
|
||||||
if event.chat_member:
|
if event.chat_member:
|
||||||
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
|
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
|
||||||
if event.chat_join_request:
|
if event.chat_join_request:
|
||||||
return EventContext(
|
return EventContext(
|
||||||
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user
|
chat=event.chat_join_request.chat,
|
||||||
|
user=event.chat_join_request.from_user,
|
||||||
)
|
)
|
||||||
if event.message_reaction:
|
if event.message_reaction:
|
||||||
return EventContext(
|
return EventContext(
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Final, Generator, List, Optional, Set
|
from collections.abc import Generator
|
||||||
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
from ..types import TelegramObject
|
|
||||||
from .event.bases import REJECTED, UNHANDLED
|
from .event.bases import REJECTED, UNHANDLED
|
||||||
from .event.event import EventObserver
|
from .event.event import EventObserver
|
||||||
from .event.telegram import TelegramEventObserver
|
from .event.telegram import TelegramEventObserver
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
|
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,31 +24,34 @@ class Router:
|
||||||
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
|
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, name: Optional[str] = None) -> None:
|
def __init__(self, *, name: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
:param name: Optional router name, can be useful for debugging
|
:param name: Optional router name, can be useful for debugging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.name = name or hex(id(self))
|
self.name = name or hex(id(self))
|
||||||
|
|
||||||
self._parent_router: Optional[Router] = None
|
self._parent_router: Router | None = None
|
||||||
self.sub_routers: List[Router] = []
|
self.sub_routers: list[Router] = []
|
||||||
|
|
||||||
# Observers
|
# Observers
|
||||||
self.message = TelegramEventObserver(router=self, event_name="message")
|
self.message = TelegramEventObserver(router=self, event_name="message")
|
||||||
self.edited_message = TelegramEventObserver(router=self, event_name="edited_message")
|
self.edited_message = TelegramEventObserver(router=self, event_name="edited_message")
|
||||||
self.channel_post = TelegramEventObserver(router=self, event_name="channel_post")
|
self.channel_post = TelegramEventObserver(router=self, event_name="channel_post")
|
||||||
self.edited_channel_post = TelegramEventObserver(
|
self.edited_channel_post = TelegramEventObserver(
|
||||||
router=self, event_name="edited_channel_post"
|
router=self,
|
||||||
|
event_name="edited_channel_post",
|
||||||
)
|
)
|
||||||
self.inline_query = TelegramEventObserver(router=self, event_name="inline_query")
|
self.inline_query = TelegramEventObserver(router=self, event_name="inline_query")
|
||||||
self.chosen_inline_result = TelegramEventObserver(
|
self.chosen_inline_result = TelegramEventObserver(
|
||||||
router=self, event_name="chosen_inline_result"
|
router=self,
|
||||||
|
event_name="chosen_inline_result",
|
||||||
)
|
)
|
||||||
self.callback_query = TelegramEventObserver(router=self, event_name="callback_query")
|
self.callback_query = TelegramEventObserver(router=self, event_name="callback_query")
|
||||||
self.shipping_query = TelegramEventObserver(router=self, event_name="shipping_query")
|
self.shipping_query = TelegramEventObserver(router=self, event_name="shipping_query")
|
||||||
self.pre_checkout_query = TelegramEventObserver(
|
self.pre_checkout_query = TelegramEventObserver(
|
||||||
router=self, event_name="pre_checkout_query"
|
router=self,
|
||||||
|
event_name="pre_checkout_query",
|
||||||
)
|
)
|
||||||
self.poll = TelegramEventObserver(router=self, event_name="poll")
|
self.poll = TelegramEventObserver(router=self, event_name="poll")
|
||||||
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
|
self.poll_answer = TelegramEventObserver(router=self, event_name="poll_answer")
|
||||||
|
|
@ -54,24 +60,30 @@ class Router:
|
||||||
self.chat_join_request = TelegramEventObserver(router=self, event_name="chat_join_request")
|
self.chat_join_request = TelegramEventObserver(router=self, event_name="chat_join_request")
|
||||||
self.message_reaction = TelegramEventObserver(router=self, event_name="message_reaction")
|
self.message_reaction = TelegramEventObserver(router=self, event_name="message_reaction")
|
||||||
self.message_reaction_count = TelegramEventObserver(
|
self.message_reaction_count = TelegramEventObserver(
|
||||||
router=self, event_name="message_reaction_count"
|
router=self,
|
||||||
|
event_name="message_reaction_count",
|
||||||
)
|
)
|
||||||
self.chat_boost = TelegramEventObserver(router=self, event_name="chat_boost")
|
self.chat_boost = TelegramEventObserver(router=self, event_name="chat_boost")
|
||||||
self.removed_chat_boost = TelegramEventObserver(
|
self.removed_chat_boost = TelegramEventObserver(
|
||||||
router=self, event_name="removed_chat_boost"
|
router=self,
|
||||||
|
event_name="removed_chat_boost",
|
||||||
)
|
)
|
||||||
self.deleted_business_messages = TelegramEventObserver(
|
self.deleted_business_messages = TelegramEventObserver(
|
||||||
router=self, event_name="deleted_business_messages"
|
router=self,
|
||||||
|
event_name="deleted_business_messages",
|
||||||
)
|
)
|
||||||
self.business_connection = TelegramEventObserver(
|
self.business_connection = TelegramEventObserver(
|
||||||
router=self, event_name="business_connection"
|
router=self,
|
||||||
|
event_name="business_connection",
|
||||||
)
|
)
|
||||||
self.edited_business_message = TelegramEventObserver(
|
self.edited_business_message = TelegramEventObserver(
|
||||||
router=self, event_name="edited_business_message"
|
router=self,
|
||||||
|
event_name="edited_business_message",
|
||||||
)
|
)
|
||||||
self.business_message = TelegramEventObserver(router=self, event_name="business_message")
|
self.business_message = TelegramEventObserver(router=self, event_name="business_message")
|
||||||
self.purchased_paid_media = TelegramEventObserver(
|
self.purchased_paid_media = TelegramEventObserver(
|
||||||
router=self, event_name="purchased_paid_media"
|
router=self,
|
||||||
|
event_name="purchased_paid_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.errors = self.error = TelegramEventObserver(router=self, event_name="error")
|
self.errors = self.error = TelegramEventObserver(router=self, event_name="error")
|
||||||
|
|
@ -79,7 +91,7 @@ class Router:
|
||||||
self.startup = EventObserver()
|
self.startup = EventObserver()
|
||||||
self.shutdown = EventObserver()
|
self.shutdown = EventObserver()
|
||||||
|
|
||||||
self.observers: Dict[str, TelegramEventObserver] = {
|
self.observers: dict[str, TelegramEventObserver] = {
|
||||||
"message": self.message,
|
"message": self.message,
|
||||||
"edited_message": self.edited_message,
|
"edited_message": self.edited_message,
|
||||||
"channel_post": self.channel_post,
|
"channel_post": self.channel_post,
|
||||||
|
|
@ -112,7 +124,7 @@ class Router:
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<{self}>"
|
return f"<{self}>"
|
||||||
|
|
||||||
def resolve_used_update_types(self, skip_events: Optional[Set[str]] = None) -> List[str]:
|
def resolve_used_update_types(self, skip_events: set[str] | None = None) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Resolve registered event names
|
Resolve registered event names
|
||||||
|
|
||||||
|
|
@ -121,7 +133,7 @@ class Router:
|
||||||
:param skip_events: skip specified event names
|
:param skip_events: skip specified event names
|
||||||
:return: set of registered names
|
:return: set of registered names
|
||||||
"""
|
"""
|
||||||
handlers_in_use: Set[str] = set()
|
handlers_in_use: set[str] = set()
|
||||||
if skip_events is None:
|
if skip_events is None:
|
||||||
skip_events = set()
|
skip_events = set()
|
||||||
skip_events = {*skip_events, *INTERNAL_UPDATE_TYPES}
|
skip_events = {*skip_events, *INTERNAL_UPDATE_TYPES}
|
||||||
|
|
@ -139,7 +151,10 @@ class Router:
|
||||||
|
|
||||||
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
|
async def _wrapped(telegram_event: TelegramObject, **data: Any) -> Any:
|
||||||
return await self._propagate_event(
|
return await self._propagate_event(
|
||||||
observer=observer, update_type=update_type, event=telegram_event, **data
|
observer=observer,
|
||||||
|
update_type=update_type,
|
||||||
|
event=telegram_event,
|
||||||
|
**data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if observer:
|
if observer:
|
||||||
|
|
@ -148,7 +163,7 @@ class Router:
|
||||||
|
|
||||||
async def _propagate_event(
|
async def _propagate_event(
|
||||||
self,
|
self,
|
||||||
observer: Optional[TelegramEventObserver],
|
observer: TelegramEventObserver | None,
|
||||||
update_type: str,
|
update_type: str,
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
|
@ -179,7 +194,7 @@ class Router:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chain_head(self) -> Generator[Router, None, None]:
|
def chain_head(self) -> Generator[Router, None, None]:
|
||||||
router: Optional[Router] = self
|
router: Router | None = self
|
||||||
while router:
|
while router:
|
||||||
yield router
|
yield router
|
||||||
router = router.parent_router
|
router = router.parent_router
|
||||||
|
|
@ -191,7 +206,7 @@ class Router:
|
||||||
yield from router.chain_tail
|
yield from router.chain_tail
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parent_router(self) -> Optional[Router]:
|
def parent_router(self) -> Router | None:
|
||||||
return self._parent_router
|
return self._parent_router
|
||||||
|
|
||||||
@parent_router.setter
|
@parent_router.setter
|
||||||
|
|
@ -206,16 +221,20 @@ class Router:
|
||||||
:param router:
|
:param router:
|
||||||
"""
|
"""
|
||||||
if not isinstance(router, Router):
|
if not isinstance(router, Router):
|
||||||
raise ValueError(f"router should be instance of Router not {type(router).__name__!r}")
|
msg = f"router should be instance of Router not {type(router).__name__!r}"
|
||||||
|
raise ValueError(msg)
|
||||||
if self._parent_router:
|
if self._parent_router:
|
||||||
raise RuntimeError(f"Router is already attached to {self._parent_router!r}")
|
msg = f"Router is already attached to {self._parent_router!r}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
if self == router:
|
if self == router:
|
||||||
raise RuntimeError("Self-referencing routers is not allowed")
|
msg = "Self-referencing routers is not allowed"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
parent: Optional[Router] = router
|
parent: Router | None = router
|
||||||
while parent is not None:
|
while parent is not None:
|
||||||
if parent == self:
|
if parent == self:
|
||||||
raise RuntimeError("Circular referencing of Router is not allowed")
|
msg = "Circular referencing of Router is not allowed"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
parent = parent.parent_router
|
parent = parent.parent_router
|
||||||
|
|
||||||
|
|
@ -230,7 +249,8 @@ class Router:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not routers:
|
if not routers:
|
||||||
raise ValueError("At least one router must be provided")
|
msg = "At least one router must be provided"
|
||||||
|
raise ValueError(msg)
|
||||||
for router in routers:
|
for router in routers:
|
||||||
self.include_router(router)
|
self.include_router(router)
|
||||||
|
|
||||||
|
|
@ -242,9 +262,8 @@ class Router:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not isinstance(router, Router):
|
if not isinstance(router, Router):
|
||||||
raise ValueError(
|
msg = f"router should be instance of Router not {type(router).__class__.__name__}"
|
||||||
f"router should be instance of Router not {type(router).__class__.__name__}"
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
router.parent_router = self
|
router.parent_router = self
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ class DetailedAiogramError(AiogramError):
|
||||||
Base exception for all aiogram errors with detailed message.
|
Base exception for all aiogram errors with detailed message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: Optional[str] = None
|
url: str | None = None
|
||||||
|
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
|
||||||
|
|
@ -23,29 +23,29 @@ from .state import StateFilter
|
||||||
BaseFilter = Filter
|
BaseFilter = Filter
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"Filter",
|
"ADMINISTRATOR",
|
||||||
|
"CREATOR",
|
||||||
|
"IS_ADMIN",
|
||||||
|
"IS_MEMBER",
|
||||||
|
"IS_NOT_MEMBER",
|
||||||
|
"JOIN_TRANSITION",
|
||||||
|
"KICKED",
|
||||||
|
"LEAVE_TRANSITION",
|
||||||
|
"LEFT",
|
||||||
|
"MEMBER",
|
||||||
|
"PROMOTED_TRANSITION",
|
||||||
|
"RESTRICTED",
|
||||||
"BaseFilter",
|
"BaseFilter",
|
||||||
|
"ChatMemberUpdatedFilter",
|
||||||
"Command",
|
"Command",
|
||||||
"CommandObject",
|
"CommandObject",
|
||||||
"CommandStart",
|
"CommandStart",
|
||||||
"ExceptionMessageFilter",
|
"ExceptionMessageFilter",
|
||||||
"ExceptionTypeFilter",
|
"ExceptionTypeFilter",
|
||||||
"StateFilter",
|
"Filter",
|
||||||
"MagicData",
|
"MagicData",
|
||||||
"ChatMemberUpdatedFilter",
|
"StateFilter",
|
||||||
"CREATOR",
|
|
||||||
"ADMINISTRATOR",
|
|
||||||
"MEMBER",
|
|
||||||
"RESTRICTED",
|
|
||||||
"LEFT",
|
|
||||||
"KICKED",
|
|
||||||
"IS_MEMBER",
|
|
||||||
"IS_ADMIN",
|
|
||||||
"PROMOTED_TRANSITION",
|
|
||||||
"IS_NOT_MEMBER",
|
|
||||||
"JOIN_TRANSITION",
|
|
||||||
"LEAVE_TRANSITION",
|
|
||||||
"and_f",
|
"and_f",
|
||||||
"or_f",
|
|
||||||
"invert_f",
|
"invert_f",
|
||||||
|
"or_f",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiogram.filters.logic import _InvertFilter
|
from aiogram.filters.logic import _InvertFilter
|
||||||
|
|
||||||
|
|
||||||
class Filter(ABC):
|
class Filter(ABC): # noqa: B024
|
||||||
"""
|
"""
|
||||||
If you want to register own filters like builtin filters you will need to write subclass
|
If you want to register own filters like builtin filters you will need to write subclass
|
||||||
of this class with overriding the :code:`__call__`
|
of this class with overriding the :code:`__call__`
|
||||||
|
|
@ -16,11 +17,11 @@ class Filter(ABC):
|
||||||
# This checking type-hint is needed because mypy checks validity of overrides and raises:
|
# This checking type-hint is needed because mypy checks validity of overrides and raises:
|
||||||
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
|
# error: Signature of "__call__" incompatible with supertype "BaseFilter" [override]
|
||||||
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
|
# https://mypy.readthedocs.io/en/latest/error_code_list.html#check-validity-of-overrides-override
|
||||||
__call__: Callable[..., Awaitable[Union[bool, Dict[str, Any]]]]
|
__call__: Callable[..., Awaitable[bool | dict[str, Any]]]
|
||||||
else: # pragma: no cover
|
else: # pragma: no cover
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
This method should be overridden.
|
This method should be overridden.
|
||||||
|
|
||||||
|
|
@ -28,21 +29,19 @@ class Filter(ABC):
|
||||||
|
|
||||||
:return: :class:`bool` or :class:`Dict[str, Any]`
|
:return: :class:`bool` or :class:`Dict[str, Any]`
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def __invert__(self) -> "_InvertFilter":
|
def __invert__(self) -> "_InvertFilter":
|
||||||
from aiogram.filters.logic import invert_f
|
from aiogram.filters.logic import invert_f
|
||||||
|
|
||||||
return invert_f(self)
|
return invert_f(self)
|
||||||
|
|
||||||
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
def update_handler_flags(self, flags: dict[str, Any]) -> None: # noqa: B027
|
||||||
"""
|
"""
|
||||||
Also if you want to extend handler flags with using this filter
|
Also if you want to extend handler flags with using this filter
|
||||||
you should implement this method
|
you should implement this method
|
||||||
|
|
||||||
:param flags: existing flags, can be updated directly
|
:param flags: existing flags, can be updated directly
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
|
def _signature_to_string(self, *args: Any, **kwargs: Any) -> str:
|
||||||
items = [repr(arg) for arg in args]
|
items = [repr(arg) for arg in args]
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,30 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from magic_filter import MagicFilter
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from aiogram.filters.base import Filter
|
from aiogram.filters.base import Filter
|
||||||
from aiogram.types import CallbackQuery
|
from aiogram.types import CallbackQuery
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from magic_filter import MagicFilter
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
T = TypeVar("T", bound="CallbackData")
|
T = TypeVar("T", bound="CallbackData")
|
||||||
|
|
||||||
MAX_CALLBACK_LENGTH: int = 64
|
MAX_CALLBACK_LENGTH: int = 64
|
||||||
|
|
||||||
|
|
||||||
_UNION_TYPES = {typing.Union}
|
_UNION_TYPES = {typing.Union, types.UnionType}
|
||||||
if sys.version_info >= (3, 10): # pragma: no cover
|
|
||||||
_UNION_TYPES.add(types.UnionType)
|
|
||||||
|
|
||||||
|
|
||||||
class CallbackDataException(Exception):
|
class CallbackDataException(Exception):
|
||||||
|
|
@ -59,17 +49,19 @@ class CallbackData(BaseModel):
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
if "prefix" not in kwargs:
|
if "prefix" not in kwargs:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"prefix required, usage example: "
|
f"prefix required, usage example: "
|
||||||
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
|
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
cls.__separator__ = kwargs.pop("sep", ":")
|
cls.__separator__ = kwargs.pop("sep", ":")
|
||||||
cls.__prefix__ = kwargs.pop("prefix")
|
cls.__prefix__ = kwargs.pop("prefix")
|
||||||
if cls.__separator__ in cls.__prefix__:
|
if cls.__separator__ in cls.__prefix__:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Separator symbol {cls.__separator__!r} can not be used "
|
f"Separator symbol {cls.__separator__!r} can not be used "
|
||||||
f"inside prefix {cls.__prefix__!r}"
|
f"inside prefix {cls.__prefix__!r}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
def _encode_value(self, key: str, value: Any) -> str:
|
def _encode_value(self, key: str, value: Any) -> str:
|
||||||
|
|
@ -83,10 +75,11 @@ class CallbackData(BaseModel):
|
||||||
return str(int(value))
|
return str(int(value))
|
||||||
if isinstance(value, (int, str, float, Decimal, Fraction)):
|
if isinstance(value, (int, str, float, Decimal, Fraction)):
|
||||||
return str(value)
|
return str(value)
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Attribute {key}={value!r} of type {type(value).__name__!r}"
|
f"Attribute {key}={value!r} of type {type(value).__name__!r}"
|
||||||
f" can not be packed to callback data"
|
f" can not be packed to callback data"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def pack(self) -> str:
|
def pack(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -98,21 +91,23 @@ class CallbackData(BaseModel):
|
||||||
for key, value in self.model_dump(mode="python").items():
|
for key, value in self.model_dump(mode="python").items():
|
||||||
encoded = self._encode_value(key, value)
|
encoded = self._encode_value(key, value)
|
||||||
if self.__separator__ in encoded:
|
if self.__separator__ in encoded:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Separator symbol {self.__separator__!r} can not be used "
|
f"Separator symbol {self.__separator__!r} can not be used "
|
||||||
f"in value {key}={encoded!r}"
|
f"in value {key}={encoded!r}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
result.append(encoded)
|
result.append(encoded)
|
||||||
callback_data = self.__separator__.join(result)
|
callback_data = self.__separator__.join(result)
|
||||||
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
|
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Resulted callback data is too long! "
|
f"Resulted callback data is too long! "
|
||||||
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return callback_data
|
return callback_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unpack(cls: Type[T], value: str) -> T:
|
def unpack(cls, value: str) -> Self:
|
||||||
"""
|
"""
|
||||||
Parse callback data string
|
Parse callback data string
|
||||||
|
|
||||||
|
|
@ -122,22 +117,28 @@ class CallbackData(BaseModel):
|
||||||
prefix, *parts = value.split(cls.__separator__)
|
prefix, *parts = value.split(cls.__separator__)
|
||||||
names = cls.model_fields.keys()
|
names = cls.model_fields.keys()
|
||||||
if len(parts) != len(names):
|
if len(parts) != len(names):
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"Callback data {cls.__name__!r} takes {len(names)} arguments "
|
f"Callback data {cls.__name__!r} takes {len(names)} arguments "
|
||||||
f"but {len(parts)} were given"
|
f"but {len(parts)} were given"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
if prefix != cls.__prefix__:
|
if prefix != cls.__prefix__:
|
||||||
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
|
msg = f"Bad prefix ({prefix!r} != {cls.__prefix__!r})"
|
||||||
|
raise ValueError(msg)
|
||||||
payload = {}
|
payload = {}
|
||||||
for k, v in zip(names, parts): # type: str, Optional[str]
|
for k, v in zip(names, parts, strict=True): # type: str, str
|
||||||
if field := cls.model_fields.get(k):
|
if (
|
||||||
if v == "" and _check_field_is_nullable(field) and field.default != "":
|
(field := cls.model_fields.get(k))
|
||||||
v = field.default if field.default is not PydanticUndefined else None
|
and v == ""
|
||||||
|
and _check_field_is_nullable(field)
|
||||||
|
and field.default != ""
|
||||||
|
):
|
||||||
|
v = field.default if field.default is not PydanticUndefined else None
|
||||||
payload[k] = v
|
payload[k] = v
|
||||||
return cls(**payload)
|
return cls(**payload)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter:
|
def filter(cls, rule: MagicFilter | None = None) -> CallbackQueryFilter:
|
||||||
"""
|
"""
|
||||||
Generates a filter for callback query with rule
|
Generates a filter for callback query with rule
|
||||||
|
|
||||||
|
|
@ -163,8 +164,8 @@ class CallbackQueryFilter(Filter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
callback_data: Type[CallbackData],
|
callback_data: type[CallbackData],
|
||||||
rule: Optional[MagicFilter] = None,
|
rule: MagicFilter | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param callback_data: Expected type of callback data
|
:param callback_data: Expected type of callback data
|
||||||
|
|
@ -179,7 +180,7 @@ class CallbackQueryFilter(Filter):
|
||||||
rule=self.rule,
|
rule=self.rule,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
|
async def __call__(self, query: CallbackQuery) -> Literal[False] | dict[str, Any]:
|
||||||
if not isinstance(query, CallbackQuery) or not query.data:
|
if not isinstance(query, CallbackQuery) or not query.data:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
|
@ -204,5 +205,5 @@ def _check_field_is_nullable(field: FieldInfo) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
|
return typing.get_origin(field.annotation) in _UNION_TYPES and type(None) in typing.get_args(
|
||||||
field.annotation
|
field.annotation,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, Dict, Optional, TypeVar, Union
|
from typing import Any, TypeVar, Union
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from aiogram.filters.base import Filter
|
from aiogram.filters.base import Filter
|
||||||
from aiogram.types import ChatMember, ChatMemberUpdated
|
from aiogram.types import ChatMember, ChatMemberUpdated
|
||||||
|
|
@ -10,11 +12,11 @@ TransitionT = TypeVar("TransitionT", bound="_MemberStatusTransition")
|
||||||
|
|
||||||
class _MemberStatusMarker:
|
class _MemberStatusMarker:
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"name",
|
|
||||||
"is_member",
|
"is_member",
|
||||||
|
"name",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, name: str, *, is_member: Optional[bool] = None) -> None:
|
def __init__(self, name: str, *, is_member: bool | None = None) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.is_member = is_member
|
self.is_member = is_member
|
||||||
|
|
||||||
|
|
@ -22,53 +24,59 @@ class _MemberStatusMarker:
|
||||||
result = self.name.upper()
|
result = self.name.upper()
|
||||||
if self.is_member is not None:
|
if self.is_member is not None:
|
||||||
result = ("+" if self.is_member else "-") + result
|
result = ("+" if self.is_member else "-") + result
|
||||||
return result # noqa: RET504
|
return result
|
||||||
|
|
||||||
def __pos__(self: MarkerT) -> MarkerT:
|
def __pos__(self) -> Self:
|
||||||
return type(self)(name=self.name, is_member=True)
|
return type(self)(name=self.name, is_member=True)
|
||||||
|
|
||||||
def __neg__(self: MarkerT) -> MarkerT:
|
def __neg__(self) -> Self:
|
||||||
return type(self)(name=self.name, is_member=False)
|
return type(self)(name=self.name, is_member=False)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
) -> "_MemberStatusGroupMarker":
|
) -> "_MemberStatusGroupMarker":
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return _MemberStatusGroupMarker(self, other)
|
return _MemberStatusGroupMarker(self, other)
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return other | self
|
return other | self
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for |: "
|
f"unsupported operand type(s) for |: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
__ror__ = __or__
|
__ror__ = __or__
|
||||||
|
|
||||||
def __rshift__(
|
def __rshift__(
|
||||||
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
) -> "_MemberStatusTransition":
|
) -> "_MemberStatusTransition":
|
||||||
old = _MemberStatusGroupMarker(self)
|
old = _MemberStatusGroupMarker(self)
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return _MemberStatusTransition(old=old, new=_MemberStatusGroupMarker(other))
|
return _MemberStatusTransition(old=old, new=_MemberStatusGroupMarker(other))
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return _MemberStatusTransition(old=old, new=other)
|
return _MemberStatusTransition(old=old, new=other)
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for >>: "
|
f"unsupported operand type(s) for >>: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
def __lshift__(
|
def __lshift__(
|
||||||
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
) -> "_MemberStatusTransition":
|
) -> "_MemberStatusTransition":
|
||||||
new = _MemberStatusGroupMarker(self)
|
new = _MemberStatusGroupMarker(self)
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=new)
|
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=new)
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return _MemberStatusTransition(old=other, new=new)
|
return _MemberStatusTransition(old=other, new=new)
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for <<: "
|
f"unsupported operand type(s) for <<: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash((self.name, self.is_member))
|
return hash((self.name, self.is_member))
|
||||||
|
|
@ -87,44 +95,51 @@ class _MemberStatusGroupMarker:
|
||||||
|
|
||||||
def __init__(self, *statuses: _MemberStatusMarker) -> None:
|
def __init__(self, *statuses: _MemberStatusMarker) -> None:
|
||||||
if not statuses:
|
if not statuses:
|
||||||
raise ValueError("Member status group should have at least one status included")
|
msg = "Member status group should have at least one status included"
|
||||||
|
raise ValueError(msg)
|
||||||
self.statuses = frozenset(statuses)
|
self.statuses = frozenset(statuses)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self: MarkerGroupT, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
) -> MarkerGroupT:
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
|
) -> Self:
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return type(self)(*self.statuses, other)
|
return type(self)(*self.statuses, other)
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return type(self)(*self.statuses, *other.statuses)
|
return type(self)(*self.statuses, *other.statuses)
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for |: "
|
f"unsupported operand type(s) for |: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
def __rshift__(
|
def __rshift__(
|
||||||
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
) -> "_MemberStatusTransition":
|
) -> "_MemberStatusTransition":
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return _MemberStatusTransition(old=self, new=_MemberStatusGroupMarker(other))
|
return _MemberStatusTransition(old=self, new=_MemberStatusGroupMarker(other))
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return _MemberStatusTransition(old=self, new=other)
|
return _MemberStatusTransition(old=self, new=other)
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for >>: "
|
f"unsupported operand type(s) for >>: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
def __lshift__(
|
def __lshift__(
|
||||||
self, other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"]
|
self,
|
||||||
|
other: Union["_MemberStatusMarker", "_MemberStatusGroupMarker"],
|
||||||
) -> "_MemberStatusTransition":
|
) -> "_MemberStatusTransition":
|
||||||
if isinstance(other, _MemberStatusMarker):
|
if isinstance(other, _MemberStatusMarker):
|
||||||
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=self)
|
return _MemberStatusTransition(old=_MemberStatusGroupMarker(other), new=self)
|
||||||
if isinstance(other, _MemberStatusGroupMarker):
|
if isinstance(other, _MemberStatusGroupMarker):
|
||||||
return _MemberStatusTransition(old=other, new=self)
|
return _MemberStatusTransition(old=other, new=self)
|
||||||
raise TypeError(
|
msg = (
|
||||||
f"unsupported operand type(s) for <<: "
|
f"unsupported operand type(s) for <<: "
|
||||||
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
f"{type(self).__name__!r} and {type(other).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
result = " | ".join(map(str, sorted(self.statuses, key=str)))
|
result = " | ".join(map(str, sorted(self.statuses, key=str)))
|
||||||
|
|
@ -138,8 +153,8 @@ class _MemberStatusGroupMarker:
|
||||||
|
|
||||||
class _MemberStatusTransition:
|
class _MemberStatusTransition:
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"old",
|
|
||||||
"new",
|
"new",
|
||||||
|
"old",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None:
|
def __init__(self, *, old: _MemberStatusGroupMarker, new: _MemberStatusGroupMarker) -> None:
|
||||||
|
|
@ -149,7 +164,7 @@ class _MemberStatusTransition:
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.old} >> {self.new}"
|
return f"{self.old} >> {self.new}"
|
||||||
|
|
||||||
def __invert__(self: TransitionT) -> TransitionT:
|
def __invert__(self) -> Self:
|
||||||
return type(self)(old=self.new, new=self.old)
|
return type(self)(old=self.new, new=self.old)
|
||||||
|
|
||||||
def check(self, *, old: ChatMember, new: ChatMember) -> bool:
|
def check(self, *, old: ChatMember, new: ChatMember) -> bool:
|
||||||
|
|
@ -177,11 +192,9 @@ class ChatMemberUpdatedFilter(Filter):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
member_status_changed: Union[
|
member_status_changed: (
|
||||||
_MemberStatusMarker,
|
_MemberStatusMarker | _MemberStatusGroupMarker | _MemberStatusTransition
|
||||||
_MemberStatusGroupMarker,
|
),
|
||||||
_MemberStatusTransition,
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
self.member_status_changed = member_status_changed
|
self.member_status_changed = member_status_changed
|
||||||
|
|
||||||
|
|
@ -190,7 +203,7 @@ class ChatMemberUpdatedFilter(Filter):
|
||||||
member_status_changed=self.member_status_changed,
|
member_status_changed=self.member_status_changed,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(self, member_updated: ChatMemberUpdated) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, member_updated: ChatMemberUpdated) -> bool | dict[str, Any]:
|
||||||
old = member_updated.old_chat_member
|
old = member_updated.old_chat_member
|
||||||
new = member_updated.new_chat_member
|
new = member_updated.new_chat_member
|
||||||
rule = self.member_status_changed
|
rule = self.member_status_changed
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,21 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import (
|
from re import Match, Pattern
|
||||||
TYPE_CHECKING,
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Match,
|
|
||||||
Optional,
|
|
||||||
Pattern,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from magic_filter import MagicFilter
|
|
||||||
|
|
||||||
from aiogram.filters.base import Filter
|
from aiogram.filters.base import Filter
|
||||||
from aiogram.types import BotCommand, Message
|
from aiogram.types import BotCommand, Message
|
||||||
from aiogram.utils.deep_linking import decode_payload
|
from aiogram.utils.deep_linking import decode_payload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from magic_filter import MagicFilter
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
|
|
||||||
# TODO: rm type ignore after py3.8 support expiration or mypy bug fix
|
CommandPatternType = str | re.Pattern[str] | BotCommand
|
||||||
CommandPatternType = Union[str, re.Pattern, BotCommand] # type: ignore[type-arg]
|
|
||||||
|
|
||||||
|
|
||||||
class CommandException(Exception):
|
class CommandException(Exception):
|
||||||
|
|
@ -41,20 +31,20 @@ class Command(Filter):
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"commands",
|
"commands",
|
||||||
"prefix",
|
|
||||||
"ignore_case",
|
"ignore_case",
|
||||||
"ignore_mention",
|
"ignore_mention",
|
||||||
"magic",
|
"magic",
|
||||||
|
"prefix",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*values: CommandPatternType,
|
*values: CommandPatternType,
|
||||||
commands: Optional[Union[Sequence[CommandPatternType], CommandPatternType]] = None,
|
commands: Sequence[CommandPatternType] | CommandPatternType | None = None,
|
||||||
prefix: str = "/",
|
prefix: str = "/",
|
||||||
ignore_case: bool = False,
|
ignore_case: bool = False,
|
||||||
ignore_mention: bool = False,
|
ignore_mention: bool = False,
|
||||||
magic: Optional[MagicFilter] = None,
|
magic: MagicFilter | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List of commands (string or compiled regexp patterns)
|
List of commands (string or compiled regexp patterns)
|
||||||
|
|
@ -74,26 +64,29 @@ class Command(Filter):
|
||||||
commands = [commands]
|
commands = [commands]
|
||||||
|
|
||||||
if not isinstance(commands, Iterable):
|
if not isinstance(commands, Iterable):
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Command filter only supports str, re.Pattern, BotCommand object"
|
"Command filter only supports str, re.Pattern, BotCommand object"
|
||||||
" or their Iterable"
|
" or their Iterable"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
for command in (*values, *commands):
|
for command in (*values, *commands):
|
||||||
if isinstance(command, BotCommand):
|
if isinstance(command, BotCommand):
|
||||||
command = command.command
|
command = command.command
|
||||||
if not isinstance(command, (str, re.Pattern)):
|
if not isinstance(command, (str, re.Pattern)):
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Command filter only supports str, re.Pattern, BotCommand object"
|
"Command filter only supports str, re.Pattern, BotCommand object"
|
||||||
" or their Iterable"
|
" or their Iterable"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if ignore_case and isinstance(command, str):
|
if ignore_case and isinstance(command, str):
|
||||||
command = command.casefold()
|
command = command.casefold()
|
||||||
items.append(command)
|
items.append(command)
|
||||||
|
|
||||||
if not items:
|
if not items:
|
||||||
raise ValueError("At least one command should be specified")
|
msg = "At least one command should be specified"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.commands = tuple(items)
|
self.commands = tuple(items)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
@ -110,11 +103,11 @@ class Command(Filter):
|
||||||
magic=self.magic,
|
magic=self.magic,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
def update_handler_flags(self, flags: dict[str, Any]) -> None:
|
||||||
commands = flags.setdefault("commands", [])
|
commands = flags.setdefault("commands", [])
|
||||||
commands.append(self)
|
commands.append(self)
|
||||||
|
|
||||||
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, message: Message, bot: Bot) -> bool | dict[str, Any]:
|
||||||
if not isinstance(message, Message):
|
if not isinstance(message, Message):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -137,7 +130,8 @@ class Command(Filter):
|
||||||
try:
|
try:
|
||||||
full_command, *args = text.split(maxsplit=1)
|
full_command, *args = text.split(maxsplit=1)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise CommandException("not enough values to unpack")
|
msg = "not enough values to unpack"
|
||||||
|
raise CommandException(msg)
|
||||||
|
|
||||||
# Separate command into valuable parts
|
# Separate command into valuable parts
|
||||||
# "/command@mention" -> "/", ("command", "@", "mention")
|
# "/command@mention" -> "/", ("command", "@", "mention")
|
||||||
|
|
@ -151,13 +145,15 @@ class Command(Filter):
|
||||||
|
|
||||||
def validate_prefix(self, command: CommandObject) -> None:
|
def validate_prefix(self, command: CommandObject) -> None:
|
||||||
if command.prefix not in self.prefix:
|
if command.prefix not in self.prefix:
|
||||||
raise CommandException("Invalid command prefix")
|
msg = "Invalid command prefix"
|
||||||
|
raise CommandException(msg)
|
||||||
|
|
||||||
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
|
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
|
||||||
if command.mention and not self.ignore_mention:
|
if command.mention and not self.ignore_mention:
|
||||||
me = await bot.me()
|
me = await bot.me()
|
||||||
if me.username and command.mention.lower() != me.username.lower():
|
if me.username and command.mention.lower() != me.username.lower():
|
||||||
raise CommandException("Mention did not match")
|
msg = "Mention did not match"
|
||||||
|
raise CommandException(msg)
|
||||||
|
|
||||||
def validate_command(self, command: CommandObject) -> CommandObject:
|
def validate_command(self, command: CommandObject) -> CommandObject:
|
||||||
for allowed_command in cast(Sequence[CommandPatternType], self.commands):
|
for allowed_command in cast(Sequence[CommandPatternType], self.commands):
|
||||||
|
|
@ -174,7 +170,8 @@ class Command(Filter):
|
||||||
|
|
||||||
if command_name == allowed_command: # String
|
if command_name == allowed_command: # String
|
||||||
return command
|
return command
|
||||||
raise CommandException("Command did not match pattern")
|
msg = "Command did not match pattern"
|
||||||
|
raise CommandException(msg)
|
||||||
|
|
||||||
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
|
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
|
||||||
"""
|
"""
|
||||||
|
|
@ -196,7 +193,8 @@ class Command(Filter):
|
||||||
return command
|
return command
|
||||||
result = self.magic.resolve(command)
|
result = self.magic.resolve(command)
|
||||||
if not result:
|
if not result:
|
||||||
raise CommandException("Rejected via magic filter")
|
msg = "Rejected via magic filter"
|
||||||
|
raise CommandException(msg)
|
||||||
return replace(command, magic_result=result)
|
return replace(command, magic_result=result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -211,13 +209,13 @@ class CommandObject:
|
||||||
"""Command prefix"""
|
"""Command prefix"""
|
||||||
command: str = ""
|
command: str = ""
|
||||||
"""Command without prefix and mention"""
|
"""Command without prefix and mention"""
|
||||||
mention: Optional[str] = None
|
mention: str | None = None
|
||||||
"""Mention (if available)"""
|
"""Mention (if available)"""
|
||||||
args: Optional[str] = field(repr=False, default=None)
|
args: str | None = field(repr=False, default=None)
|
||||||
"""Command argument"""
|
"""Command argument"""
|
||||||
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
|
regexp_match: Match[str] | None = field(repr=False, default=None)
|
||||||
"""Will be presented match result if the command is presented as regexp in filter"""
|
"""Will be presented match result if the command is presented as regexp in filter"""
|
||||||
magic_result: Optional[Any] = field(repr=False, default=None)
|
magic_result: Any | None = field(repr=False, default=None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mentioned(self) -> bool:
|
def mentioned(self) -> bool:
|
||||||
|
|
@ -246,7 +244,7 @@ class CommandStart(Command):
|
||||||
deep_link_encoded: bool = False,
|
deep_link_encoded: bool = False,
|
||||||
ignore_case: bool = False,
|
ignore_case: bool = False,
|
||||||
ignore_mention: bool = False,
|
ignore_mention: bool = False,
|
||||||
magic: Optional[MagicFilter] = None,
|
magic: MagicFilter | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"start",
|
"start",
|
||||||
|
|
@ -287,12 +285,14 @@ class CommandStart(Command):
|
||||||
if not self.deep_link:
|
if not self.deep_link:
|
||||||
return command
|
return command
|
||||||
if not command.args:
|
if not command.args:
|
||||||
raise CommandException("Deep-link was missing")
|
msg = "Deep-link was missing"
|
||||||
|
raise CommandException(msg)
|
||||||
args = command.args
|
args = command.args
|
||||||
if self.deep_link_encoded:
|
if self.deep_link_encoded:
|
||||||
try:
|
try:
|
||||||
args = decode_payload(args)
|
args = decode_payload(args)
|
||||||
except UnicodeDecodeError as e:
|
except UnicodeDecodeError as e:
|
||||||
raise CommandException(f"Failed to decode Base64: {e}")
|
msg = f"Failed to decode Base64: {e}"
|
||||||
|
raise CommandException(msg)
|
||||||
return replace(command, args=args)
|
return replace(command, args=args)
|
||||||
return command
|
return command
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Pattern, Type, Union, cast
|
from re import Pattern
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from aiogram.filters.base import Filter
|
from aiogram.filters.base import Filter
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
@ -13,15 +14,16 @@ class ExceptionTypeFilter(Filter):
|
||||||
|
|
||||||
__slots__ = ("exceptions",)
|
__slots__ = ("exceptions",)
|
||||||
|
|
||||||
def __init__(self, *exceptions: Type[Exception]):
|
def __init__(self, *exceptions: type[Exception]):
|
||||||
"""
|
"""
|
||||||
:param exceptions: Exception type(s)
|
:param exceptions: Exception type(s)
|
||||||
"""
|
"""
|
||||||
if not exceptions:
|
if not exceptions:
|
||||||
raise ValueError("At least one exception type is required")
|
msg = "At least one exception type is required"
|
||||||
|
raise ValueError(msg)
|
||||||
self.exceptions = exceptions
|
self.exceptions = exceptions
|
||||||
|
|
||||||
async def __call__(self, obj: TelegramObject) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, obj: TelegramObject) -> bool | dict[str, Any]:
|
||||||
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
|
return isinstance(cast(ErrorEvent, obj).exception, self.exceptions)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -32,7 +34,7 @@ class ExceptionMessageFilter(Filter):
|
||||||
|
|
||||||
__slots__ = ("pattern",)
|
__slots__ = ("pattern",)
|
||||||
|
|
||||||
def __init__(self, pattern: Union[str, Pattern[str]]):
|
def __init__(self, pattern: str | Pattern[str]):
|
||||||
"""
|
"""
|
||||||
:param pattern: Regexp pattern
|
:param pattern: Regexp pattern
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,7 +50,7 @@ class ExceptionMessageFilter(Filter):
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
obj: TelegramObject,
|
obj: TelegramObject,
|
||||||
) -> Union[bool, Dict[str, Any]]:
|
) -> bool | dict[str, Any]:
|
||||||
result = self.pattern.match(str(cast(ErrorEvent, obj).exception))
|
result = self.pattern.match(str(cast(ErrorEvent, obj).exception))
|
||||||
if not result:
|
if not result:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiogram.filters import Filter
|
from aiogram.filters import Filter
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ class _InvertFilter(_LogicFilter):
|
||||||
def __init__(self, target: "FilterObject") -> None:
|
def __init__(self, target: "FilterObject") -> None:
|
||||||
self.target = target
|
self.target = target
|
||||||
|
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
|
||||||
return not bool(await self.target.call(*args, **kwargs))
|
return not bool(await self.target.call(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ class _AndFilter(_LogicFilter):
|
||||||
def __init__(self, *targets: "FilterObject") -> None:
|
def __init__(self, *targets: "FilterObject") -> None:
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
|
||||||
final_result = {}
|
final_result = {}
|
||||||
|
|
||||||
for target in self.targets:
|
for target in self.targets:
|
||||||
|
|
@ -48,7 +48,7 @@ class _OrFilter(_LogicFilter):
|
||||||
def __init__(self, *targets: "FilterObject") -> None:
|
def __init__(self, *targets: "FilterObject") -> None:
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Union[bool, Dict[str, Any]]:
|
async def __call__(self, *args: Any, **kwargs: Any) -> bool | dict[str, Any]:
|
||||||
for target in self.targets:
|
for target in self.targets:
|
||||||
result = await target.call(*args, **kwargs)
|
result = await target.call(*args, **kwargs)
|
||||||
if not result:
|
if not result:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class MagicData(Filter):
|
||||||
|
|
||||||
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
|
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
|
||||||
return self.magic_data.resolve(
|
return self.magic_data.resolve(
|
||||||
AttrDict({"event": event, **dict(enumerate(args)), **kwargs})
|
AttrDict({"event": event, **dict(enumerate(args)), **kwargs}),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from aiogram.filters.base import Filter
|
from aiogram.filters.base import Filter
|
||||||
from aiogram.fsm.state import State, StatesGroup
|
from aiogram.fsm.state import State, StatesGroup
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]]
|
StateType = str | State | StatesGroup | type[StatesGroup] | None
|
||||||
|
|
||||||
|
|
||||||
class StateFilter(Filter):
|
class StateFilter(Filter):
|
||||||
|
|
@ -17,7 +18,8 @@ class StateFilter(Filter):
|
||||||
|
|
||||||
def __init__(self, *states: StateType) -> None:
|
def __init__(self, *states: StateType) -> None:
|
||||||
if not states:
|
if not states:
|
||||||
raise ValueError("At least one state is required")
|
msg = "At least one state is required"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.states = states
|
self.states = states
|
||||||
|
|
||||||
|
|
@ -27,17 +29,22 @@ class StateFilter(Filter):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self, obj: TelegramObject, raw_state: Optional[str] = None
|
self,
|
||||||
) -> Union[bool, Dict[str, Any]]:
|
obj: TelegramObject,
|
||||||
|
raw_state: str | None = None,
|
||||||
|
) -> bool | dict[str, Any]:
|
||||||
allowed_states = cast(Sequence[StateType], self.states)
|
allowed_states = cast(Sequence[StateType], self.states)
|
||||||
for allowed_state in allowed_states:
|
for allowed_state in allowed_states:
|
||||||
if isinstance(allowed_state, str) or allowed_state is None:
|
if isinstance(allowed_state, str) or allowed_state is None:
|
||||||
if allowed_state == "*" or raw_state == allowed_state:
|
if allowed_state in {"*", raw_state}:
|
||||||
return True
|
return True
|
||||||
elif isinstance(allowed_state, (State, StatesGroup)):
|
elif isinstance(allowed_state, (State, StatesGroup)):
|
||||||
if allowed_state(event=obj, raw_state=raw_state):
|
if allowed_state(event=obj, raw_state=raw_state):
|
||||||
return True
|
return True
|
||||||
elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup):
|
elif (
|
||||||
if allowed_state()(event=obj, raw_state=raw_state):
|
isclass(allowed_state)
|
||||||
return True
|
and issubclass(allowed_state, StatesGroup)
|
||||||
|
and allowed_state()(event=obj, raw_state=raw_state)
|
||||||
|
):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Mapping, Optional, overload
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, overload
|
||||||
|
|
||||||
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
|
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
|
||||||
|
|
||||||
|
|
@ -11,27 +12,29 @@ class FSMContext:
|
||||||
async def set_state(self, state: StateType = None) -> None:
|
async def set_state(self, state: StateType = None) -> None:
|
||||||
await self.storage.set_state(key=self.key, state=state)
|
await self.storage.set_state(key=self.key, state=state)
|
||||||
|
|
||||||
async def get_state(self) -> Optional[str]:
|
async def get_state(self) -> str | None:
|
||||||
return await self.storage.get_state(key=self.key)
|
return await self.storage.get_state(key=self.key)
|
||||||
|
|
||||||
async def set_data(self, data: Mapping[str, Any]) -> None:
|
async def set_data(self, data: Mapping[str, Any]) -> None:
|
||||||
await self.storage.set_data(key=self.key, data=data)
|
await self.storage.set_data(key=self.key, data=data)
|
||||||
|
|
||||||
async def get_data(self) -> Dict[str, Any]:
|
async def get_data(self) -> dict[str, Any]:
|
||||||
return await self.storage.get_data(key=self.key)
|
return await self.storage.get_data(key=self.key)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, key: str) -> Optional[Any]: ...
|
async def get_value(self, key: str) -> Any | None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, key: str, default: Any) -> Any: ...
|
async def get_value(self, key: str, default: Any) -> Any: ...
|
||||||
|
|
||||||
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
async def get_value(self, key: str, default: Any | None = None) -> Any | None:
|
||||||
return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default)
|
return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default)
|
||||||
|
|
||||||
async def update_data(
|
async def update_data(
|
||||||
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any
|
self,
|
||||||
) -> Dict[str, Any]:
|
data: Mapping[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
if data:
|
if data:
|
||||||
kwargs.update(data)
|
kwargs.update(data)
|
||||||
return await self.storage.update_data(key=self.key, data=kwargs)
|
return await self.storage.update_data(key=self.key, data=kwargs)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||||
|
|
@ -27,9 +28,9 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
bot: Bot = cast(Bot, data["bot"])
|
bot: Bot = cast(Bot, data["bot"])
|
||||||
context = self.resolve_event_context(bot, data)
|
context = self.resolve_event_context(bot, data)
|
||||||
|
|
@ -45,9 +46,9 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
def resolve_event_context(
|
def resolve_event_context(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> Optional[FSMContext]:
|
) -> FSMContext | None:
|
||||||
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
|
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
|
||||||
return self.resolve_context(
|
return self.resolve_context(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
|
|
@ -61,12 +62,12 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
def resolve_context(
|
def resolve_context(
|
||||||
self,
|
self,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: Optional[int],
|
chat_id: int | None,
|
||||||
user_id: Optional[int],
|
user_id: int | None,
|
||||||
thread_id: Optional[int] = None,
|
thread_id: int | None = None,
|
||||||
business_connection_id: Optional[str] = None,
|
business_connection_id: str | None = None,
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> Optional[FSMContext]:
|
) -> FSMContext | None:
|
||||||
if chat_id is None:
|
if chat_id is None:
|
||||||
chat_id = user_id
|
chat_id = user_id
|
||||||
|
|
||||||
|
|
@ -92,8 +93,8 @@ class FSMContextMiddleware(BaseMiddleware):
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
thread_id: Optional[int] = None,
|
thread_id: int | None = None,
|
||||||
business_connection_id: Optional[str] = None,
|
business_connection_id: str | None = None,
|
||||||
destiny: str = DEFAULT_DESTINY,
|
destiny: str = DEFAULT_DESTINY,
|
||||||
) -> FSMContext:
|
) -> FSMContext:
|
||||||
return FSMContext(
|
return FSMContext(
|
||||||
|
|
|
||||||
|
|
@ -2,26 +2,15 @@ from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, ClassVar, overload
|
||||||
Any,
|
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from aiogram import loggers
|
from aiogram import loggers
|
||||||
from aiogram.dispatcher.dispatcher import Dispatcher
|
from aiogram.dispatcher.dispatcher import Dispatcher
|
||||||
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
|
||||||
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
||||||
from aiogram.dispatcher.flags import extract_flags_from_object
|
from aiogram.dispatcher.flags import extract_flags_from_object
|
||||||
from aiogram.dispatcher.router import Router
|
from aiogram.dispatcher.router import Router
|
||||||
|
|
@ -36,16 +25,20 @@ from aiogram.utils.class_attrs_resolver import (
|
||||||
get_sorted_mro_attrs_resolver,
|
get_sorted_mro_attrs_resolver,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
||||||
|
|
||||||
|
|
||||||
class HistoryManager:
|
class HistoryManager:
|
||||||
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
|
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
|
||||||
self._size = size
|
self._size = size
|
||||||
self._state = state
|
self._state = state
|
||||||
self._history_state = FSMContext(
|
self._history_state = FSMContext(
|
||||||
storage=state.storage, key=replace(state.key, destiny=destiny)
|
storage=state.storage,
|
||||||
|
key=replace(state.key, destiny=destiny),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
async def push(self, state: str | None, data: dict[str, Any]) -> None:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
history.append({"state": state, "data": data})
|
history.append({"state": state, "data": data})
|
||||||
|
|
@ -55,7 +48,7 @@ class HistoryManager:
|
||||||
|
|
||||||
await self._history_state.update_data(history=history)
|
await self._history_state.update_data(history=history)
|
||||||
|
|
||||||
async def pop(self) -> Optional[MemoryStorageRecord]:
|
async def pop(self) -> MemoryStorageRecord | None:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
if not history:
|
if not history:
|
||||||
|
|
@ -70,14 +63,14 @@ class HistoryManager:
|
||||||
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
|
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
|
||||||
return MemoryStorageRecord(state=state, data=data)
|
return MemoryStorageRecord(state=state, data=data)
|
||||||
|
|
||||||
async def get(self) -> Optional[MemoryStorageRecord]:
|
async def get(self) -> MemoryStorageRecord | None:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
if not history:
|
if not history:
|
||||||
return None
|
return None
|
||||||
return MemoryStorageRecord(**history[-1])
|
return MemoryStorageRecord(**history[-1])
|
||||||
|
|
||||||
async def all(self) -> List[MemoryStorageRecord]:
|
async def all(self) -> list[MemoryStorageRecord]:
|
||||||
history_data = await self._history_state.get_data()
|
history_data = await self._history_state.get_data()
|
||||||
history = history_data.setdefault("history", [])
|
history = history_data.setdefault("history", [])
|
||||||
return [MemoryStorageRecord(**item) for item in history]
|
return [MemoryStorageRecord(**item) for item in history]
|
||||||
|
|
@ -91,11 +84,11 @@ class HistoryManager:
|
||||||
data = await self._state.get_data()
|
data = await self._state.get_data()
|
||||||
await self.push(state, data)
|
await self.push(state, data)
|
||||||
|
|
||||||
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
async def _set_state(self, state: str | None, data: dict[str, Any]) -> None:
|
||||||
await self._state.set_state(state)
|
await self._state.set_state(state)
|
||||||
await self._state.set_data(data)
|
await self._state.set_data(data)
|
||||||
|
|
||||||
async def rollback(self) -> Optional[str]:
|
async def rollback(self) -> str | None:
|
||||||
previous_state = await self.pop()
|
previous_state = await self.pop()
|
||||||
if not previous_state:
|
if not previous_state:
|
||||||
await self._set_state(None, {})
|
await self._set_state(None, {})
|
||||||
|
|
@ -116,14 +109,14 @@ class ObserverDecorator:
|
||||||
name: str,
|
name: str,
|
||||||
filters: tuple[CallbackType, ...],
|
filters: tuple[CallbackType, ...],
|
||||||
action: SceneAction | None = None,
|
action: SceneAction | None = None,
|
||||||
after: Optional[After] = None,
|
after: After | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
self.action = action
|
self.action = action
|
||||||
self.after = after
|
self.after = after
|
||||||
|
|
||||||
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
|
def _wrap_filter(self, target: type[Scene] | CallbackType) -> None:
|
||||||
handlers = getattr(target, "__aiogram_handler__", None)
|
handlers = getattr(target, "__aiogram_handler__", None)
|
||||||
if not handlers:
|
if not handlers:
|
||||||
handlers = []
|
handlers = []
|
||||||
|
|
@ -135,7 +128,7 @@ class ObserverDecorator:
|
||||||
handler=target,
|
handler=target,
|
||||||
filters=self.filters,
|
filters=self.filters,
|
||||||
after=self.after,
|
after=self.after,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _wrap_action(self, target: CallbackType) -> None:
|
def _wrap_action(self, target: CallbackType) -> None:
|
||||||
|
|
@ -154,13 +147,14 @@ class ObserverDecorator:
|
||||||
else:
|
else:
|
||||||
self._wrap_action(target)
|
self._wrap_action(target)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Only function or method is allowed")
|
msg = "Only function or method is allowed"
|
||||||
|
raise TypeError(msg)
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def leave(self) -> ActionContainer:
|
def leave(self) -> ActionContainer:
|
||||||
return ActionContainer(self.name, self.filters, SceneAction.leave)
|
return ActionContainer(self.name, self.filters, SceneAction.leave)
|
||||||
|
|
||||||
def enter(self, target: Type[Scene]) -> ActionContainer:
|
def enter(self, target: type[Scene]) -> ActionContainer:
|
||||||
return ActionContainer(self.name, self.filters, SceneAction.enter, target)
|
return ActionContainer(self.name, self.filters, SceneAction.enter, target)
|
||||||
|
|
||||||
def exit(self) -> ActionContainer:
|
def exit(self) -> ActionContainer:
|
||||||
|
|
@ -181,9 +175,9 @@ class ActionContainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
filters: Tuple[CallbackType, ...],
|
filters: tuple[CallbackType, ...],
|
||||||
action: SceneAction,
|
action: SceneAction,
|
||||||
target: Optional[Union[Type[Scene], State, str]] = None,
|
target: type[Scene] | State | str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
|
@ -201,33 +195,27 @@ class ActionContainer:
|
||||||
await wizard.back()
|
await wizard.back()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
class HandlerContainer:
|
class HandlerContainer:
|
||||||
def __init__(
|
name: str
|
||||||
self,
|
handler: CallbackType
|
||||||
name: str,
|
filters: tuple[CallbackType, ...]
|
||||||
handler: CallbackType,
|
after: After | None = None
|
||||||
filters: Tuple[CallbackType, ...],
|
|
||||||
after: Optional[After] = None,
|
|
||||||
) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.handler = handler
|
|
||||||
self.filters = filters
|
|
||||||
self.after = after
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass
|
||||||
class SceneConfig:
|
class SceneConfig:
|
||||||
state: Optional[str]
|
state: str | None
|
||||||
"""Scene state"""
|
"""Scene state"""
|
||||||
handlers: List[HandlerContainer]
|
handlers: list[HandlerContainer]
|
||||||
"""Scene handlers"""
|
"""Scene handlers"""
|
||||||
actions: Dict[SceneAction, Dict[str, CallableObject]]
|
actions: dict[SceneAction, dict[str, CallableObject]]
|
||||||
"""Scene actions"""
|
"""Scene actions"""
|
||||||
reset_data_on_enter: Optional[bool] = None
|
reset_data_on_enter: bool | None = None
|
||||||
"""Reset scene data on enter"""
|
"""Reset scene data on enter"""
|
||||||
reset_history_on_enter: Optional[bool] = None
|
reset_history_on_enter: bool | None = None
|
||||||
"""Reset scene history on enter"""
|
"""Reset scene history on enter"""
|
||||||
callback_query_without_state: Optional[bool] = None
|
callback_query_without_state: bool | None = None
|
||||||
"""Allow callback query without state"""
|
"""Allow callback query without state"""
|
||||||
attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver
|
attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver
|
||||||
"""
|
"""
|
||||||
|
|
@ -247,9 +235,9 @@ async def _empty_handler(*args: Any, **kwargs: Any) -> None:
|
||||||
class SceneHandlerWrapper:
|
class SceneHandlerWrapper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scene: Type[Scene],
|
scene: type[Scene],
|
||||||
handler: CallbackType,
|
handler: CallbackType,
|
||||||
after: Optional[After] = None,
|
after: After | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scene = scene
|
self.scene = scene
|
||||||
self.handler = CallableObject(handler)
|
self.handler = CallableObject(handler)
|
||||||
|
|
@ -271,7 +259,7 @@ class SceneHandlerWrapper:
|
||||||
update_type=event_update.event_type,
|
update_type=event_update.event_type,
|
||||||
event=event,
|
event=event,
|
||||||
data=kwargs,
|
data=kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.handler.call(scene, event, **kwargs)
|
result = await self.handler.call(scene, event, **kwargs)
|
||||||
|
|
@ -331,7 +319,7 @@ class Scene:
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
handlers: list[HandlerContainer] = []
|
handlers: list[HandlerContainer] = []
|
||||||
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
|
actions: defaultdict[SceneAction, dict[str, CallableObject]] = defaultdict(dict)
|
||||||
|
|
||||||
for base in cls.__bases__:
|
for base in cls.__bases__:
|
||||||
if not issubclass(base, Scene):
|
if not issubclass(base, Scene):
|
||||||
|
|
@ -353,7 +341,7 @@ class Scene:
|
||||||
if attrs_resolver is None:
|
if attrs_resolver is None:
|
||||||
attrs_resolver = get_sorted_mro_attrs_resolver
|
attrs_resolver = get_sorted_mro_attrs_resolver
|
||||||
|
|
||||||
for name, value in attrs_resolver(cls):
|
for _name, value in attrs_resolver(cls):
|
||||||
if scene_handlers := getattr(value, "__aiogram_handler__", None):
|
if scene_handlers := getattr(value, "__aiogram_handler__", None):
|
||||||
handlers.extend(scene_handlers)
|
handlers.extend(scene_handlers)
|
||||||
if isinstance(value, ObserverDecorator):
|
if isinstance(value, ObserverDecorator):
|
||||||
|
|
@ -363,7 +351,7 @@ class Scene:
|
||||||
_empty_handler,
|
_empty_handler,
|
||||||
value.filters,
|
value.filters,
|
||||||
after=value.after,
|
after=value.after,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
if hasattr(value, "__aiogram_action__"):
|
if hasattr(value, "__aiogram_action__"):
|
||||||
for action, action_handlers in value.__aiogram_action__.items():
|
for action, action_handlers in value.__aiogram_action__.items():
|
||||||
|
|
@ -408,7 +396,7 @@ class Scene:
|
||||||
router.observers[observer_name].filter(StateFilter(scene_config.state))
|
router.observers[observer_name].filter(StateFilter(scene_config.state))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def as_router(cls, name: Optional[str] = None) -> Router:
|
def as_router(cls, name: str | None = None) -> Router:
|
||||||
"""
|
"""
|
||||||
Returns the scene as a router.
|
Returns the scene as a router.
|
||||||
|
|
||||||
|
|
@ -433,7 +421,9 @@ class Scene:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def enter_to_scene_handler(
|
async def enter_to_scene_handler(
|
||||||
event: TelegramObject, scenes: ScenesManager, **middleware_kwargs: Any
|
event: TelegramObject,
|
||||||
|
scenes: ScenesManager,
|
||||||
|
**middleware_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
await scenes.enter(cls, **{**handler_kwargs, **middleware_kwargs})
|
await scenes.enter(cls, **{**handler_kwargs, **middleware_kwargs})
|
||||||
|
|
||||||
|
|
@ -461,7 +451,7 @@ class SceneWizard:
|
||||||
state: FSMContext,
|
state: FSMContext,
|
||||||
update_type: str,
|
update_type: str,
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A class that represents a wizard for managing scenes in a Telegram bot.
|
A class that represents a wizard for managing scenes in a Telegram bot.
|
||||||
|
|
@ -480,7 +470,7 @@ class SceneWizard:
|
||||||
self.event = event
|
self.event = event
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
self.scene: Optional[Scene] = None
|
self.scene: Scene | None = None
|
||||||
|
|
||||||
async def enter(self, **kwargs: Any) -> None:
|
async def enter(self, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -548,7 +538,7 @@ class SceneWizard:
|
||||||
assert self.scene_config.state is not None, "Scene state is not specified"
|
assert self.scene_config.state is not None, "Scene state is not specified"
|
||||||
await self.goto(self.scene_config.state, **kwargs)
|
await self.goto(self.scene_config.state, **kwargs)
|
||||||
|
|
||||||
async def goto(self, scene: Union[Type[Scene], State, str], **kwargs: Any) -> None:
|
async def goto(self, scene: type[Scene] | State | str, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
The `goto` method transitions to a new scene.
|
The `goto` method transitions to a new scene.
|
||||||
It first calls the `leave` method to perform any necessary cleanup
|
It first calls the `leave` method to perform any necessary cleanup
|
||||||
|
|
@ -565,13 +555,16 @@ class SceneWizard:
|
||||||
|
|
||||||
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
|
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
|
||||||
if not self.scene:
|
if not self.scene:
|
||||||
raise SceneException("Scene is not initialized")
|
msg = "Scene is not initialized"
|
||||||
|
raise SceneException(msg)
|
||||||
|
|
||||||
loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state)
|
loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state)
|
||||||
action_config = self.scene_config.actions.get(action, {})
|
action_config = self.scene_config.actions.get(action, {})
|
||||||
if not action_config:
|
if not action_config:
|
||||||
loggers.scene.debug(
|
loggers.scene.debug(
|
||||||
"Action %r not found in scene %r", action.name, self.scene_config.state
|
"Action %r not found in scene %r",
|
||||||
|
action.name,
|
||||||
|
self.scene_config.state,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -597,7 +590,7 @@ class SceneWizard:
|
||||||
"""
|
"""
|
||||||
await self.state.set_data(data=data)
|
await self.state.set_data(data=data)
|
||||||
|
|
||||||
async def get_data(self) -> Dict[str, Any]:
|
async def get_data(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
This method returns the data stored in the current state.
|
This method returns the data stored in the current state.
|
||||||
|
|
||||||
|
|
@ -606,7 +599,7 @@ class SceneWizard:
|
||||||
return await self.state.get_data()
|
return await self.state.get_data()
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, key: str) -> Optional[Any]:
|
async def get_value(self, key: str) -> Any | None:
|
||||||
"""
|
"""
|
||||||
This method returns the value from key in the data of the current state.
|
This method returns the value from key in the data of the current state.
|
||||||
|
|
||||||
|
|
@ -614,7 +607,6 @@ class SceneWizard:
|
||||||
|
|
||||||
:return: A dictionary containing the data stored in the scene state.
|
:return: A dictionary containing the data stored in the scene state.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, key: str, default: Any) -> Any:
|
async def get_value(self, key: str, default: Any) -> Any:
|
||||||
|
|
@ -626,14 +618,15 @@ class SceneWizard:
|
||||||
|
|
||||||
:return: A dictionary containing the data stored in the scene state.
|
:return: A dictionary containing the data stored in the scene state.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
async def get_value(self, key: str, default: Any | None = None) -> Any | None:
|
||||||
return await self.state.get_value(key, default)
|
return await self.state.get_value(key, default)
|
||||||
|
|
||||||
async def update_data(
|
async def update_data(
|
||||||
self, data: Optional[Mapping[str, Any]] = None, **kwargs: Any
|
self,
|
||||||
) -> Dict[str, Any]:
|
data: Mapping[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
This method updates the data stored in the current state
|
This method updates the data stored in the current state
|
||||||
|
|
||||||
|
|
@ -666,7 +659,7 @@ class ScenesManager:
|
||||||
update_type: str,
|
update_type: str,
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
state: FSMContext,
|
state: FSMContext,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.registry = registry
|
self.registry = registry
|
||||||
self.update_type = update_type
|
self.update_type = update_type
|
||||||
|
|
@ -676,7 +669,7 @@ class ScenesManager:
|
||||||
|
|
||||||
self.history = HistoryManager(self.state)
|
self.history = HistoryManager(self.state)
|
||||||
|
|
||||||
async def _get_scene(self, scene_type: Optional[Union[Type[Scene], State, str]]) -> Scene:
|
async def _get_scene(self, scene_type: type[Scene] | State | str | None) -> Scene:
|
||||||
scene_type = self.registry.get(scene_type)
|
scene_type = self.registry.get(scene_type)
|
||||||
return scene_type(
|
return scene_type(
|
||||||
wizard=SceneWizard(
|
wizard=SceneWizard(
|
||||||
|
|
@ -689,7 +682,7 @@ class ScenesManager:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_active_scene(self) -> Optional[Scene]:
|
async def _get_active_scene(self) -> Scene | None:
|
||||||
state = await self.state.get_state()
|
state = await self.state.get_state()
|
||||||
try:
|
try:
|
||||||
return await self._get_scene(state)
|
return await self._get_scene(state)
|
||||||
|
|
@ -698,7 +691,7 @@ class ScenesManager:
|
||||||
|
|
||||||
async def enter(
|
async def enter(
|
||||||
self,
|
self,
|
||||||
scene_type: Optional[Union[Type[Scene], State, str]],
|
scene_type: type[Scene] | State | str | None,
|
||||||
_check_active: bool = True,
|
_check_active: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -753,7 +746,7 @@ class SceneRegistry:
|
||||||
self.router = router
|
self.router = router
|
||||||
self.register_on_add = register_on_add
|
self.register_on_add = register_on_add
|
||||||
|
|
||||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
self._scenes: dict[str | None, type[Scene]] = {}
|
||||||
self._setup_middleware(router)
|
self._setup_middleware(router)
|
||||||
|
|
||||||
def _setup_middleware(self, router: Router) -> None:
|
def _setup_middleware(self, router: Router) -> None:
|
||||||
|
|
@ -772,7 +765,7 @@ class SceneRegistry:
|
||||||
self,
|
self,
|
||||||
handler: NextMiddlewareType[TelegramObject],
|
handler: NextMiddlewareType[TelegramObject],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
assert isinstance(event, Update), "Event must be an Update instance"
|
assert isinstance(event, Update), "Event must be an Update instance"
|
||||||
|
|
||||||
|
|
@ -789,7 +782,7 @@ class SceneRegistry:
|
||||||
self,
|
self,
|
||||||
handler: NextMiddlewareType[TelegramObject],
|
handler: NextMiddlewareType[TelegramObject],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
update: Update = data["event_update"]
|
update: Update = data["event_update"]
|
||||||
data["scenes"] = ScenesManager(
|
data["scenes"] = ScenesManager(
|
||||||
|
|
@ -801,7 +794,7 @@ class SceneRegistry:
|
||||||
)
|
)
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
||||||
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
|
def add(self, *scenes: type[Scene], router: Router | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
This method adds the specified scenes to the registry
|
This method adds the specified scenes to the registry
|
||||||
and optionally registers it to the router.
|
and optionally registers it to the router.
|
||||||
|
|
@ -820,13 +813,13 @@ class SceneRegistry:
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if not scenes:
|
if not scenes:
|
||||||
raise ValueError("At least one scene must be specified")
|
msg = "At least one scene must be specified"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
if scene.__scene_config__.state in self._scenes:
|
if scene.__scene_config__.state in self._scenes:
|
||||||
raise SceneException(
|
msg = f"Scene with state {scene.__scene_config__.state!r} already exists"
|
||||||
f"Scene with state {scene.__scene_config__.state!r} already exists"
|
raise SceneException(msg)
|
||||||
)
|
|
||||||
|
|
||||||
self._scenes[scene.__scene_config__.state] = scene
|
self._scenes[scene.__scene_config__.state] = scene
|
||||||
|
|
||||||
|
|
@ -835,7 +828,7 @@ class SceneRegistry:
|
||||||
elif self.register_on_add:
|
elif self.register_on_add:
|
||||||
self.router.include_router(scene.as_router())
|
self.router.include_router(scene.as_router())
|
||||||
|
|
||||||
def register(self, *scenes: Type[Scene]) -> None:
|
def register(self, *scenes: type[Scene]) -> None:
|
||||||
"""
|
"""
|
||||||
Registers one or more scenes to the SceneRegistry.
|
Registers one or more scenes to the SceneRegistry.
|
||||||
|
|
||||||
|
|
@ -844,7 +837,7 @@ class SceneRegistry:
|
||||||
"""
|
"""
|
||||||
self.add(*scenes, router=self.router)
|
self.add(*scenes, router=self.router)
|
||||||
|
|
||||||
def get(self, scene: Optional[Union[Type[Scene], State, str]]) -> Type[Scene]:
|
def get(self, scene: type[Scene] | State | str | None) -> type[Scene]:
|
||||||
"""
|
"""
|
||||||
This method returns the registered Scene object for the specified scene.
|
This method returns the registered Scene object for the specified scene.
|
||||||
The scene parameter can be either a Scene object, State object or a string representing
|
The scene parameter can be either a Scene object, State object or a string representing
|
||||||
|
|
@ -865,18 +858,20 @@ class SceneRegistry:
|
||||||
if isinstance(scene, State):
|
if isinstance(scene, State):
|
||||||
scene = scene.state
|
scene = scene.state
|
||||||
if scene is not None and not isinstance(scene, str):
|
if scene is not None and not isinstance(scene, str):
|
||||||
raise SceneException("Scene must be a subclass of Scene, State or a string")
|
msg = "Scene must be a subclass of Scene, State or a string"
|
||||||
|
raise SceneException(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._scenes[scene]
|
return self._scenes[scene]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SceneException(f"Scene {scene!r} is not registered")
|
msg = f"Scene {scene!r} is not registered"
|
||||||
|
raise SceneException(msg)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class After:
|
class After:
|
||||||
action: SceneAction
|
action: SceneAction
|
||||||
scene: Optional[Union[Type[Scene], State, str]] = None
|
scene: type[Scene] | State | str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def exit(cls) -> After:
|
def exit(cls) -> After:
|
||||||
|
|
@ -887,7 +882,7 @@ class After:
|
||||||
return cls(action=SceneAction.back)
|
return cls(action=SceneAction.back)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def goto(cls, scene: Optional[Union[Type[Scene], State, str]]) -> After:
|
def goto(cls, scene: type[Scene] | State | str | None) -> After:
|
||||||
return cls(action=SceneAction.enter, scene=scene)
|
return cls(action=SceneAction.enter, scene=scene)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -898,7 +893,7 @@ class ObserverMarker:
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*filters: CallbackType,
|
*filters: CallbackType,
|
||||||
after: Optional[After] = None,
|
after: After | None = None,
|
||||||
) -> ObserverDecorator:
|
) -> ObserverDecorator:
|
||||||
return ObserverDecorator(
|
return ObserverDecorator(
|
||||||
self.name,
|
self.name,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Iterator, Optional, Tuple, Type, no_type_check
|
from collections.abc import Iterator
|
||||||
|
from typing import Any, no_type_check
|
||||||
|
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
||||||
|
|
@ -9,19 +10,20 @@ class State:
|
||||||
State object
|
State object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None) -> None:
|
def __init__(self, state: str | None = None, group_name: str | None = None) -> None:
|
||||||
self._state = state
|
self._state = state
|
||||||
self._group_name = group_name
|
self._group_name = group_name
|
||||||
self._group: Optional[Type[StatesGroup]] = None
|
self._group: type[StatesGroup] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group(self) -> "Type[StatesGroup]":
|
def group(self) -> "type[StatesGroup]":
|
||||||
if not self._group:
|
if not self._group:
|
||||||
raise RuntimeError("This state is not in any group.")
|
msg = "This state is not in any group."
|
||||||
|
raise RuntimeError(msg)
|
||||||
return self._group
|
return self._group
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> Optional[str]:
|
def state(self) -> str | None:
|
||||||
if self._state is None or self._state == "*":
|
if self._state is None or self._state == "*":
|
||||||
return self._state
|
return self._state
|
||||||
|
|
||||||
|
|
@ -34,12 +36,13 @@ class State:
|
||||||
|
|
||||||
return f"{group}:{self._state}"
|
return f"{group}:{self._state}"
|
||||||
|
|
||||||
def set_parent(self, group: "Type[StatesGroup]") -> None:
|
def set_parent(self, group: "type[StatesGroup]") -> None:
|
||||||
if not issubclass(group, StatesGroup):
|
if not issubclass(group, StatesGroup):
|
||||||
raise ValueError("Group must be subclass of StatesGroup")
|
msg = "Group must be subclass of StatesGroup"
|
||||||
|
raise ValueError(msg)
|
||||||
self._group = group
|
self._group = group
|
||||||
|
|
||||||
def __set_name__(self, owner: "Type[StatesGroup]", name: str) -> None:
|
def __set_name__(self, owner: "type[StatesGroup]", name: str) -> None:
|
||||||
if self._state is None:
|
if self._state is None:
|
||||||
self._state = name
|
self._state = name
|
||||||
self.set_parent(owner)
|
self.set_parent(owner)
|
||||||
|
|
@ -49,12 +52,12 @@ class State:
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
|
||||||
if self.state == "*":
|
if self.state == "*":
|
||||||
return True
|
return True
|
||||||
return raw_state == self.state
|
return raw_state == self.state
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self.state == other.state
|
return self.state == other.state
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
|
|
@ -66,13 +69,13 @@ class State:
|
||||||
|
|
||||||
|
|
||||||
class StatesGroupMeta(type):
|
class StatesGroupMeta(type):
|
||||||
__parent__: "Optional[Type[StatesGroup]]"
|
__parent__: type["StatesGroup"] | None
|
||||||
__childs__: "Tuple[Type[StatesGroup], ...]"
|
__childs__: tuple[type["StatesGroup"], ...]
|
||||||
__states__: Tuple[State, ...]
|
__states__: tuple[State, ...]
|
||||||
__state_names__: Tuple[str, ...]
|
__state_names__: tuple[str, ...]
|
||||||
__all_childs__: Tuple[Type["StatesGroup"], ...]
|
__all_childs__: tuple[type["StatesGroup"], ...]
|
||||||
__all_states__: Tuple[State, ...]
|
__all_states__: tuple[State, ...]
|
||||||
__all_states_names__: Tuple[str, ...]
|
__all_states_names__: tuple[str, ...]
|
||||||
|
|
||||||
@no_type_check
|
@no_type_check
|
||||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||||
|
|
@ -81,7 +84,7 @@ class StatesGroupMeta(type):
|
||||||
states = []
|
states = []
|
||||||
childs = []
|
childs = []
|
||||||
|
|
||||||
for name, arg in namespace.items():
|
for arg in namespace.values():
|
||||||
if isinstance(arg, State):
|
if isinstance(arg, State):
|
||||||
states.append(arg)
|
states.append(arg)
|
||||||
elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
|
elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
|
||||||
|
|
@ -106,10 +109,10 @@ class StatesGroupMeta(type):
|
||||||
@property
|
@property
|
||||||
def __full_group_name__(cls) -> str:
|
def __full_group_name__(cls) -> str:
|
||||||
if cls.__parent__:
|
if cls.__parent__:
|
||||||
return ".".join((cls.__parent__.__full_group_name__, cls.__name__))
|
return f"{cls.__parent__.__full_group_name__}.{cls.__name__}"
|
||||||
return cls.__name__
|
return cls.__name__
|
||||||
|
|
||||||
def _prepare_child(cls, child: Type["StatesGroup"]) -> Type["StatesGroup"]:
|
def _prepare_child(cls, child: type["StatesGroup"]) -> type["StatesGroup"]:
|
||||||
"""Prepare child.
|
"""Prepare child.
|
||||||
|
|
||||||
While adding `cls` for its children, we also need to recalculate
|
While adding `cls` for its children, we also need to recalculate
|
||||||
|
|
@ -123,19 +126,19 @@ class StatesGroupMeta(type):
|
||||||
child.__all_states_names__ = child._get_all_states_names()
|
child.__all_states_names__ = child._get_all_states_names()
|
||||||
return child
|
return child
|
||||||
|
|
||||||
def _get_all_childs(cls) -> Tuple[Type["StatesGroup"], ...]:
|
def _get_all_childs(cls) -> tuple[type["StatesGroup"], ...]:
|
||||||
result = cls.__childs__
|
result = cls.__childs__
|
||||||
for child in cls.__childs__:
|
for child in cls.__childs__:
|
||||||
result += child.__childs__
|
result += child.__childs__
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_all_states(cls) -> Tuple[State, ...]:
|
def _get_all_states(cls) -> tuple[State, ...]:
|
||||||
result = cls.__states__
|
result = cls.__states__
|
||||||
for group in cls.__childs__:
|
for group in cls.__childs__:
|
||||||
result += group.__all_states__
|
result += group.__all_states__
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_all_states_names(cls) -> Tuple[str, ...]:
|
def _get_all_states_names(cls) -> tuple[str, ...]:
|
||||||
return tuple(state.state for state in cls.__all_states__ if state.state)
|
return tuple(state.state for state in cls.__all_states__ if state.state)
|
||||||
|
|
||||||
def __contains__(cls, item: Any) -> bool:
|
def __contains__(cls, item: Any) -> bool:
|
||||||
|
|
@ -156,12 +159,12 @@ class StatesGroupMeta(type):
|
||||||
|
|
||||||
class StatesGroup(metaclass=StatesGroupMeta):
|
class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_root(cls) -> Type["StatesGroup"]:
|
def get_root(cls) -> type["StatesGroup"]:
|
||||||
if cls.__parent__ is None:
|
if cls.__parent__ is None:
|
||||||
return cls
|
return cls
|
||||||
return cls.__parent__.get_root()
|
return cls.__parent__.get_root()
|
||||||
|
|
||||||
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
|
def __call__(self, event: TelegramObject, raw_state: str | None = None) -> bool:
|
||||||
return raw_state in type(self).__all_states_names__
|
return raw_state in type(self).__all_states_names__
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,12 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Any, Literal, overload
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
Dict,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
from aiogram.fsm.state import State
|
from aiogram.fsm.state import State
|
||||||
|
|
||||||
StateType = Optional[Union[str, State]]
|
StateType = str | State | None
|
||||||
|
|
||||||
DEFAULT_DESTINY = "default"
|
DEFAULT_DESTINY = "default"
|
||||||
|
|
||||||
|
|
@ -24,8 +16,8 @@ class StorageKey:
|
||||||
bot_id: int
|
bot_id: int
|
||||||
chat_id: int
|
chat_id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
thread_id: Optional[int] = None
|
thread_id: int | None = None
|
||||||
business_connection_id: Optional[str] = None
|
business_connection_id: str | None = None
|
||||||
destiny: str = DEFAULT_DESTINY
|
destiny: str = DEFAULT_DESTINY
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,7 +28,7 @@ class KeyBuilder(ABC):
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
key: StorageKey,
|
key: StorageKey,
|
||||||
part: Optional[Literal["data", "state", "lock"]] = None,
|
part: Literal["data", "state", "lock"] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build key to be used in storage's db queries
|
Build key to be used in storage's db queries
|
||||||
|
|
@ -45,7 +37,6 @@ class KeyBuilder(ABC):
|
||||||
:param part: part of the record
|
:param part: part of the record
|
||||||
:return: key to be used in storage's db queries
|
:return: key to be used in storage's db queries
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultKeyBuilder(KeyBuilder):
|
class DefaultKeyBuilder(KeyBuilder):
|
||||||
|
|
@ -84,7 +75,7 @@ class DefaultKeyBuilder(KeyBuilder):
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
key: StorageKey,
|
key: StorageKey,
|
||||||
part: Optional[Literal["data", "state", "lock"]] = None,
|
part: Literal["data", "state", "lock"] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
parts = [self.prefix]
|
parts = [self.prefix]
|
||||||
if self.with_bot_id:
|
if self.with_bot_id:
|
||||||
|
|
@ -121,17 +112,15 @@ class BaseStorage(ABC):
|
||||||
:param key: storage key
|
:param key: storage key
|
||||||
:param state: new state
|
:param state: new state
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_state(self, key: StorageKey) -> Optional[str]:
|
async def get_state(self, key: StorageKey) -> str | None:
|
||||||
"""
|
"""
|
||||||
Get key state
|
Get key state
|
||||||
|
|
||||||
:param key: storage key
|
:param key: storage key
|
||||||
:return: current state
|
:return: current state
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
||||||
|
|
@ -141,20 +130,18 @@ class BaseStorage(ABC):
|
||||||
:param key: storage key
|
:param key: storage key
|
||||||
:param data: new data
|
:param data: new data
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
|
async def get_data(self, key: StorageKey) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get current data for key
|
Get current data for key
|
||||||
|
|
||||||
:param key: storage key
|
:param key: storage key
|
||||||
:return: current data
|
:return: current data
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]:
|
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None:
|
||||||
"""
|
"""
|
||||||
Get single value from data by key
|
Get single value from data by key
|
||||||
|
|
||||||
|
|
@ -162,7 +149,6 @@ class BaseStorage(ABC):
|
||||||
:param dict_key: value key
|
:param dict_key: value key
|
||||||
:return: value stored in key of dict or ``None``
|
:return: value stored in key of dict or ``None``
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any:
|
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any:
|
||||||
|
|
@ -174,15 +160,17 @@ class BaseStorage(ABC):
|
||||||
:param default: default value to return
|
:param default: default value to return
|
||||||
:return: value stored in key of dict or default
|
:return: value stored in key of dict or default
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_value(
|
async def get_value(
|
||||||
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
|
self,
|
||||||
) -> Optional[Any]:
|
storage_key: StorageKey,
|
||||||
|
dict_key: str,
|
||||||
|
default: Any | None = None,
|
||||||
|
) -> Any | None:
|
||||||
data = await self.get_data(storage_key)
|
data = await self.get_data(storage_key)
|
||||||
return data.get(dict_key, default)
|
return data.get(dict_key, default)
|
||||||
|
|
||||||
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
|
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Update date in the storage for key (like dict.update)
|
Update date in the storage for key (like dict.update)
|
||||||
|
|
||||||
|
|
@ -200,7 +188,6 @@ class BaseStorage(ABC):
|
||||||
"""
|
"""
|
||||||
Close storage (database connection, file or etc.)
|
Close storage (database connection, file or etc.)
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEventIsolation(ABC):
|
class BaseEventIsolation(ABC):
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,10 @@
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import AsyncGenerator, Hashable, Mapping
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import Any, overload
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
Hashable,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
from aiogram.exceptions import DataNotDictLikeError
|
from aiogram.exceptions import DataNotDictLikeError
|
||||||
from aiogram.fsm.state import State
|
from aiogram.fsm.state import State
|
||||||
|
|
@ -26,8 +18,8 @@ from aiogram.fsm.storage.base import (
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryStorageRecord:
|
class MemoryStorageRecord:
|
||||||
data: Dict[str, Any] = field(default_factory=dict)
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
state: Optional[str] = None
|
state: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemoryStorage(BaseStorage):
|
class MemoryStorage(BaseStorage):
|
||||||
|
|
@ -41,8 +33,8 @@ class MemoryStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
|
self.storage: defaultdict[StorageKey, MemoryStorageRecord] = defaultdict(
|
||||||
MemoryStorageRecord
|
MemoryStorageRecord,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|
@ -51,28 +43,30 @@ class MemoryStorage(BaseStorage):
|
||||||
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
|
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
|
||||||
self.storage[key].state = state.state if isinstance(state, State) else state
|
self.storage[key].state = state.state if isinstance(state, State) else state
|
||||||
|
|
||||||
async def get_state(self, key: StorageKey) -> Optional[str]:
|
async def get_state(self, key: StorageKey) -> str | None:
|
||||||
return self.storage[key].state
|
return self.storage[key].state
|
||||||
|
|
||||||
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise DataNotDictLikeError(
|
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
||||||
f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
raise DataNotDictLikeError(msg)
|
||||||
)
|
|
||||||
self.storage[key].data = data.copy()
|
self.storage[key].data = data.copy()
|
||||||
|
|
||||||
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
|
async def get_data(self, key: StorageKey) -> dict[str, Any]:
|
||||||
return self.storage[key].data.copy()
|
return self.storage[key].data.copy()
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]: ...
|
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Any | None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
|
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
|
||||||
|
|
||||||
async def get_value(
|
async def get_value(
|
||||||
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
|
self,
|
||||||
) -> Optional[Any]:
|
storage_key: StorageKey,
|
||||||
|
dict_key: str,
|
||||||
|
default: Any | None = None,
|
||||||
|
) -> Any | None:
|
||||||
data = self.storage[storage_key].data
|
data = self.storage[storage_key].data
|
||||||
return copy(data.get(dict_key, default))
|
return copy(data.get(dict_key, default))
|
||||||
|
|
||||||
|
|
@ -89,7 +83,7 @@ class DisabledEventIsolation(BaseEventIsolation):
|
||||||
class SimpleEventIsolation(BaseEventIsolation):
|
class SimpleEventIsolation(BaseEventIsolation):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# TODO: Unused locks cleaner is needed
|
# TODO: Unused locks cleaner is needed
|
||||||
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
|
self._locks: defaultdict[Hashable, Lock] = defaultdict(Lock)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
|
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Mapping, Optional, cast
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
|
@ -27,7 +28,7 @@ class MongoStorage(BaseStorage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: AsyncIOMotorClient,
|
client: AsyncIOMotorClient,
|
||||||
key_builder: Optional[KeyBuilder] = None,
|
key_builder: KeyBuilder | None = None,
|
||||||
db_name: str = "aiogram_fsm",
|
db_name: str = "aiogram_fsm",
|
||||||
collection_name: str = "states_and_data",
|
collection_name: str = "states_and_data",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -46,7 +47,10 @@ class MongoStorage(BaseStorage):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(
|
def from_url(
|
||||||
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
|
cls,
|
||||||
|
url: str,
|
||||||
|
connection_kwargs: dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> "MongoStorage":
|
) -> "MongoStorage":
|
||||||
"""
|
"""
|
||||||
Create an instance of :class:`MongoStorage` with specifying the connection string
|
Create an instance of :class:`MongoStorage` with specifying the connection string
|
||||||
|
|
@ -65,7 +69,7 @@ class MongoStorage(BaseStorage):
|
||||||
"""Cleanup client resources and disconnect from MongoDB."""
|
"""Cleanup client resources and disconnect from MongoDB."""
|
||||||
self._client.close()
|
self._client.close()
|
||||||
|
|
||||||
def resolve_state(self, value: StateType) -> Optional[str]:
|
def resolve_state(self, value: StateType) -> str | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(value, State):
|
if isinstance(value, State):
|
||||||
|
|
@ -90,7 +94,7 @@ class MongoStorage(BaseStorage):
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_state(self, key: StorageKey) -> Optional[str]:
|
async def get_state(self, key: StorageKey) -> str | None:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
document = await self._collection.find_one({"_id": document_id})
|
document = await self._collection.find_one({"_id": document_id})
|
||||||
if document is None:
|
if document is None:
|
||||||
|
|
@ -99,9 +103,8 @@ class MongoStorage(BaseStorage):
|
||||||
|
|
||||||
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise DataNotDictLikeError(
|
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
||||||
f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
raise DataNotDictLikeError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
if not data:
|
if not data:
|
||||||
|
|
@ -120,14 +123,14 @@ class MongoStorage(BaseStorage):
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
|
async def get_data(self, key: StorageKey) -> dict[str, Any]:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
document = await self._collection.find_one({"_id": document_id})
|
document = await self._collection.find_one({"_id": document_id})
|
||||||
if document is None or not document.get("data"):
|
if document is None or not document.get("data"):
|
||||||
return {}
|
return {}
|
||||||
return cast(Dict[str, Any], document["data"])
|
return cast(dict[str, Any], document["data"])
|
||||||
|
|
||||||
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
|
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
update_with = {f"data.{key}": value for key, value in data.items()}
|
update_with = {f"data.{key}": value for key, value in data.items()}
|
||||||
update_result = await self._collection.find_one_and_update(
|
update_result = await self._collection.find_one_and_update(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Mapping, Optional, cast
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from pymongo import AsyncMongoClient
|
from pymongo import AsyncMongoClient
|
||||||
|
|
||||||
|
|
@ -21,7 +22,7 @@ class PyMongoStorage(BaseStorage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: AsyncMongoClient[Any],
|
client: AsyncMongoClient[Any],
|
||||||
key_builder: Optional[KeyBuilder] = None,
|
key_builder: KeyBuilder | None = None,
|
||||||
db_name: str = "aiogram_fsm",
|
db_name: str = "aiogram_fsm",
|
||||||
collection_name: str = "states_and_data",
|
collection_name: str = "states_and_data",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -40,7 +41,10 @@ class PyMongoStorage(BaseStorage):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(
|
def from_url(
|
||||||
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
|
cls,
|
||||||
|
url: str,
|
||||||
|
connection_kwargs: dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> "PyMongoStorage":
|
) -> "PyMongoStorage":
|
||||||
"""
|
"""
|
||||||
Create an instance of :class:`PyMongoStorage` with specifying the connection string
|
Create an instance of :class:`PyMongoStorage` with specifying the connection string
|
||||||
|
|
@ -59,7 +63,7 @@ class PyMongoStorage(BaseStorage):
|
||||||
"""Cleanup client resources and disconnect from MongoDB."""
|
"""Cleanup client resources and disconnect from MongoDB."""
|
||||||
return await self._client.close()
|
return await self._client.close()
|
||||||
|
|
||||||
def resolve_state(self, value: StateType) -> Optional[str]:
|
def resolve_state(self, value: StateType) -> str | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(value, State):
|
if isinstance(value, State):
|
||||||
|
|
@ -84,18 +88,17 @@ class PyMongoStorage(BaseStorage):
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_state(self, key: StorageKey) -> Optional[str]:
|
async def get_state(self, key: StorageKey) -> str | None:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
document = await self._collection.find_one({"_id": document_id})
|
document = await self._collection.find_one({"_id": document_id})
|
||||||
if document is None:
|
if document is None:
|
||||||
return None
|
return None
|
||||||
return cast(Optional[str], document.get("state"))
|
return cast(str | None, document.get("state"))
|
||||||
|
|
||||||
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise DataNotDictLikeError(
|
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
||||||
f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
raise DataNotDictLikeError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
if not data:
|
if not data:
|
||||||
|
|
@ -114,14 +117,14 @@ class PyMongoStorage(BaseStorage):
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
|
async def get_data(self, key: StorageKey) -> dict[str, Any]:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
document = await self._collection.find_one({"_id": document_id})
|
document = await self._collection.find_one({"_id": document_id})
|
||||||
if document is None or not document.get("data"):
|
if document is None or not document.get("data"):
|
||||||
return {}
|
return {}
|
||||||
return cast(Dict[str, Any], document["data"])
|
return cast(dict[str, Any], document["data"])
|
||||||
|
|
||||||
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> Dict[str, Any]:
|
async def update_data(self, key: StorageKey, data: Mapping[str, Any]) -> dict[str, Any]:
|
||||||
document_id = self._key_builder.build(key)
|
document_id = self._key_builder.build(key)
|
||||||
update_with = {f"data.{key}": value for key, value in data.items()}
|
update_with = {f"data.{key}": value for key, value in data.items()}
|
||||||
update_result = await self._collection.find_one_and_update(
|
update_result = await self._collection.find_one_and_update(
|
||||||
|
|
@ -133,4 +136,4 @@ class PyMongoStorage(BaseStorage):
|
||||||
)
|
)
|
||||||
if not update_result:
|
if not update_result:
|
||||||
await self._collection.delete_one({"_id": document_id})
|
await self._collection.delete_one({"_id": document_id})
|
||||||
return cast(Dict[str, Any], update_result.get("data", {}))
|
return cast(dict[str, Any], update_result.get("data", {}))
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from redis.asyncio.client import Redis
|
from redis.asyncio.client import Redis
|
||||||
from redis.asyncio.connection import ConnectionPool
|
from redis.asyncio.connection import ConnectionPool
|
||||||
|
|
@ -31,9 +32,9 @@ class RedisStorage(BaseStorage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
key_builder: Optional[KeyBuilder] = None,
|
key_builder: KeyBuilder | None = None,
|
||||||
state_ttl: Optional[ExpiryT] = None,
|
state_ttl: ExpiryT | None = None,
|
||||||
data_ttl: Optional[ExpiryT] = None,
|
data_ttl: ExpiryT | None = None,
|
||||||
json_loads: _JsonLoads = json.loads,
|
json_loads: _JsonLoads = json.loads,
|
||||||
json_dumps: _JsonDumps = json.dumps,
|
json_dumps: _JsonDumps = json.dumps,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -54,7 +55,10 @@ class RedisStorage(BaseStorage):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(
|
def from_url(
|
||||||
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
|
cls,
|
||||||
|
url: str,
|
||||||
|
connection_kwargs: dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> "RedisStorage":
|
) -> "RedisStorage":
|
||||||
"""
|
"""
|
||||||
Create an instance of :class:`RedisStorage` with specifying the connection string
|
Create an instance of :class:`RedisStorage` with specifying the connection string
|
||||||
|
|
@ -94,12 +98,12 @@ class RedisStorage(BaseStorage):
|
||||||
async def get_state(
|
async def get_state(
|
||||||
self,
|
self,
|
||||||
key: StorageKey,
|
key: StorageKey,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
redis_key = self.key_builder.build(key, "state")
|
redis_key = self.key_builder.build(key, "state")
|
||||||
value = await self.redis.get(redis_key)
|
value = await self.redis.get(redis_key)
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
return value.decode("utf-8")
|
return value.decode("utf-8")
|
||||||
return cast(Optional[str], value)
|
return cast(str | None, value)
|
||||||
|
|
||||||
async def set_data(
|
async def set_data(
|
||||||
self,
|
self,
|
||||||
|
|
@ -107,9 +111,8 @@ class RedisStorage(BaseStorage):
|
||||||
data: Mapping[str, Any],
|
data: Mapping[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise DataNotDictLikeError(
|
msg = f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
||||||
f"Data must be a dict or dict-like object, got {type(data).__name__}"
|
raise DataNotDictLikeError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
redis_key = self.key_builder.build(key, "data")
|
redis_key = self.key_builder.build(key, "data")
|
||||||
if not data:
|
if not data:
|
||||||
|
|
@ -124,22 +127,22 @@ class RedisStorage(BaseStorage):
|
||||||
async def get_data(
|
async def get_data(
|
||||||
self,
|
self,
|
||||||
key: StorageKey,
|
key: StorageKey,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
redis_key = self.key_builder.build(key, "data")
|
redis_key = self.key_builder.build(key, "data")
|
||||||
value = await self.redis.get(redis_key)
|
value = await self.redis.get(redis_key)
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
return cast(Dict[str, Any], self.json_loads(value))
|
return cast(dict[str, Any], self.json_loads(value))
|
||||||
|
|
||||||
|
|
||||||
class RedisEventIsolation(BaseEventIsolation):
|
class RedisEventIsolation(BaseEventIsolation):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
key_builder: Optional[KeyBuilder] = None,
|
key_builder: KeyBuilder | None = None,
|
||||||
lock_kwargs: Optional[Dict[str, Any]] = None,
|
lock_kwargs: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if key_builder is None:
|
if key_builder is None:
|
||||||
key_builder = DefaultKeyBuilder()
|
key_builder = DefaultKeyBuilder()
|
||||||
|
|
@ -153,7 +156,7 @@ class RedisEventIsolation(BaseEventIsolation):
|
||||||
def from_url(
|
def from_url(
|
||||||
cls,
|
cls,
|
||||||
url: str,
|
url: str,
|
||||||
connection_kwargs: Optional[Dict[str, Any]] = None,
|
connection_kwargs: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "RedisEventIsolation":
|
) -> "RedisEventIsolation":
|
||||||
if connection_kwargs is None:
|
if connection_kwargs is None:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class FSMStrategy(Enum):
|
class FSMStrategy(Enum):
|
||||||
|
|
@ -23,8 +22,8 @@ def apply_strategy(
|
||||||
strategy: FSMStrategy,
|
strategy: FSMStrategy,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
thread_id: Optional[int] = None,
|
thread_id: int | None = None,
|
||||||
) -> Tuple[int, int, Optional[int]]:
|
) -> tuple[int, int, int | None]:
|
||||||
if strategy == FSMStrategy.CHAT:
|
if strategy == FSMStrategy.CHAT:
|
||||||
return chat_id, chat_id, None
|
return chat_id, chat_id, None
|
||||||
if strategy == FSMStrategy.GLOBAL_USER:
|
if strategy == FSMStrategy.GLOBAL_USER:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
from aiogram.types import Update
|
from aiogram.types import Update
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ T = TypeVar("T")
|
||||||
class BaseHandlerMixin(Generic[T]):
|
class BaseHandlerMixin(Generic[T]):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
event: T
|
event: T
|
||||||
data: Dict[str, Any]
|
data: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class BaseHandler(BaseHandlerMixin[T], ABC):
|
class BaseHandler(BaseHandlerMixin[T], ABC):
|
||||||
|
|
@ -24,7 +24,7 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
|
||||||
|
|
||||||
def __init__(self, event: T, **kwargs: Any) -> None:
|
def __init__(self, event: T, **kwargs: Any) -> None:
|
||||||
self.event: T = event
|
self.event: T = event
|
||||||
self.data: Dict[str, Any] = kwargs
|
self.data: dict[str, Any] = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> Bot:
|
def bot(self) -> Bot:
|
||||||
|
|
@ -32,7 +32,8 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
|
||||||
|
|
||||||
if "bot" in self.data:
|
if "bot" in self.data:
|
||||||
return cast(Bot, self.data["bot"])
|
return cast(Bot, self.data["bot"])
|
||||||
raise RuntimeError("Bot instance not found in the context")
|
msg = "Bot instance not found in the context"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def update(self) -> Update:
|
def update(self) -> Update:
|
||||||
|
|
|
||||||
|
|
@ -29,14 +29,14 @@ class CallbackQueryHandler(BaseHandler[CallbackQuery], ABC):
|
||||||
return self.event.from_user
|
return self.event.from_user
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message(self) -> Optional[MaybeInaccessibleMessage]:
|
def message(self) -> MaybeInaccessibleMessage | None:
|
||||||
"""
|
"""
|
||||||
Is alias for `event.message`
|
Is alias for `event.message`
|
||||||
"""
|
"""
|
||||||
return self.event.message
|
return self.event.message
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def callback_data(self) -> Optional[str]:
|
def callback_data(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Is alias for `event.data`
|
Is alias for `event.data`
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class MessageHandler(BaseHandler[Message], ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def from_user(self) -> Optional[User]:
|
def from_user(self) -> User | None:
|
||||||
return self.event.from_user
|
return self.event.from_user
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -22,7 +22,7 @@ class MessageHandler(BaseHandler[Message], ABC):
|
||||||
|
|
||||||
class MessageHandlerCommandMixin(BaseHandlerMixin[Message]):
|
class MessageHandlerCommandMixin(BaseHandlerMixin[Message]):
|
||||||
@property
|
@property
|
||||||
def command(self) -> Optional[CommandObject]:
|
def command(self) -> CommandObject | None:
|
||||||
if "command" in self.data:
|
if "command" in self.data:
|
||||||
return cast(CommandObject, self.data["command"])
|
return cast(CommandObject, self.data["command"])
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from aiogram.handlers import BaseHandler
|
from aiogram.handlers import BaseHandler
|
||||||
from aiogram.types import Poll, PollOption
|
from aiogram.types import Poll, PollOption
|
||||||
|
|
@ -15,5 +14,5 @@ class PollHandler(BaseHandler[Poll], ABC):
|
||||||
return self.event.question
|
return self.event.question
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def options(self) -> List[PollOption]:
|
def options(self) -> list[PollOption]:
|
||||||
return self.event.options
|
return self.event.options
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
|
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
|
||||||
|
|
@ -17,12 +17,14 @@ def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
|
||||||
secret = hashlib.sha256(token.encode("utf-8"))
|
secret = hashlib.sha256(token.encode("utf-8"))
|
||||||
check_string = "\n".join(f"{k}={kwargs[k]}" for k in sorted(kwargs))
|
check_string = "\n".join(f"{k}={kwargs[k]}" for k in sorted(kwargs))
|
||||||
hmac_string = hmac.new(
|
hmac_string = hmac.new(
|
||||||
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
|
secret.digest(),
|
||||||
|
check_string.encode("utf-8"),
|
||||||
|
digestmod=hashlib.sha256,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
return hmac_string == hash
|
return hmac_string == hash
|
||||||
|
|
||||||
|
|
||||||
def check_integrity(token: str, data: Dict[str, Any]) -> bool:
|
def check_integrity(token: str, data: dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
Verify the authentication and the integrity
|
Verify the authentication and the integrity
|
||||||
of the data received on user's auth
|
of the data received on user's auth
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,11 @@ class BackoffConfig:
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.max_delay <= self.min_delay:
|
if self.max_delay <= self.min_delay:
|
||||||
raise ValueError("`max_delay` should be greater than `min_delay`")
|
msg = "`max_delay` should be greater than `min_delay`"
|
||||||
|
raise ValueError(msg)
|
||||||
if self.factor <= 1:
|
if self.factor <= 1:
|
||||||
raise ValueError("`factor` should be greater than 1")
|
msg = "`factor` should be greater than 1"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
class Backoff:
|
class Backoff:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Union
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from aiogram import BaseMiddleware, loggers
|
from aiogram import BaseMiddleware, loggers
|
||||||
from aiogram.dispatcher.flags import get_flag
|
from aiogram.dispatcher.flags import get_flag
|
||||||
|
|
@ -12,10 +13,10 @@ class CallbackAnswer:
|
||||||
self,
|
self,
|
||||||
answered: bool,
|
answered: bool,
|
||||||
disabled: bool = False,
|
disabled: bool = False,
|
||||||
text: Optional[str] = None,
|
text: str | None = None,
|
||||||
show_alert: Optional[bool] = None,
|
show_alert: bool | None = None,
|
||||||
url: Optional[str] = None,
|
url: str | None = None,
|
||||||
cache_time: Optional[int] = None,
|
cache_time: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Callback answer configuration
|
Callback answer configuration
|
||||||
|
|
@ -48,7 +49,8 @@ class CallbackAnswer:
|
||||||
@disabled.setter
|
@disabled.setter
|
||||||
def disabled(self, value: bool) -> None:
|
def disabled(self, value: bool) -> None:
|
||||||
if self._answered:
|
if self._answered:
|
||||||
raise CallbackAnswerException("Can't change disabled state after answer")
|
msg = "Can't change disabled state after answer"
|
||||||
|
raise CallbackAnswerException(msg)
|
||||||
self._disabled = value
|
self._disabled = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -59,7 +61,7 @@ class CallbackAnswer:
|
||||||
return self._answered
|
return self._answered
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Response text
|
Response text
|
||||||
:return:
|
:return:
|
||||||
|
|
@ -67,48 +69,52 @@ class CallbackAnswer:
|
||||||
return self._text
|
return self._text
|
||||||
|
|
||||||
@text.setter
|
@text.setter
|
||||||
def text(self, value: Optional[str]) -> None:
|
def text(self, value: str | None) -> None:
|
||||||
if self._answered:
|
if self._answered:
|
||||||
raise CallbackAnswerException("Can't change text after answer")
|
msg = "Can't change text after answer"
|
||||||
|
raise CallbackAnswerException(msg)
|
||||||
self._text = value
|
self._text = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def show_alert(self) -> Optional[bool]:
|
def show_alert(self) -> bool | None:
|
||||||
"""
|
"""
|
||||||
Whether to display an alert
|
Whether to display an alert
|
||||||
"""
|
"""
|
||||||
return self._show_alert
|
return self._show_alert
|
||||||
|
|
||||||
@show_alert.setter
|
@show_alert.setter
|
||||||
def show_alert(self, value: Optional[bool]) -> None:
|
def show_alert(self, value: bool | None) -> None:
|
||||||
if self._answered:
|
if self._answered:
|
||||||
raise CallbackAnswerException("Can't change show_alert after answer")
|
msg = "Can't change show_alert after answer"
|
||||||
|
raise CallbackAnswerException(msg)
|
||||||
self._show_alert = value
|
self._show_alert = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> Optional[str]:
|
def url(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Game url
|
Game url
|
||||||
"""
|
"""
|
||||||
return self._url
|
return self._url
|
||||||
|
|
||||||
@url.setter
|
@url.setter
|
||||||
def url(self, value: Optional[str]) -> None:
|
def url(self, value: str | None) -> None:
|
||||||
if self._answered:
|
if self._answered:
|
||||||
raise CallbackAnswerException("Can't change url after answer")
|
msg = "Can't change url after answer"
|
||||||
|
raise CallbackAnswerException(msg)
|
||||||
self._url = value
|
self._url = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_time(self) -> Optional[int]:
|
def cache_time(self) -> int | None:
|
||||||
"""
|
"""
|
||||||
Response cache time
|
Response cache time
|
||||||
"""
|
"""
|
||||||
return self._cache_time
|
return self._cache_time
|
||||||
|
|
||||||
@cache_time.setter
|
@cache_time.setter
|
||||||
def cache_time(self, value: Optional[int]) -> None:
|
def cache_time(self, value: int | None) -> None:
|
||||||
if self._answered:
|
if self._answered:
|
||||||
raise CallbackAnswerException("Can't change cache_time after answer")
|
msg = "Can't change cache_time after answer"
|
||||||
|
raise CallbackAnswerException(msg)
|
||||||
self._cache_time = value
|
self._cache_time = value
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
@ -131,10 +137,10 @@ class CallbackAnswerMiddleware(BaseMiddleware):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pre: bool = False,
|
pre: bool = False,
|
||||||
text: Optional[str] = None,
|
text: str | None = None,
|
||||||
show_alert: Optional[bool] = None,
|
show_alert: bool | None = None,
|
||||||
url: Optional[str] = None,
|
url: str | None = None,
|
||||||
cache_time: Optional[int] = None,
|
cache_time: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Inner middleware for callback query handlers, can be useful in bots with a lot of callback
|
Inner middleware for callback query handlers, can be useful in bots with a lot of callback
|
||||||
|
|
@ -154,15 +160,15 @@ class CallbackAnswerMiddleware(BaseMiddleware):
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not isinstance(event, CallbackQuery):
|
if not isinstance(event, CallbackQuery):
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
||||||
callback_answer = data["callback_answer"] = self.construct_callback_answer(
|
callback_answer = data["callback_answer"] = self.construct_callback_answer(
|
||||||
properties=get_flag(data, "callback_answer")
|
properties=get_flag(data, "callback_answer"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not callback_answer.disabled and callback_answer.answered:
|
if not callback_answer.disabled and callback_answer.answered:
|
||||||
|
|
@ -174,7 +180,8 @@ class CallbackAnswerMiddleware(BaseMiddleware):
|
||||||
await self.answer(event, callback_answer)
|
await self.answer(event, callback_answer)
|
||||||
|
|
||||||
def construct_callback_answer(
|
def construct_callback_answer(
|
||||||
self, properties: Optional[Union[Dict[str, Any], bool]]
|
self,
|
||||||
|
properties: dict[str, Any] | bool | None,
|
||||||
) -> CallbackAnswer:
|
) -> CallbackAnswer:
|
||||||
pre, disabled, text, show_alert, url, cache_time = (
|
pre, disabled, text, show_alert, url, cache_time = (
|
||||||
self.pre,
|
self.pre,
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,10 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from asyncio import Event, Lock
|
from asyncio import Event, Lock
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union
|
from typing import Any
|
||||||
|
|
||||||
from aiogram import BaseMiddleware, Bot
|
from aiogram import BaseMiddleware, Bot
|
||||||
from aiogram.dispatcher.flags import get_flag
|
from aiogram.dispatcher.flags import get_flag
|
||||||
|
|
@ -32,8 +33,8 @@ class ChatActionSender:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
chat_id: Union[str, int],
|
chat_id: str | int,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
action: str = "typing",
|
action: str = "typing",
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
|
|
@ -56,7 +57,7 @@ class ChatActionSender:
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
self._close_event = Event()
|
self._close_event = Event()
|
||||||
self._closed_event = Event()
|
self._closed_event = Event()
|
||||||
self._task: Optional[asyncio.Task[Any]] = None
|
self._task: asyncio.Task[Any] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
|
|
@ -108,7 +109,8 @@ class ChatActionSender:
|
||||||
self._close_event.clear()
|
self._close_event.clear()
|
||||||
self._closed_event.clear()
|
self._closed_event.clear()
|
||||||
if self.running:
|
if self.running:
|
||||||
raise RuntimeError("Already running")
|
msg = "Already running"
|
||||||
|
raise RuntimeError(msg)
|
||||||
self._task = asyncio.create_task(self._worker())
|
self._task = asyncio.create_task(self._worker())
|
||||||
|
|
||||||
async def _stop(self) -> None:
|
async def _stop(self) -> None:
|
||||||
|
|
@ -126,18 +128,18 @@ class ChatActionSender:
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Optional[Type[BaseException]],
|
exc_type: type[BaseException] | None,
|
||||||
exc_value: Optional[BaseException],
|
exc_value: BaseException | None,
|
||||||
traceback: Optional[TracebackType],
|
traceback: TracebackType | None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
await self._stop()
|
await self._stop()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def typing(
|
def typing(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -154,9 +156,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def upload_photo(
|
def upload_photo(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -173,9 +175,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def record_video(
|
def record_video(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -192,9 +194,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def upload_video(
|
def upload_video(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -211,9 +213,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def record_voice(
|
def record_voice(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -230,9 +232,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def upload_voice(
|
def upload_voice(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -249,9 +251,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def upload_document(
|
def upload_document(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -268,9 +270,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def choose_sticker(
|
def choose_sticker(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -287,9 +289,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_location(
|
def find_location(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -306,9 +308,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def record_video_note(
|
def record_video_note(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -325,9 +327,9 @@ class ChatActionSender:
|
||||||
@classmethod
|
@classmethod
|
||||||
def upload_video_note(
|
def upload_video_note(
|
||||||
cls,
|
cls,
|
||||||
chat_id: Union[int, str],
|
chat_id: int | str,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
message_thread_id: Optional[int] = None,
|
message_thread_id: int | None = None,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
|
||||||
) -> "ChatActionSender":
|
) -> "ChatActionSender":
|
||||||
|
|
@ -349,9 +351,9 @@ class ChatActionMiddleware(BaseMiddleware):
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not isinstance(event, Message):
|
if not isinstance(event, Message):
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Tuple, Type, Union
|
from typing import Annotated
|
||||||
|
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from aiogram.types import (
|
from aiogram.types import (
|
||||||
ChatMember,
|
ChatMember,
|
||||||
|
|
@ -13,22 +12,22 @@ from aiogram.types import (
|
||||||
ChatMemberRestricted,
|
ChatMemberRestricted,
|
||||||
)
|
)
|
||||||
|
|
||||||
ChatMemberUnion = Union[
|
ChatMemberUnion = (
|
||||||
ChatMemberOwner,
|
ChatMemberOwner
|
||||||
ChatMemberAdministrator,
|
| ChatMemberAdministrator
|
||||||
ChatMemberMember,
|
| ChatMemberMember
|
||||||
ChatMemberRestricted,
|
| ChatMemberRestricted
|
||||||
ChatMemberLeft,
|
| ChatMemberLeft
|
||||||
ChatMemberBanned,
|
| ChatMemberBanned
|
||||||
]
|
)
|
||||||
|
|
||||||
ChatMemberCollection = Tuple[Type[ChatMember], ...]
|
ChatMemberCollection = tuple[type[ChatMember], ...]
|
||||||
|
|
||||||
ChatMemberAdapter: TypeAdapter[ChatMemberUnion] = TypeAdapter(
|
ChatMemberAdapter: TypeAdapter[ChatMemberUnion] = TypeAdapter(
|
||||||
Annotated[
|
Annotated[
|
||||||
ChatMemberUnion,
|
ChatMemberUnion,
|
||||||
Field(discriminator="status"),
|
Field(discriminator="status"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
ADMINS: ChatMemberCollection = (ChatMemberOwner, ChatMemberAdministrator)
|
ADMINS: ChatMemberCollection = (ChatMemberOwner, ChatMemberAdministrator)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections.abc import Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Generator, NamedTuple, Protocol
|
from typing import Any, NamedTuple, Protocol
|
||||||
|
|
||||||
from aiogram.utils.dataclass import dataclass_kwargs
|
from aiogram.utils.dataclass import dataclass_kwargs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,16 +9,16 @@ from typing import Any, Union
|
||||||
|
|
||||||
|
|
||||||
def dataclass_kwargs(
|
def dataclass_kwargs(
|
||||||
init: Union[bool, None] = None,
|
init: bool | None = None,
|
||||||
repr: Union[bool, None] = None,
|
repr: bool | None = None,
|
||||||
eq: Union[bool, None] = None,
|
eq: bool | None = None,
|
||||||
order: Union[bool, None] = None,
|
order: bool | None = None,
|
||||||
unsafe_hash: Union[bool, None] = None,
|
unsafe_hash: bool | None = None,
|
||||||
frozen: Union[bool, None] = None,
|
frozen: bool | None = None,
|
||||||
match_args: Union[bool, None] = None,
|
match_args: bool | None = None,
|
||||||
kw_only: Union[bool, None] = None,
|
kw_only: bool | None = None,
|
||||||
slots: Union[bool, None] = None,
|
slots: bool | None = None,
|
||||||
weakref_slot: Union[bool, None] = None,
|
weakref_slot: bool | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generates a dictionary of keyword arguments that can be passed to a Python
|
Generates a dictionary of keyword arguments that can be passed to a Python
|
||||||
|
|
@ -48,13 +48,12 @@ def dataclass_kwargs(
|
||||||
params["frozen"] = frozen
|
params["frozen"] = frozen
|
||||||
|
|
||||||
# Added in 3.10
|
# Added in 3.10
|
||||||
if sys.version_info >= (3, 10):
|
if match_args is not None:
|
||||||
if match_args is not None:
|
params["match_args"] = match_args
|
||||||
params["match_args"] = match_args
|
if kw_only is not None:
|
||||||
if kw_only is not None:
|
params["kw_only"] = kw_only
|
||||||
params["kw_only"] = kw_only
|
if slots is not None:
|
||||||
if slots is not None:
|
params["slots"] = slots
|
||||||
params["slots"] = slots
|
|
||||||
|
|
||||||
# Added in 3.11
|
# Added in 3.11
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,24 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_start_link",
|
|
||||||
"create_startgroup_link",
|
|
||||||
"create_startapp_link",
|
|
||||||
"create_deep_link",
|
"create_deep_link",
|
||||||
|
"create_start_link",
|
||||||
|
"create_startapp_link",
|
||||||
|
"create_startgroup_link",
|
||||||
"create_telegram_link",
|
"create_telegram_link",
|
||||||
"encode_payload",
|
|
||||||
"decode_payload",
|
"decode_payload",
|
||||||
|
"encode_payload",
|
||||||
]
|
]
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||||
|
|
||||||
from aiogram.utils.link import create_telegram_link
|
from aiogram.utils.link import create_telegram_link
|
||||||
from aiogram.utils.payload import decode_payload, encode_payload
|
from aiogram.utils.payload import decode_payload, encode_payload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
|
|
||||||
BAD_PATTERN = re.compile(r"[^a-zA-Z0-9-_]")
|
BAD_PATTERN = re.compile(r"[^a-zA-Z0-9-_]")
|
||||||
|
|
@ -26,7 +28,7 @@ async def create_start_link(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
payload: str,
|
payload: str,
|
||||||
encode: bool = False,
|
encode: bool = False,
|
||||||
encoder: Optional[Callable[[bytes], bytes]] = None,
|
encoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create 'start' deep link with your payload.
|
Create 'start' deep link with your payload.
|
||||||
|
|
@ -53,7 +55,7 @@ async def create_startgroup_link(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
payload: str,
|
payload: str,
|
||||||
encode: bool = False,
|
encode: bool = False,
|
||||||
encoder: Optional[Callable[[bytes], bytes]] = None,
|
encoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create 'startgroup' deep link with your payload.
|
Create 'startgroup' deep link with your payload.
|
||||||
|
|
@ -80,8 +82,8 @@ async def create_startapp_link(
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
payload: str,
|
payload: str,
|
||||||
encode: bool = False,
|
encode: bool = False,
|
||||||
app_name: Optional[str] = None,
|
app_name: str | None = None,
|
||||||
encoder: Optional[Callable[[bytes], bytes]] = None,
|
encoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create 'startapp' deep link with your payload.
|
Create 'startapp' deep link with your payload.
|
||||||
|
|
@ -115,9 +117,9 @@ def create_deep_link(
|
||||||
username: str,
|
username: str,
|
||||||
link_type: Literal["start", "startgroup", "startapp"],
|
link_type: Literal["start", "startgroup", "startapp"],
|
||||||
payload: str,
|
payload: str,
|
||||||
app_name: Optional[str] = None,
|
app_name: str | None = None,
|
||||||
encode: bool = False,
|
encode: bool = False,
|
||||||
encoder: Optional[Callable[[bytes], bytes]] = None,
|
encoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create deep link.
|
Create deep link.
|
||||||
|
|
@ -137,13 +139,15 @@ def create_deep_link(
|
||||||
payload = encode_payload(payload, encoder=encoder)
|
payload = encode_payload(payload, encoder=encoder)
|
||||||
|
|
||||||
if re.search(BAD_PATTERN, payload):
|
if re.search(BAD_PATTERN, payload):
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
|
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
|
||||||
"Pass `encode=True` or encode payload manually."
|
"Pass `encode=True` or encode payload manually."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if len(payload) > 64:
|
if len(payload) > 64:
|
||||||
raise ValueError("Payload must be up to 64 characters long.")
|
msg = "Payload must be up to 64 characters long."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if not app_name:
|
if not app_name:
|
||||||
deep_link = create_telegram_link(username, **{cast(str, link_type): payload})
|
deep_link = create_telegram_link(username, **{cast(str, link_type): payload})
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import (
|
from collections.abc import Generator, Iterable, Iterator
|
||||||
Any,
|
from typing import TYPE_CHECKING, Any, ClassVar
|
||||||
ClassVar,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
@ -35,7 +27,7 @@ class Text(Iterable[NodeType]):
|
||||||
Simple text element
|
Simple text element
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ClassVar[Optional[str]] = None
|
type: ClassVar[str | None] = None
|
||||||
|
|
||||||
__slots__ = ("_body", "_params")
|
__slots__ = ("_body", "_params")
|
||||||
|
|
||||||
|
|
@ -44,16 +36,16 @@ class Text(Iterable[NodeType]):
|
||||||
*body: NodeType,
|
*body: NodeType,
|
||||||
**params: Any,
|
**params: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._body: Tuple[NodeType, ...] = body
|
self._body: tuple[NodeType, ...] = body
|
||||||
self._params: Dict[str, Any] = params
|
self._params: dict[str, Any] = params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_entities(cls, text: str, entities: List[MessageEntity]) -> "Text":
|
def from_entities(cls, text: str, entities: list[MessageEntity]) -> Text:
|
||||||
return cls(
|
return cls(
|
||||||
*_unparse_entities(
|
*_unparse_entities(
|
||||||
text=add_surrogates(text),
|
text=add_surrogates(text),
|
||||||
entities=sorted(entities, key=lambda item: item.offset) if entities else [],
|
entities=sorted(entities, key=lambda item: item.offset) if entities else [],
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(
|
def render(
|
||||||
|
|
@ -62,7 +54,7 @@ class Text(Iterable[NodeType]):
|
||||||
_offset: int = 0,
|
_offset: int = 0,
|
||||||
_sort: bool = True,
|
_sort: bool = True,
|
||||||
_collect_entities: bool = True,
|
_collect_entities: bool = True,
|
||||||
) -> Tuple[str, List[MessageEntity]]:
|
) -> tuple[str, list[MessageEntity]]:
|
||||||
"""
|
"""
|
||||||
Render elements tree as text with entities list
|
Render elements tree as text with entities list
|
||||||
|
|
||||||
|
|
@ -108,7 +100,7 @@ class Text(Iterable[NodeType]):
|
||||||
entities_key: str = "entities",
|
entities_key: str = "entities",
|
||||||
replace_parse_mode: bool = True,
|
replace_parse_mode: bool = True,
|
||||||
parse_mode_key: str = "parse_mode",
|
parse_mode_key: str = "parse_mode",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Render element tree as keyword arguments for usage in an API call, for example:
|
Render element tree as keyword arguments for usage in an API call, for example:
|
||||||
|
|
||||||
|
|
@ -124,7 +116,7 @@ class Text(Iterable[NodeType]):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
text_value, entities_value = self.render()
|
text_value, entities_value = self.render()
|
||||||
result: Dict[str, Any] = {
|
result: dict[str, Any] = {
|
||||||
text_key: text_value,
|
text_key: text_value,
|
||||||
entities_key: entities_value,
|
entities_key: entities_value,
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +124,7 @@ class Text(Iterable[NodeType]):
|
||||||
result[parse_mode_key] = None
|
result[parse_mode_key] = None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
|
def as_caption_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Shortcut for :meth:`as_kwargs` for usage with API calls that take
|
Shortcut for :meth:`as_kwargs` for usage with API calls that take
|
||||||
``caption`` as a parameter.
|
``caption`` as a parameter.
|
||||||
|
|
@ -151,7 +143,7 @@ class Text(Iterable[NodeType]):
|
||||||
replace_parse_mode=replace_parse_mode,
|
replace_parse_mode=replace_parse_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
|
def as_poll_question_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Shortcut for :meth:`as_kwargs` for usage with
|
Shortcut for :meth:`as_kwargs` for usage with
|
||||||
method :class:`aiogram.methods.send_poll.SendPoll`.
|
method :class:`aiogram.methods.send_poll.SendPoll`.
|
||||||
|
|
@ -171,7 +163,7 @@ class Text(Iterable[NodeType]):
|
||||||
replace_parse_mode=replace_parse_mode,
|
replace_parse_mode=replace_parse_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
|
def as_poll_explanation_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Shortcut for :meth:`as_kwargs` for usage with
|
Shortcut for :meth:`as_kwargs` for usage with
|
||||||
method :class:`aiogram.methods.send_poll.SendPoll`.
|
method :class:`aiogram.methods.send_poll.SendPoll`.
|
||||||
|
|
@ -196,7 +188,7 @@ class Text(Iterable[NodeType]):
|
||||||
replace_parse_mode=replace_parse_mode,
|
replace_parse_mode=replace_parse_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> Dict[str, Any]:
|
def as_gift_text_kwargs(self, *, replace_parse_mode: bool = True) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Shortcut for :meth:`as_kwargs` for usage with
|
Shortcut for :meth:`as_kwargs` for usage with
|
||||||
method :class:`aiogram.methods.send_gift.SendGift`.
|
method :class:`aiogram.methods.send_gift.SendGift`.
|
||||||
|
|
@ -252,7 +244,7 @@ class Text(Iterable[NodeType]):
|
||||||
args_str = textwrap.indent("\n" + args_str + "\n", " ")
|
args_str = textwrap.indent("\n" + args_str + "\n", " ")
|
||||||
return f"{type(self).__name__}({args_str})"
|
return f"{type(self).__name__}({args_str})"
|
||||||
|
|
||||||
def __add__(self, other: NodeType) -> "Text":
|
def __add__(self, other: NodeType) -> Text:
|
||||||
if isinstance(other, Text) and other.type == self.type and self._params == other._params:
|
if isinstance(other, Text) and other.type == self.type and self._params == other._params:
|
||||||
return type(self)(*self, *other, **self._params)
|
return type(self)(*self, *other, **self._params)
|
||||||
if type(self) is Text and isinstance(other, str):
|
if type(self) is Text and isinstance(other, str):
|
||||||
|
|
@ -266,9 +258,10 @@ class Text(Iterable[NodeType]):
|
||||||
text, _ = self.render(_collect_entities=False)
|
text, _ = self.render(_collect_entities=False)
|
||||||
return sizeof(text)
|
return sizeof(text)
|
||||||
|
|
||||||
def __getitem__(self, item: slice) -> "Text":
|
def __getitem__(self, item: slice) -> Text:
|
||||||
if not isinstance(item, slice):
|
if not isinstance(item, slice):
|
||||||
raise TypeError("Can only be sliced")
|
msg = "Can only be sliced"
|
||||||
|
raise TypeError(msg)
|
||||||
if (item.start is None or item.start == 0) and item.stop is None:
|
if (item.start is None or item.start == 0) and item.stop is None:
|
||||||
return self.replace(*self._body)
|
return self.replace(*self._body)
|
||||||
start = 0 if item.start is None else item.start
|
start = 0 if item.start is None else item.start
|
||||||
|
|
@ -313,9 +306,11 @@ class HashTag(Text):
|
||||||
|
|
||||||
def __init__(self, *body: NodeType, **params: Any) -> None:
|
def __init__(self, *body: NodeType, **params: Any) -> None:
|
||||||
if len(body) != 1:
|
if len(body) != 1:
|
||||||
raise ValueError("Hashtag can contain only one element")
|
msg = "Hashtag can contain only one element"
|
||||||
|
raise ValueError(msg)
|
||||||
if not isinstance(body[0], str):
|
if not isinstance(body[0], str):
|
||||||
raise ValueError("Hashtag can contain only string")
|
msg = "Hashtag can contain only string"
|
||||||
|
raise ValueError(msg)
|
||||||
if not body[0].startswith("#"):
|
if not body[0].startswith("#"):
|
||||||
body = ("#" + body[0],)
|
body = ("#" + body[0],)
|
||||||
super().__init__(*body, **params)
|
super().__init__(*body, **params)
|
||||||
|
|
@ -337,9 +332,11 @@ class CashTag(Text):
|
||||||
|
|
||||||
def __init__(self, *body: NodeType, **params: Any) -> None:
|
def __init__(self, *body: NodeType, **params: Any) -> None:
|
||||||
if len(body) != 1:
|
if len(body) != 1:
|
||||||
raise ValueError("Cashtag can contain only one element")
|
msg = "Cashtag can contain only one element"
|
||||||
|
raise ValueError(msg)
|
||||||
if not isinstance(body[0], str):
|
if not isinstance(body[0], str):
|
||||||
raise ValueError("Cashtag can contain only string")
|
msg = "Cashtag can contain only string"
|
||||||
|
raise ValueError(msg)
|
||||||
if not body[0].startswith("$"):
|
if not body[0].startswith("$"):
|
||||||
body = ("$" + body[0],)
|
body = ("$" + body[0],)
|
||||||
super().__init__(*body, **params)
|
super().__init__(*body, **params)
|
||||||
|
|
@ -469,7 +466,7 @@ class Pre(Text):
|
||||||
|
|
||||||
type = MessageEntityType.PRE
|
type = MessageEntityType.PRE
|
||||||
|
|
||||||
def __init__(self, *body: NodeType, language: Optional[str] = None, **params: Any) -> None:
|
def __init__(self, *body: NodeType, language: str | None = None, **params: Any) -> None:
|
||||||
super().__init__(*body, language=language, **params)
|
super().__init__(*body, language=language, **params)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -537,7 +534,7 @@ class ExpandableBlockQuote(Text):
|
||||||
type = MessageEntityType.EXPANDABLE_BLOCKQUOTE
|
type = MessageEntityType.EXPANDABLE_BLOCKQUOTE
|
||||||
|
|
||||||
|
|
||||||
NODE_TYPES: Dict[Optional[str], Type[Text]] = {
|
NODE_TYPES: dict[str | None, type[Text]] = {
|
||||||
Text.type: Text,
|
Text.type: Text,
|
||||||
HashTag.type: HashTag,
|
HashTag.type: HashTag,
|
||||||
CashTag.type: CashTag,
|
CashTag.type: CashTag,
|
||||||
|
|
@ -570,15 +567,16 @@ def _apply_entity(entity: MessageEntity, *nodes: NodeType) -> NodeType:
|
||||||
"""
|
"""
|
||||||
node_type = NODE_TYPES.get(entity.type, Text)
|
node_type = NODE_TYPES.get(entity.type, Text)
|
||||||
return node_type(
|
return node_type(
|
||||||
*nodes, **entity.model_dump(exclude={"type", "offset", "length"}, warnings=False)
|
*nodes,
|
||||||
|
**entity.model_dump(exclude={"type", "offset", "length"}, warnings=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _unparse_entities(
|
def _unparse_entities(
|
||||||
text: bytes,
|
text: bytes,
|
||||||
entities: List[MessageEntity],
|
entities: list[MessageEntity],
|
||||||
offset: Optional[int] = None,
|
offset: int | None = None,
|
||||||
length: Optional[int] = None,
|
length: int | None = None,
|
||||||
) -> Generator[NodeType, None, None]:
|
) -> Generator[NodeType, None, None]:
|
||||||
if offset is None:
|
if offset is None:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
@ -615,8 +613,7 @@ def as_line(*items: NodeType, end: str = "\n", sep: str = "") -> Text:
|
||||||
nodes = []
|
nodes = []
|
||||||
for item in items[:-1]:
|
for item in items[:-1]:
|
||||||
nodes.extend([item, sep])
|
nodes.extend([item, sep])
|
||||||
nodes.append(items[-1])
|
nodes.extend([items[-1], end])
|
||||||
nodes.append(end)
|
|
||||||
else:
|
else:
|
||||||
nodes = [*items, end]
|
nodes = [*items, end]
|
||||||
return Text(*nodes)
|
return Text(*nodes)
|
||||||
|
|
|
||||||
|
|
@ -8,14 +8,14 @@ from .middleware import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
"ConstI18nMiddleware",
|
||||||
|
"FSMI18nMiddleware",
|
||||||
"I18n",
|
"I18n",
|
||||||
"I18nMiddleware",
|
"I18nMiddleware",
|
||||||
"SimpleI18nMiddleware",
|
"SimpleI18nMiddleware",
|
||||||
"ConstI18nMiddleware",
|
"get_i18n",
|
||||||
"FSMI18nMiddleware",
|
|
||||||
"gettext",
|
"gettext",
|
||||||
"lazy_gettext",
|
"lazy_gettext",
|
||||||
"ngettext",
|
|
||||||
"lazy_ngettext",
|
"lazy_ngettext",
|
||||||
"get_i18n",
|
"ngettext",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,8 @@ from aiogram.utils.i18n.lazy_proxy import LazyProxy
|
||||||
def get_i18n() -> I18n:
|
def get_i18n() -> I18n:
|
||||||
i18n = I18n.get_current(no_error=True)
|
i18n = I18n.get_current(no_error=True)
|
||||||
if i18n is None:
|
if i18n is None:
|
||||||
raise LookupError("I18n context is not set")
|
msg = "I18n context is not set"
|
||||||
|
raise LookupError(msg)
|
||||||
return i18n
|
return i18n
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,27 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import gettext
|
import gettext
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generator, Optional, Tuple, Union
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from aiogram.utils.i18n.lazy_proxy import LazyProxy
|
from aiogram.utils.i18n.lazy_proxy import LazyProxy
|
||||||
from aiogram.utils.mixins import ContextInstanceMixin
|
from aiogram.utils.mixins import ContextInstanceMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
|
||||||
class I18n(ContextInstanceMixin["I18n"]):
|
class I18n(ContextInstanceMixin["I18n"]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
path: Union[str, Path],
|
path: str | Path,
|
||||||
default_locale: str = "en",
|
default_locale: str = "en",
|
||||||
domain: str = "messages",
|
domain: str = "messages",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = Path(path)
|
||||||
self.default_locale = default_locale
|
self.default_locale = default_locale
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.ctx_locale = ContextVar("aiogram_ctx_locale", default=default_locale)
|
self.ctx_locale = ContextVar("aiogram_ctx_locale", default=default_locale)
|
||||||
|
|
@ -43,7 +47,7 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
self.ctx_locale.reset(ctx_token)
|
self.ctx_locale.reset(ctx_token)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def context(self) -> Generator["I18n", None, None]:
|
def context(self) -> Generator[I18n, None, None]:
|
||||||
"""
|
"""
|
||||||
Use I18n context
|
Use I18n context
|
||||||
"""
|
"""
|
||||||
|
|
@ -53,24 +57,25 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
finally:
|
finally:
|
||||||
self.reset_current(token)
|
self.reset_current(token)
|
||||||
|
|
||||||
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
|
def find_locales(self) -> dict[str, gettext.GNUTranslations]:
|
||||||
"""
|
"""
|
||||||
Load all compiled locales from path
|
Load all compiled locales from path
|
||||||
|
|
||||||
:return: dict with locales
|
:return: dict with locales
|
||||||
"""
|
"""
|
||||||
translations: Dict[str, gettext.GNUTranslations] = {}
|
translations: dict[str, gettext.GNUTranslations] = {}
|
||||||
|
|
||||||
for name in os.listdir(self.path):
|
for name in self.path.iterdir():
|
||||||
if not os.path.isdir(os.path.join(self.path, name)):
|
if not (self.path / name).is_dir():
|
||||||
continue
|
continue
|
||||||
mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo")
|
mo_path = self.path / name / "LC_MESSAGES" / (self.domain + ".mo")
|
||||||
|
|
||||||
if os.path.exists(mo_path):
|
if mo_path.exists():
|
||||||
with open(mo_path, "rb") as fp:
|
with mo_path.open("rb") as fp:
|
||||||
translations[name] = gettext.GNUTranslations(fp)
|
translations[name.name] = gettext.GNUTranslations(fp)
|
||||||
elif os.path.exists(mo_path[:-2] + "po"): # pragma: no cover
|
elif mo_path.with_suffix(".po").exists(): # pragma: no cover
|
||||||
raise RuntimeError(f"Found locale '{name}' but this language is not compiled!")
|
msg = f"Found locale '{name.name}' but this language is not compiled!"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
return translations
|
return translations
|
||||||
|
|
||||||
|
|
@ -81,7 +86,7 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
self.locales = self.find_locales()
|
self.locales = self.find_locales()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_locales(self) -> Tuple[str, ...]:
|
def available_locales(self) -> tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
list of loaded locales
|
list of loaded locales
|
||||||
|
|
||||||
|
|
@ -90,7 +95,11 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
return tuple(self.locales.keys())
|
return tuple(self.locales.keys())
|
||||||
|
|
||||||
def gettext(
|
def gettext(
|
||||||
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
|
self,
|
||||||
|
singular: str,
|
||||||
|
plural: str | None = None,
|
||||||
|
n: int = 1,
|
||||||
|
locale: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get text
|
Get text
|
||||||
|
|
@ -107,7 +116,7 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
if locale not in self.locales:
|
if locale not in self.locales:
|
||||||
if n == 1:
|
if n == 1:
|
||||||
return singular
|
return singular
|
||||||
return plural if plural else singular
|
return plural or singular
|
||||||
|
|
||||||
translator = self.locales[locale]
|
translator = self.locales[locale]
|
||||||
|
|
||||||
|
|
@ -116,8 +125,17 @@ class I18n(ContextInstanceMixin["I18n"]):
|
||||||
return translator.ngettext(singular, plural, n)
|
return translator.ngettext(singular, plural, n)
|
||||||
|
|
||||||
def lazy_gettext(
|
def lazy_gettext(
|
||||||
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
|
self,
|
||||||
|
singular: str,
|
||||||
|
plural: str | None = None,
|
||||||
|
n: int = 1,
|
||||||
|
locale: str | None = None,
|
||||||
) -> LazyProxy:
|
) -> LazyProxy:
|
||||||
return LazyProxy(
|
return LazyProxy(
|
||||||
self.gettext, singular=singular, plural=plural, n=n, locale=locale, enable_cache=False
|
self.gettext,
|
||||||
|
singular=singular,
|
||||||
|
plural=plural,
|
||||||
|
n=n,
|
||||||
|
locale=locale,
|
||||||
|
enable_cache=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ except ImportError: # pragma: no cover
|
||||||
|
|
||||||
class LazyProxy: # type: ignore
|
class LazyProxy: # type: ignore
|
||||||
def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None:
|
||||||
raise RuntimeError(
|
msg = (
|
||||||
"LazyProxy can be used only when Babel installed\n"
|
"LazyProxy can be used only when Babel installed\n"
|
||||||
"Just install Babel (`pip install Babel`) "
|
"Just install Babel (`pip install Babel`) "
|
||||||
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
||||||
)
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Set
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from babel import Locale, UnknownLocaleError
|
from babel import Locale, UnknownLocaleError
|
||||||
|
|
@ -11,9 +13,13 @@ except ImportError: # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
from aiogram import BaseMiddleware, Router
|
from aiogram import BaseMiddleware, Router
|
||||||
from aiogram.fsm.context import FSMContext
|
|
||||||
from aiogram.types import TelegramObject, User
|
if TYPE_CHECKING:
|
||||||
from aiogram.utils.i18n.core import I18n
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from aiogram.fsm.context import FSMContext
|
||||||
|
from aiogram.types import TelegramObject, User
|
||||||
|
from aiogram.utils.i18n.core import I18n
|
||||||
|
|
||||||
|
|
||||||
class I18nMiddleware(BaseMiddleware, ABC):
|
class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
|
|
@ -24,7 +30,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
i18n: I18n,
|
i18n: I18n,
|
||||||
i18n_key: Optional[str] = "i18n",
|
i18n_key: str | None = "i18n",
|
||||||
middleware_key: str = "i18n_middleware",
|
middleware_key: str = "i18n_middleware",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,7 +45,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
self.middleware_key = middleware_key
|
self.middleware_key = middleware_key
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
self: BaseMiddleware, router: Router, exclude: Optional[Set[str]] = None
|
self: BaseMiddleware,
|
||||||
|
router: Router,
|
||||||
|
exclude: set[str] | None = None,
|
||||||
) -> BaseMiddleware:
|
) -> BaseMiddleware:
|
||||||
"""
|
"""
|
||||||
Register middleware for all events in the Router
|
Register middleware for all events in the Router
|
||||||
|
|
@ -59,9 +67,9 @@ class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
|
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||||
event: TelegramObject,
|
event: TelegramObject,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
current_locale = await self.get_locale(event=event, data=data) or self.i18n.default_locale
|
current_locale = await self.get_locale(event=event, data=data) or self.i18n.default_locale
|
||||||
|
|
||||||
|
|
@ -74,7 +82,7 @@ class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
return await handler(event, data)
|
return await handler(event, data)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
|
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Detect current user locale based on event and context.
|
Detect current user locale based on event and context.
|
||||||
|
|
||||||
|
|
@ -84,7 +92,6 @@ class I18nMiddleware(BaseMiddleware, ABC):
|
||||||
:param data:
|
:param data:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleI18nMiddleware(I18nMiddleware):
|
class SimpleI18nMiddleware(I18nMiddleware):
|
||||||
|
|
@ -97,27 +104,29 @@ class SimpleI18nMiddleware(I18nMiddleware):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
i18n: I18n,
|
i18n: I18n,
|
||||||
i18n_key: Optional[str] = "i18n",
|
i18n_key: str | None = "i18n",
|
||||||
middleware_key: str = "i18n_middleware",
|
middleware_key: str = "i18n_middleware",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
||||||
|
|
||||||
if Locale is None: # pragma: no cover
|
if Locale is None: # pragma: no cover
|
||||||
raise RuntimeError(
|
msg = (
|
||||||
f"{type(self).__name__} can be used only when Babel installed\n"
|
f"{type(self).__name__} can be used only when Babel installed\n"
|
||||||
"Just install Babel (`pip install Babel`) "
|
"Just install Babel (`pip install Babel`) "
|
||||||
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
||||||
)
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
|
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
|
||||||
if Locale is None: # pragma: no cover
|
if Locale is None: # pragma: no cover
|
||||||
raise RuntimeError(
|
msg = (
|
||||||
f"{type(self).__name__} can be used only when Babel installed\n"
|
f"{type(self).__name__} can be used only when Babel installed\n"
|
||||||
"Just install Babel (`pip install Babel`) "
|
"Just install Babel (`pip install Babel`) "
|
||||||
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
|
||||||
)
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
event_from_user: Optional[User] = data.get("event_from_user", None)
|
event_from_user: User | None = data.get("event_from_user")
|
||||||
if event_from_user is None or event_from_user.language_code is None:
|
if event_from_user is None or event_from_user.language_code is None:
|
||||||
return self.i18n.default_locale
|
return self.i18n.default_locale
|
||||||
try:
|
try:
|
||||||
|
|
@ -139,13 +148,13 @@ class ConstI18nMiddleware(I18nMiddleware):
|
||||||
self,
|
self,
|
||||||
locale: str,
|
locale: str,
|
||||||
i18n: I18n,
|
i18n: I18n,
|
||||||
i18n_key: Optional[str] = "i18n",
|
i18n_key: str | None = "i18n",
|
||||||
middleware_key: str = "i18n_middleware",
|
middleware_key: str = "i18n_middleware",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
||||||
self.locale = locale
|
self.locale = locale
|
||||||
|
|
||||||
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
|
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
|
||||||
return self.locale
|
return self.locale
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -158,14 +167,14 @@ class FSMI18nMiddleware(SimpleI18nMiddleware):
|
||||||
self,
|
self,
|
||||||
i18n: I18n,
|
i18n: I18n,
|
||||||
key: str = "locale",
|
key: str = "locale",
|
||||||
i18n_key: Optional[str] = "i18n",
|
i18n_key: str | None = "i18n",
|
||||||
middleware_key: str = "i18n_middleware",
|
middleware_key: str = "i18n_middleware",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
super().__init__(i18n=i18n, i18n_key=i18n_key, middleware_key=middleware_key)
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
|
async def get_locale(self, event: TelegramObject, data: dict[str, Any]) -> str:
|
||||||
fsm_context: Optional[FSMContext] = data.get("state")
|
fsm_context: FSMContext | None = data.get("state")
|
||||||
locale = None
|
locale = None
|
||||||
if fsm_context:
|
if fsm_context:
|
||||||
fsm_data = await fsm_context.get_data()
|
fsm_data = await fsm_context.get_data()
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,7 @@ from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from itertools import cycle as repeat_all
|
from itertools import cycle as repeat_all
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Generator,
|
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from aiogram.filters.callback_data import CallbackData
|
from aiogram.filters.callback_data import CallbackData
|
||||||
from aiogram.types import (
|
from aiogram.types import (
|
||||||
|
|
@ -34,11 +22,14 @@ from aiogram.types import (
|
||||||
WebAppInfo,
|
WebAppInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator, Iterable
|
||||||
|
|
||||||
ButtonType = TypeVar("ButtonType", InlineKeyboardButton, KeyboardButton)
|
ButtonType = TypeVar("ButtonType", InlineKeyboardButton, KeyboardButton)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class KeyboardBuilder(Generic[ButtonType], ABC):
|
class KeyboardBuilder(ABC, Generic[ButtonType]):
|
||||||
"""
|
"""
|
||||||
Generic keyboard builder that helps to adjust your markup with defined shape of lines.
|
Generic keyboard builder that helps to adjust your markup with defined shape of lines.
|
||||||
|
|
||||||
|
|
@ -50,16 +41,19 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
max_buttons: int = 0
|
max_buttons: int = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, button_type: Type[ButtonType], markup: Optional[List[List[ButtonType]]] = None
|
self,
|
||||||
|
button_type: type[ButtonType],
|
||||||
|
markup: list[list[ButtonType]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not issubclass(button_type, (InlineKeyboardButton, KeyboardButton)):
|
if not issubclass(button_type, (InlineKeyboardButton, KeyboardButton)):
|
||||||
raise ValueError(f"Button type {button_type} are not allowed here")
|
msg = f"Button type {button_type} are not allowed here"
|
||||||
self._button_type: Type[ButtonType] = button_type
|
raise ValueError(msg)
|
||||||
|
self._button_type: type[ButtonType] = button_type
|
||||||
if markup:
|
if markup:
|
||||||
self._validate_markup(markup)
|
self._validate_markup(markup)
|
||||||
else:
|
else:
|
||||||
markup = []
|
markup = []
|
||||||
self._markup: List[List[ButtonType]] = markup
|
self._markup: list[list[ButtonType]] = markup
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def buttons(self) -> Generator[ButtonType, None, None]:
|
def buttons(self) -> Generator[ButtonType, None, None]:
|
||||||
|
|
@ -79,9 +73,8 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
"""
|
"""
|
||||||
allowed = self._button_type
|
allowed = self._button_type
|
||||||
if not isinstance(button, allowed):
|
if not isinstance(button, allowed):
|
||||||
raise ValueError(
|
msg = f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}"
|
||||||
f"{button!r} should be type {allowed.__name__!r} not {type(button).__name__!r}"
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _validate_buttons(self, *buttons: ButtonType) -> bool:
|
def _validate_buttons(self, *buttons: ButtonType) -> bool:
|
||||||
|
|
@ -93,7 +86,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
"""
|
"""
|
||||||
return all(map(self._validate_button, buttons))
|
return all(map(self._validate_button, buttons))
|
||||||
|
|
||||||
def _validate_row(self, row: List[ButtonType]) -> bool:
|
def _validate_row(self, row: list[ButtonType]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check that row of buttons are correct
|
Check that row of buttons are correct
|
||||||
Row can be only list of allowed button types and has length 0 <= n <= 8
|
Row can be only list of allowed button types and has length 0 <= n <= 8
|
||||||
|
|
@ -102,16 +95,18 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not isinstance(row, list):
|
if not isinstance(row, list):
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Row {row!r} should be type 'List[{self._button_type.__name__}]' "
|
f"Row {row!r} should be type 'List[{self._button_type.__name__}]' "
|
||||||
f"not type {type(row).__name__}"
|
f"not type {type(row).__name__}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if len(row) > self.max_width:
|
if len(row) > self.max_width:
|
||||||
raise ValueError(f"Row {row!r} is too long (max width: {self.max_width})")
|
msg = f"Row {row!r} is too long (max width: {self.max_width})"
|
||||||
|
raise ValueError(msg)
|
||||||
self._validate_buttons(*row)
|
self._validate_buttons(*row)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _validate_markup(self, markup: List[List[ButtonType]]) -> bool:
|
def _validate_markup(self, markup: list[list[ButtonType]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check that passed markup has correct data structure
|
Check that passed markup has correct data structure
|
||||||
Markup is list of lists of buttons
|
Markup is list of lists of buttons
|
||||||
|
|
@ -121,15 +116,17 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
if not isinstance(markup, list):
|
if not isinstance(markup, list):
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Markup should be type 'List[List[{self._button_type.__name__}]]' "
|
f"Markup should be type 'List[List[{self._button_type.__name__}]]' "
|
||||||
f"not type {type(markup).__name__!r}"
|
f"not type {type(markup).__name__!r}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
for row in markup:
|
for row in markup:
|
||||||
self._validate_row(row)
|
self._validate_row(row)
|
||||||
count += len(row)
|
count += len(row)
|
||||||
if count > self.max_buttons:
|
if count > self.max_buttons:
|
||||||
raise ValueError(f"Too much buttons detected Max allowed count - {self.max_buttons}")
|
msg = f"Too much buttons detected Max allowed count - {self.max_buttons}"
|
||||||
|
raise ValueError(msg)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _validate_size(self, size: Any) -> int:
|
def _validate_size(self, size: Any) -> int:
|
||||||
|
|
@ -140,14 +137,14 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not isinstance(size, int):
|
if not isinstance(size, int):
|
||||||
raise ValueError("Only int sizes are allowed")
|
msg = "Only int sizes are allowed"
|
||||||
|
raise ValueError(msg)
|
||||||
if size not in range(self.min_width, self.max_width + 1):
|
if size not in range(self.min_width, self.max_width + 1):
|
||||||
raise ValueError(
|
msg = f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]"
|
||||||
f"Row size {size} is not allowed, range: [{self.min_width}, {self.max_width}]"
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
return size
|
return size
|
||||||
|
|
||||||
def export(self) -> List[List[ButtonType]]:
|
def export(self) -> list[list[ButtonType]]:
|
||||||
"""
|
"""
|
||||||
Export configured markup as list of lists of buttons
|
Export configured markup as list of lists of buttons
|
||||||
|
|
||||||
|
|
@ -161,7 +158,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
"""
|
"""
|
||||||
return deepcopy(self._markup)
|
return deepcopy(self._markup)
|
||||||
|
|
||||||
def add(self, *buttons: ButtonType) -> "KeyboardBuilder[ButtonType]":
|
def add(self, *buttons: ButtonType) -> KeyboardBuilder[ButtonType]:
|
||||||
"""
|
"""
|
||||||
Add one or many buttons to markup.
|
Add one or many buttons to markup.
|
||||||
|
|
||||||
|
|
@ -189,9 +186,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
self._markup = markup
|
self._markup = markup
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def row(
|
def row(self, *buttons: ButtonType, width: int | None = None) -> KeyboardBuilder[ButtonType]:
|
||||||
self, *buttons: ButtonType, width: Optional[int] = None
|
|
||||||
) -> "KeyboardBuilder[ButtonType]":
|
|
||||||
"""
|
"""
|
||||||
Add row to markup
|
Add row to markup
|
||||||
|
|
||||||
|
|
@ -211,7 +206,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def adjust(self, *sizes: int, repeat: bool = False) -> "KeyboardBuilder[ButtonType]":
|
def adjust(self, *sizes: int, repeat: bool = False) -> KeyboardBuilder[ButtonType]:
|
||||||
"""
|
"""
|
||||||
Adjust previously added buttons to specific row sizes.
|
Adjust previously added buttons to specific row sizes.
|
||||||
|
|
||||||
|
|
@ -232,7 +227,7 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
size = next(sizes_iter)
|
size = next(sizes_iter)
|
||||||
|
|
||||||
markup = []
|
markup = []
|
||||||
row: List[ButtonType] = []
|
row: list[ButtonType] = []
|
||||||
for button in self.buttons:
|
for button in self.buttons:
|
||||||
if len(row) >= size:
|
if len(row) >= size:
|
||||||
markup.append(row)
|
markup.append(row)
|
||||||
|
|
@ -244,33 +239,35 @@ class KeyboardBuilder(Generic[ButtonType], ABC):
|
||||||
self._markup = markup
|
self._markup = markup
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _button(self, **kwargs: Any) -> "KeyboardBuilder[ButtonType]":
|
def _button(self, **kwargs: Any) -> KeyboardBuilder[ButtonType]:
|
||||||
"""
|
"""
|
||||||
Add button to markup
|
Add button to markup
|
||||||
|
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if isinstance(callback_data := kwargs.get("callback_data", None), CallbackData):
|
if isinstance(callback_data := kwargs.get("callback_data"), CallbackData):
|
||||||
kwargs["callback_data"] = callback_data.pack()
|
kwargs["callback_data"] = callback_data.pack()
|
||||||
button = self._button_type(**kwargs)
|
button = self._button_type(**kwargs)
|
||||||
return self.add(button)
|
return self.add(button)
|
||||||
|
|
||||||
def as_markup(self, **kwargs: Any) -> Union[InlineKeyboardMarkup, ReplyKeyboardMarkup]:
|
def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup | ReplyKeyboardMarkup:
|
||||||
if self._button_type is KeyboardButton:
|
if self._button_type is KeyboardButton:
|
||||||
keyboard = cast(List[List[KeyboardButton]], self.export()) # type: ignore
|
keyboard = cast(list[list[KeyboardButton]], self.export()) # type: ignore
|
||||||
return ReplyKeyboardMarkup(keyboard=keyboard, **kwargs)
|
return ReplyKeyboardMarkup(keyboard=keyboard, **kwargs)
|
||||||
inline_keyboard = cast(List[List[InlineKeyboardButton]], self.export()) # type: ignore
|
inline_keyboard = cast(list[list[InlineKeyboardButton]], self.export()) # type: ignore
|
||||||
return InlineKeyboardMarkup(inline_keyboard=inline_keyboard)
|
return InlineKeyboardMarkup(inline_keyboard=inline_keyboard)
|
||||||
|
|
||||||
def attach(self, builder: "KeyboardBuilder[ButtonType]") -> "KeyboardBuilder[ButtonType]":
|
def attach(self, builder: KeyboardBuilder[ButtonType]) -> KeyboardBuilder[ButtonType]:
|
||||||
if not isinstance(builder, KeyboardBuilder):
|
if not isinstance(builder, KeyboardBuilder):
|
||||||
raise ValueError(f"Only KeyboardBuilder can be attached, not {type(builder).__name__}")
|
msg = f"Only KeyboardBuilder can be attached, not {type(builder).__name__}"
|
||||||
|
raise ValueError(msg)
|
||||||
if builder._button_type is not self._button_type:
|
if builder._button_type is not self._button_type:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Only builders with same button type can be attached, "
|
f"Only builders with same button type can be attached, "
|
||||||
f"not {self._button_type.__name__} and {builder._button_type.__name__}"
|
f"not {self._button_type.__name__} and {builder._button_type.__name__}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
self._markup.extend(builder.export())
|
self._markup.extend(builder.export())
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -306,18 +303,18 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
text: str,
|
text: str,
|
||||||
url: Optional[str] = None,
|
url: str | None = None,
|
||||||
callback_data: Optional[Union[str, CallbackData]] = None,
|
callback_data: str | CallbackData | None = None,
|
||||||
web_app: Optional[WebAppInfo] = None,
|
web_app: WebAppInfo | None = None,
|
||||||
login_url: Optional[LoginUrl] = None,
|
login_url: LoginUrl | None = None,
|
||||||
switch_inline_query: Optional[str] = None,
|
switch_inline_query: str | None = None,
|
||||||
switch_inline_query_current_chat: Optional[str] = None,
|
switch_inline_query_current_chat: str | None = None,
|
||||||
switch_inline_query_chosen_chat: Optional[SwitchInlineQueryChosenChat] = None,
|
switch_inline_query_chosen_chat: SwitchInlineQueryChosenChat | None = None,
|
||||||
copy_text: Optional[CopyTextButton] = None,
|
copy_text: CopyTextButton | None = None,
|
||||||
callback_game: Optional[CallbackGame] = None,
|
callback_game: CallbackGame | None = None,
|
||||||
pay: Optional[bool] = None,
|
pay: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "InlineKeyboardBuilder":
|
) -> InlineKeyboardBuilder:
|
||||||
return cast(
|
return cast(
|
||||||
InlineKeyboardBuilder,
|
InlineKeyboardBuilder,
|
||||||
self._button(
|
self._button(
|
||||||
|
|
@ -340,10 +337,10 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
|
||||||
"""Construct an InlineKeyboardMarkup"""
|
"""Construct an InlineKeyboardMarkup"""
|
||||||
return cast(InlineKeyboardMarkup, super().as_markup(**kwargs))
|
return cast(InlineKeyboardMarkup, super().as_markup(**kwargs))
|
||||||
|
|
||||||
def __init__(self, markup: Optional[List[List[InlineKeyboardButton]]] = None) -> None:
|
def __init__(self, markup: list[list[InlineKeyboardButton]] | None = None) -> None:
|
||||||
super().__init__(button_type=InlineKeyboardButton, markup=markup)
|
super().__init__(button_type=InlineKeyboardButton, markup=markup)
|
||||||
|
|
||||||
def copy(self: "InlineKeyboardBuilder") -> "InlineKeyboardBuilder":
|
def copy(self: InlineKeyboardBuilder) -> InlineKeyboardBuilder:
|
||||||
"""
|
"""
|
||||||
Make full copy of current builder with markup
|
Make full copy of current builder with markup
|
||||||
|
|
||||||
|
|
@ -353,8 +350,9 @@ class InlineKeyboardBuilder(KeyboardBuilder[InlineKeyboardButton]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_markup(
|
def from_markup(
|
||||||
cls: Type["InlineKeyboardBuilder"], markup: InlineKeyboardMarkup
|
cls: type[InlineKeyboardBuilder],
|
||||||
) -> "InlineKeyboardBuilder":
|
markup: InlineKeyboardMarkup,
|
||||||
|
) -> InlineKeyboardBuilder:
|
||||||
"""
|
"""
|
||||||
Create builder from existing markup
|
Create builder from existing markup
|
||||||
|
|
||||||
|
|
@ -377,14 +375,14 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
text: str,
|
text: str,
|
||||||
request_users: Optional[KeyboardButtonRequestUsers] = None,
|
request_users: KeyboardButtonRequestUsers | None = None,
|
||||||
request_chat: Optional[KeyboardButtonRequestChat] = None,
|
request_chat: KeyboardButtonRequestChat | None = None,
|
||||||
request_contact: Optional[bool] = None,
|
request_contact: bool | None = None,
|
||||||
request_location: Optional[bool] = None,
|
request_location: bool | None = None,
|
||||||
request_poll: Optional[KeyboardButtonPollType] = None,
|
request_poll: KeyboardButtonPollType | None = None,
|
||||||
web_app: Optional[WebAppInfo] = None,
|
web_app: WebAppInfo | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "ReplyKeyboardBuilder":
|
) -> ReplyKeyboardBuilder:
|
||||||
return cast(
|
return cast(
|
||||||
ReplyKeyboardBuilder,
|
ReplyKeyboardBuilder,
|
||||||
self._button(
|
self._button(
|
||||||
|
|
@ -403,10 +401,10 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
|
||||||
"""Construct a ReplyKeyboardMarkup"""
|
"""Construct a ReplyKeyboardMarkup"""
|
||||||
return cast(ReplyKeyboardMarkup, super().as_markup(**kwargs))
|
return cast(ReplyKeyboardMarkup, super().as_markup(**kwargs))
|
||||||
|
|
||||||
def __init__(self, markup: Optional[List[List[KeyboardButton]]] = None) -> None:
|
def __init__(self, markup: list[list[KeyboardButton]] | None = None) -> None:
|
||||||
super().__init__(button_type=KeyboardButton, markup=markup)
|
super().__init__(button_type=KeyboardButton, markup=markup)
|
||||||
|
|
||||||
def copy(self: "ReplyKeyboardBuilder") -> "ReplyKeyboardBuilder":
|
def copy(self: ReplyKeyboardBuilder) -> ReplyKeyboardBuilder:
|
||||||
"""
|
"""
|
||||||
Make full copy of current builder with markup
|
Make full copy of current builder with markup
|
||||||
|
|
||||||
|
|
@ -415,7 +413,7 @@ class ReplyKeyboardBuilder(KeyboardBuilder[KeyboardButton]):
|
||||||
return ReplyKeyboardBuilder(markup=self.export())
|
return ReplyKeyboardBuilder(markup=self.export())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_markup(cls, markup: ReplyKeyboardMarkup) -> "ReplyKeyboardBuilder":
|
def from_markup(cls, markup: ReplyKeyboardMarkup) -> ReplyKeyboardBuilder:
|
||||||
"""
|
"""
|
||||||
Create builder from existing markup
|
Create builder from existing markup
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ BRANCH = "dev-3.x"
|
||||||
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
|
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
|
||||||
|
|
||||||
|
|
||||||
def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query: Any) -> str:
|
def _format_url(url: str, *path: str, fragment_: str | None = None, **query: Any) -> str:
|
||||||
url = urljoin(url, "/".join(path), allow_fragments=True)
|
url = urljoin(url, "/".join(path), allow_fragments=True)
|
||||||
if query:
|
if query:
|
||||||
url += "?" + urlencode(query)
|
url += "?" + urlencode(query)
|
||||||
|
|
@ -16,7 +16,7 @@ def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query:
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def docs_url(*path: str, fragment_: Optional[str] = None, **query: Any) -> str:
|
def docs_url(*path: str, fragment_: str | None = None, **query: Any) -> str:
|
||||||
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
|
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,7 +30,7 @@ def create_telegram_link(*path: str, **kwargs: Any) -> str:
|
||||||
|
|
||||||
def create_channel_bot_link(
|
def create_channel_bot_link(
|
||||||
username: str,
|
username: str,
|
||||||
parameter: Optional[str] = None,
|
parameter: str | None = None,
|
||||||
change_info: bool = False,
|
change_info: bool = False,
|
||||||
post_messages: bool = False,
|
post_messages: bool = False,
|
||||||
edit_messages: bool = False,
|
edit_messages: bool = False,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from magic_filter import MagicFilter as _MagicFilter
|
from magic_filter import MagicFilter as _MagicFilter
|
||||||
from magic_filter import MagicT as _MagicT
|
from magic_filter import MagicT as _MagicT
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ def strikethrough(*content: Any, sep: str = " ") -> str:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return markdown_decoration.strikethrough(
|
return markdown_decoration.strikethrough(
|
||||||
value=markdown_decoration.quote(_join(*content, sep=sep))
|
value=markdown_decoration.quote(_join(*content, sep=sep)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,7 +183,7 @@ def blockquote(*content: Any, sep: str = "\n") -> str:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return markdown_decoration.blockquote(
|
return markdown_decoration.blockquote(
|
||||||
value=markdown_decoration.quote(_join(*content, sep=sep))
|
value=markdown_decoration.quote(_join(*content, sep=sep)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
from aiogram.enums import InputMediaType
|
from aiogram.enums import InputMediaType
|
||||||
from aiogram.types import (
|
from aiogram.types import (
|
||||||
|
|
@ -12,12 +12,7 @@ from aiogram.types import (
|
||||||
MessageEntity,
|
MessageEntity,
|
||||||
)
|
)
|
||||||
|
|
||||||
MediaType = Union[
|
MediaType = InputMediaAudio | InputMediaPhoto | InputMediaVideo | InputMediaDocument
|
||||||
InputMediaAudio,
|
|
||||||
InputMediaPhoto,
|
|
||||||
InputMediaVideo,
|
|
||||||
InputMediaDocument,
|
|
||||||
]
|
|
||||||
|
|
||||||
MAX_MEDIA_GROUP_SIZE = 10
|
MAX_MEDIA_GROUP_SIZE = 10
|
||||||
|
|
||||||
|
|
@ -27,9 +22,9 @@ class MediaGroupBuilder:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
media: Optional[List[MediaType]] = None,
|
media: list[MediaType] | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Helper class for building media groups.
|
Helper class for building media groups.
|
||||||
|
|
@ -39,7 +34,7 @@ class MediaGroupBuilder:
|
||||||
:param caption_entities: List of special entities in the caption,
|
:param caption_entities: List of special entities in the caption,
|
||||||
like usernames, URLs, etc. (optional)
|
like usernames, URLs, etc. (optional)
|
||||||
"""
|
"""
|
||||||
self._media: List[MediaType] = []
|
self._media: list[MediaType] = []
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.caption_entities = caption_entities
|
self.caption_entities = caption_entities
|
||||||
|
|
||||||
|
|
@ -47,14 +42,16 @@ class MediaGroupBuilder:
|
||||||
|
|
||||||
def _add(self, media: MediaType) -> None:
|
def _add(self, media: MediaType) -> None:
|
||||||
if not isinstance(media, InputMedia):
|
if not isinstance(media, InputMedia):
|
||||||
raise ValueError("Media must be instance of InputMedia")
|
msg = "Media must be instance of InputMedia"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if len(self._media) >= MAX_MEDIA_GROUP_SIZE:
|
if len(self._media) >= MAX_MEDIA_GROUP_SIZE:
|
||||||
raise ValueError("Media group can't contain more than 10 elements")
|
msg = "Media group can't contain more than 10 elements"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self._media.append(media)
|
self._media.append(media)
|
||||||
|
|
||||||
def _extend(self, media: List[MediaType]) -> None:
|
def _extend(self, media: list[MediaType]) -> None:
|
||||||
for m in media:
|
for m in media:
|
||||||
self._add(m)
|
self._add(m)
|
||||||
|
|
||||||
|
|
@ -63,13 +60,13 @@ class MediaGroupBuilder:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
type: Literal[InputMediaType.AUDIO],
|
type: Literal[InputMediaType.AUDIO],
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
duration: Optional[int] = None,
|
duration: int | None = None,
|
||||||
performer: Optional[str] = None,
|
performer: str | None = None,
|
||||||
title: Optional[str] = None,
|
title: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
@ -79,11 +76,11 @@ class MediaGroupBuilder:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
type: Literal[InputMediaType.PHOTO],
|
type: Literal[InputMediaType.PHOTO],
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
has_spoiler: Optional[bool] = None,
|
has_spoiler: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
@ -93,16 +90,16 @@ class MediaGroupBuilder:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
type: Literal[InputMediaType.VIDEO],
|
type: Literal[InputMediaType.VIDEO],
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
thumbnail: Optional[Union[InputFile, str]] = None,
|
thumbnail: InputFile | str | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
width: Optional[int] = None,
|
width: int | None = None,
|
||||||
height: Optional[int] = None,
|
height: int | None = None,
|
||||||
duration: Optional[int] = None,
|
duration: int | None = None,
|
||||||
supports_streaming: Optional[bool] = None,
|
supports_streaming: bool | None = None,
|
||||||
has_spoiler: Optional[bool] = None,
|
has_spoiler: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
@ -112,12 +109,12 @@ class MediaGroupBuilder:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
type: Literal[InputMediaType.DOCUMENT],
|
type: Literal[InputMediaType.DOCUMENT],
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
thumbnail: Optional[Union[InputFile, str]] = None,
|
thumbnail: InputFile | str | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
disable_content_type_detection: Optional[bool] = None,
|
disable_content_type_detection: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
@ -140,18 +137,19 @@ class MediaGroupBuilder:
|
||||||
elif type_ == InputMediaType.DOCUMENT:
|
elif type_ == InputMediaType.DOCUMENT:
|
||||||
self.add_document(**kwargs)
|
self.add_document(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown media type: {type_!r}")
|
msg = f"Unknown media type: {type_!r}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def add_audio(
|
def add_audio(
|
||||||
self,
|
self,
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
thumbnail: Optional[InputFile] = None,
|
thumbnail: InputFile | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
duration: Optional[int] = None,
|
duration: int | None = None,
|
||||||
performer: Optional[str] = None,
|
performer: str | None = None,
|
||||||
title: Optional[str] = None,
|
title: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -189,16 +187,16 @@ class MediaGroupBuilder:
|
||||||
performer=performer,
|
performer=performer,
|
||||||
title=title,
|
title=title,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_photo(
|
def add_photo(
|
||||||
self,
|
self,
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
has_spoiler: Optional[bool] = None,
|
has_spoiler: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -228,21 +226,21 @@ class MediaGroupBuilder:
|
||||||
caption_entities=caption_entities,
|
caption_entities=caption_entities,
|
||||||
has_spoiler=has_spoiler,
|
has_spoiler=has_spoiler,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_video(
|
def add_video(
|
||||||
self,
|
self,
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
thumbnail: Optional[InputFile] = None,
|
thumbnail: InputFile | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
width: Optional[int] = None,
|
width: int | None = None,
|
||||||
height: Optional[int] = None,
|
height: int | None = None,
|
||||||
duration: Optional[int] = None,
|
duration: int | None = None,
|
||||||
supports_streaming: Optional[bool] = None,
|
supports_streaming: bool | None = None,
|
||||||
has_spoiler: Optional[bool] = None,
|
has_spoiler: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -290,17 +288,17 @@ class MediaGroupBuilder:
|
||||||
supports_streaming=supports_streaming,
|
supports_streaming=supports_streaming,
|
||||||
has_spoiler=has_spoiler,
|
has_spoiler=has_spoiler,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_document(
|
def add_document(
|
||||||
self,
|
self,
|
||||||
media: Union[str, InputFile],
|
media: str | InputFile,
|
||||||
thumbnail: Optional[InputFile] = None,
|
thumbnail: InputFile | None = None,
|
||||||
caption: Optional[str] = None,
|
caption: str | None = None,
|
||||||
parse_mode: Optional[str] = UNSET_PARSE_MODE,
|
parse_mode: str | None = UNSET_PARSE_MODE,
|
||||||
caption_entities: Optional[List[MessageEntity]] = None,
|
caption_entities: list[MessageEntity] | None = None,
|
||||||
disable_content_type_detection: Optional[bool] = None,
|
disable_content_type_detection: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -342,10 +340,10 @@ class MediaGroupBuilder:
|
||||||
caption_entities=caption_entities,
|
caption_entities=caption_entities,
|
||||||
disable_content_type_detection=disable_content_type_detection,
|
disable_content_type_detection=disable_content_type_detection,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self) -> List[MediaType]:
|
def build(self) -> list[MediaType]:
|
||||||
"""
|
"""
|
||||||
Builds a list of media objects for a media group.
|
Builds a list of media objects for a media group.
|
||||||
|
|
||||||
|
|
@ -353,7 +351,7 @@ class MediaGroupBuilder:
|
||||||
|
|
||||||
:return: List of media objects.
|
:return: List of media objects.
|
||||||
"""
|
"""
|
||||||
update_first_media: Dict[str, Any] = {"caption": self.caption}
|
update_first_media: dict[str, Any] = {"caption": self.caption}
|
||||||
if self.caption_entities is not None:
|
if self.caption_entities is not None:
|
||||||
update_first_media["caption_entities"] = self.caption_entities
|
update_first_media["caption_entities"] = self.caption_entities
|
||||||
update_first_media["parse_mode"] = None
|
update_first_media["parse_mode"] = None
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar, cast, overload
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Literal
|
from typing import Literal
|
||||||
|
|
||||||
__all__ = ("ContextInstanceMixin", "DataMixin")
|
__all__ = ("ContextInstanceMixin", "DataMixin")
|
||||||
|
|
||||||
|
|
||||||
class DataMixin:
|
class DataMixin:
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
data: Optional[Dict[str, Any]] = getattr(self, "_data", None)
|
data: dict[str, Any] | None = getattr(self, "_data", None)
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {}
|
data = {}
|
||||||
setattr(self, "_data", data)
|
setattr(self, "_data", data)
|
||||||
|
|
@ -30,7 +30,7 @@ class DataMixin:
|
||||||
def __contains__(self, key: str) -> bool:
|
def __contains__(self, key: str) -> bool:
|
||||||
return key in self.data
|
return key in self.data
|
||||||
|
|
||||||
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
def get(self, key: str, default: Any | None = None) -> Any | None:
|
||||||
return self.data.get(key, default)
|
return self.data.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,36 +44,40 @@ class ContextInstanceMixin(Generic[ContextInstance]):
|
||||||
super().__init_subclass__()
|
super().__init_subclass__()
|
||||||
cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}")
|
cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}")
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current(cls) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
|
def get_current(cls) -> ContextInstance | None: # pragma: no cover
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current( # noqa: F811
|
def get_current(
|
||||||
cls, no_error: Literal[True]
|
cls,
|
||||||
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
|
no_error: Literal[True],
|
||||||
|
) -> ContextInstance | None: # pragma: no cover
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current( # noqa: F811
|
def get_current(
|
||||||
cls, no_error: Literal[False]
|
cls,
|
||||||
) -> ContextInstance: # pragma: no cover # noqa: F811
|
no_error: Literal[False],
|
||||||
|
) -> ContextInstance: # pragma: no cover
|
||||||
...
|
...
|
||||||
|
|
||||||
@classmethod # noqa: F811
|
@classmethod
|
||||||
def get_current( # noqa: F811
|
def get_current(
|
||||||
cls, no_error: bool = True
|
cls,
|
||||||
) -> Optional[ContextInstance]: # pragma: no cover # noqa: F811
|
no_error: bool = True,
|
||||||
|
) -> ContextInstance | None: # pragma: no cover
|
||||||
# on mypy 0.770 I catch that contextvars.ContextVar always contextvars.ContextVar[Any]
|
# on mypy 0.770 I catch that contextvars.ContextVar always contextvars.ContextVar[Any]
|
||||||
cls.__context_instance = cast(
|
cls.__context_instance = cast(
|
||||||
contextvars.ContextVar[ContextInstance], cls.__context_instance
|
contextvars.ContextVar[ContextInstance],
|
||||||
|
cls.__context_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current: Optional[ContextInstance] = cls.__context_instance.get()
|
current: ContextInstance | None = cls.__context_instance.get()
|
||||||
except LookupError:
|
except LookupError:
|
||||||
if no_error:
|
if no_error:
|
||||||
current = None
|
current = None
|
||||||
|
|
@ -85,9 +89,8 @@ class ContextInstanceMixin(Generic[ContextInstance]):
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_current(cls, value: ContextInstance) -> contextvars.Token[ContextInstance]:
|
def set_current(cls, value: ContextInstance) -> contextvars.Token[ContextInstance]:
|
||||||
if not isinstance(value, cls):
|
if not isinstance(value, cls):
|
||||||
raise TypeError(
|
msg = f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
|
||||||
f"Value should be instance of {cls.__name__!r} not {type(value).__name__!r}"
|
raise TypeError(msg)
|
||||||
)
|
|
||||||
return cls.__context_instance.set(value)
|
return cls.__context_instance.set(value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,13 +61,18 @@ Encoding and decoding with your own methods:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
||||||
from typing import Callable, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
|
||||||
def encode_payload(
|
def encode_payload(
|
||||||
payload: str,
|
payload: str,
|
||||||
encoder: Optional[Callable[[bytes], bytes]] = None,
|
encoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Encode payload with encoder.
|
"""Encode payload with encoder.
|
||||||
|
|
||||||
|
|
@ -85,7 +90,7 @@ def encode_payload(
|
||||||
|
|
||||||
def decode_payload(
|
def decode_payload(
|
||||||
payload: str,
|
payload: str,
|
||||||
decoder: Optional[Callable[[bytes], bytes]] = None,
|
decoder: Callable[[bytes], bytes] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decode URL-safe base64url payload with decoder."""
|
"""Decode URL-safe base64url payload with decoder."""
|
||||||
original_payload = _decode_b64(payload)
|
original_payload = _decode_b64(payload)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -9,7 +9,7 @@ from aiogram.methods import TelegramMethod
|
||||||
from aiogram.types import InputFile
|
from aiogram.types import InputFile
|
||||||
|
|
||||||
|
|
||||||
def _get_fake_bot(default: Optional[DefaultBotProperties] = None) -> Bot:
|
def _get_fake_bot(default: DefaultBotProperties | None = None) -> Bot:
|
||||||
if default is None:
|
if default is None:
|
||||||
default = DefaultBotProperties()
|
default = DefaultBotProperties()
|
||||||
return Bot(token="42:Fake", default=default)
|
return Bot(token="42:Fake", default=default)
|
||||||
|
|
@ -28,12 +28,12 @@ class DeserializedTelegramObject:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: Any
|
data: Any
|
||||||
files: Dict[str, InputFile]
|
files: dict[str, InputFile]
|
||||||
|
|
||||||
|
|
||||||
def deserialize_telegram_object(
|
def deserialize_telegram_object(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
default: Optional[DefaultBotProperties] = None,
|
default: DefaultBotProperties | None = None,
|
||||||
include_api_method_name: bool = True,
|
include_api_method_name: bool = True,
|
||||||
) -> DeserializedTelegramObject:
|
) -> DeserializedTelegramObject:
|
||||||
"""
|
"""
|
||||||
|
|
@ -55,7 +55,7 @@ def deserialize_telegram_object(
|
||||||
# Fake bot is needed to exclude global defaults from the object.
|
# Fake bot is needed to exclude global defaults from the object.
|
||||||
fake_bot = _get_fake_bot(default=default)
|
fake_bot = _get_fake_bot(default=default)
|
||||||
|
|
||||||
files: Dict[str, InputFile] = {}
|
files: dict[str, InputFile] = {}
|
||||||
prepared = fake_bot.session.prepare_value(
|
prepared = fake_bot.session.prepare_value(
|
||||||
obj,
|
obj,
|
||||||
bot=fake_bot,
|
bot=fake_bot,
|
||||||
|
|
@ -70,7 +70,7 @@ def deserialize_telegram_object(
|
||||||
|
|
||||||
def deserialize_telegram_object_to_python(
|
def deserialize_telegram_object_to_python(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
default: Optional[DefaultBotProperties] = None,
|
default: DefaultBotProperties | None = None,
|
||||||
include_api_method_name: bool = True,
|
include_api_method_name: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,20 +3,23 @@ from __future__ import annotations
|
||||||
import html
|
import html
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Generator, List, Optional, Pattern, cast
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
from aiogram.enums import MessageEntityType
|
from aiogram.enums import MessageEntityType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
from re import Pattern
|
||||||
|
|
||||||
from aiogram.types import MessageEntity
|
from aiogram.types import MessageEntity
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"HtmlDecoration",
|
"HtmlDecoration",
|
||||||
"MarkdownDecoration",
|
"MarkdownDecoration",
|
||||||
"TextDecoration",
|
"TextDecoration",
|
||||||
|
"add_surrogates",
|
||||||
"html_decoration",
|
"html_decoration",
|
||||||
"markdown_decoration",
|
"markdown_decoration",
|
||||||
"add_surrogates",
|
|
||||||
"remove_surrogates",
|
"remove_surrogates",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -80,7 +83,7 @@ class TextDecoration(ABC):
|
||||||
# API it will be here too
|
# API it will be here too
|
||||||
return self.quote(text)
|
return self.quote(text)
|
||||||
|
|
||||||
def unparse(self, text: str, entities: Optional[List[MessageEntity]] = None) -> str:
|
def unparse(self, text: str, entities: list[MessageEntity] | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Unparse message entities
|
Unparse message entities
|
||||||
|
|
||||||
|
|
@ -92,15 +95,15 @@ class TextDecoration(ABC):
|
||||||
self._unparse_entities(
|
self._unparse_entities(
|
||||||
add_surrogates(text),
|
add_surrogates(text),
|
||||||
sorted(entities, key=lambda item: item.offset) if entities else [],
|
sorted(entities, key=lambda item: item.offset) if entities else [],
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _unparse_entities(
|
def _unparse_entities(
|
||||||
self,
|
self,
|
||||||
text: bytes,
|
text: bytes,
|
||||||
entities: List[MessageEntity],
|
entities: list[MessageEntity],
|
||||||
offset: Optional[int] = None,
|
offset: int | None = None,
|
||||||
length: Optional[int] = None,
|
length: int | None = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if offset is None:
|
if offset is None:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
@ -115,7 +118,7 @@ class TextDecoration(ABC):
|
||||||
offset = entity.offset * 2 + entity.length * 2
|
offset = entity.offset * 2 + entity.length * 2
|
||||||
|
|
||||||
sub_entities = list(
|
sub_entities = list(
|
||||||
filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :])
|
filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :]),
|
||||||
)
|
)
|
||||||
yield self.apply_entity(
|
yield self.apply_entity(
|
||||||
entity,
|
entity,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ class TokenValidationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def validate_token(token: str) -> bool:
|
def validate_token(token: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate Telegram token
|
Validate Telegram token
|
||||||
|
|
@ -14,9 +14,8 @@ def validate_token(token: str) -> bool:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not isinstance(token, str):
|
if not isinstance(token, str):
|
||||||
raise TokenValidationError(
|
msg = f"Token is invalid! It must be 'str' type instead of {type(token)} type."
|
||||||
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
|
raise TokenValidationError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
if any(x.isspace() for x in token):
|
if any(x.isspace() for x in token):
|
||||||
message = "Token is invalid! It can't contains spaces."
|
message = "Token is invalid! It can't contains spaces."
|
||||||
|
|
@ -24,12 +23,13 @@ def validate_token(token: str) -> bool:
|
||||||
|
|
||||||
left, sep, right = token.partition(":")
|
left, sep, right = token.partition(":")
|
||||||
if (not sep) or (not left.isdigit()) or (not right):
|
if (not sep) or (not left.isdigit()) or (not right):
|
||||||
raise TokenValidationError("Token is invalid!")
|
msg = "Token is invalid!"
|
||||||
|
raise TokenValidationError(msg)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def extract_bot_id(token: str) -> int:
|
def extract_bot_id(token: str) -> int:
|
||||||
"""
|
"""
|
||||||
Extract bot ID from Telegram token
|
Extract bot ID from Telegram token
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any
|
||||||
from urllib.parse import parse_qsl
|
from urllib.parse import parse_qsl
|
||||||
|
|
||||||
from aiogram.types import TelegramObject
|
from aiogram.types import TelegramObject
|
||||||
|
|
@ -25,9 +26,9 @@ class WebAppChat(TelegramObject):
|
||||||
"""Type of chat, can be either “group”, “supergroup” or “channel”"""
|
"""Type of chat, can be either “group”, “supergroup” or “channel”"""
|
||||||
title: str
|
title: str
|
||||||
"""Title of the chat"""
|
"""Title of the chat"""
|
||||||
username: Optional[str] = None
|
username: str | None = None
|
||||||
"""Username of the chat"""
|
"""Username of the chat"""
|
||||||
photo_url: Optional[str] = None
|
photo_url: str | None = None
|
||||||
"""URL of the chat’s photo. The photo can be in .jpeg or .svg formats.
|
"""URL of the chat’s photo. The photo can be in .jpeg or .svg formats.
|
||||||
Only returned for Web Apps launched from the attachment menu."""
|
Only returned for Web Apps launched from the attachment menu."""
|
||||||
|
|
||||||
|
|
@ -44,23 +45,23 @@ class WebAppUser(TelegramObject):
|
||||||
and some programming languages may have difficulty/silent defects in interpreting it.
|
and some programming languages may have difficulty/silent defects in interpreting it.
|
||||||
It has at most 52 significant bits, so a 64-bit integer or a double-precision float type
|
It has at most 52 significant bits, so a 64-bit integer or a double-precision float type
|
||||||
is safe for storing this identifier."""
|
is safe for storing this identifier."""
|
||||||
is_bot: Optional[bool] = None
|
is_bot: bool | None = None
|
||||||
"""True, if this user is a bot. Returns in the receiver field only."""
|
"""True, if this user is a bot. Returns in the receiver field only."""
|
||||||
first_name: str
|
first_name: str
|
||||||
"""First name of the user or bot."""
|
"""First name of the user or bot."""
|
||||||
last_name: Optional[str] = None
|
last_name: str | None = None
|
||||||
"""Last name of the user or bot."""
|
"""Last name of the user or bot."""
|
||||||
username: Optional[str] = None
|
username: str | None = None
|
||||||
"""Username of the user or bot."""
|
"""Username of the user or bot."""
|
||||||
language_code: Optional[str] = None
|
language_code: str | None = None
|
||||||
"""IETF language tag of the user's language. Returns in user field only."""
|
"""IETF language tag of the user's language. Returns in user field only."""
|
||||||
is_premium: Optional[bool] = None
|
is_premium: bool | None = None
|
||||||
"""True, if this user is a Telegram Premium user."""
|
"""True, if this user is a Telegram Premium user."""
|
||||||
added_to_attachment_menu: Optional[bool] = None
|
added_to_attachment_menu: bool | None = None
|
||||||
"""True, if this user added the bot to the attachment menu."""
|
"""True, if this user added the bot to the attachment menu."""
|
||||||
allows_write_to_pm: Optional[bool] = None
|
allows_write_to_pm: bool | None = None
|
||||||
"""True, if this user allowed the bot to message them."""
|
"""True, if this user allowed the bot to message them."""
|
||||||
photo_url: Optional[str] = None
|
photo_url: str | None = None
|
||||||
"""URL of the user’s profile photo. The photo can be in .jpeg or .svg formats.
|
"""URL of the user’s profile photo. The photo can be in .jpeg or .svg formats.
|
||||||
Only returned for Web Apps launched from the attachment menu."""
|
Only returned for Web Apps launched from the attachment menu."""
|
||||||
|
|
||||||
|
|
@ -73,33 +74,33 @@ class WebAppInitData(TelegramObject):
|
||||||
Source: https://core.telegram.org/bots/webapps#webappinitdata
|
Source: https://core.telegram.org/bots/webapps#webappinitdata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query_id: Optional[str] = None
|
query_id: str | None = None
|
||||||
"""A unique identifier for the Web App session, required for sending messages
|
"""A unique identifier for the Web App session, required for sending messages
|
||||||
via the answerWebAppQuery method."""
|
via the answerWebAppQuery method."""
|
||||||
user: Optional[WebAppUser] = None
|
user: WebAppUser | None = None
|
||||||
"""An object containing data about the current user."""
|
"""An object containing data about the current user."""
|
||||||
receiver: Optional[WebAppUser] = None
|
receiver: WebAppUser | None = None
|
||||||
"""An object containing data about the chat partner of the current user in the chat where
|
"""An object containing data about the chat partner of the current user in the chat where
|
||||||
the bot was launched via the attachment menu.
|
the bot was launched via the attachment menu.
|
||||||
Returned only for Web Apps launched via the attachment menu."""
|
Returned only for Web Apps launched via the attachment menu."""
|
||||||
chat: Optional[WebAppChat] = None
|
chat: WebAppChat | None = None
|
||||||
"""An object containing data about the chat where the bot was launched via the attachment menu.
|
"""An object containing data about the chat where the bot was launched via the attachment menu.
|
||||||
Returned for supergroups, channels, and group chats – only for Web Apps launched via the
|
Returned for supergroups, channels, and group chats – only for Web Apps launched via the
|
||||||
attachment menu."""
|
attachment menu."""
|
||||||
chat_type: Optional[str] = None
|
chat_type: str | None = None
|
||||||
"""Type of the chat from which the Web App was opened.
|
"""Type of the chat from which the Web App was opened.
|
||||||
Can be either “sender” for a private chat with the user opening the link,
|
Can be either “sender” for a private chat with the user opening the link,
|
||||||
“private”, “group”, “supergroup”, or “channel”.
|
“private”, “group”, “supergroup”, or “channel”.
|
||||||
Returned only for Web Apps launched from direct links."""
|
Returned only for Web Apps launched from direct links."""
|
||||||
chat_instance: Optional[str] = None
|
chat_instance: str | None = None
|
||||||
"""Global identifier, uniquely corresponding to the chat from which the Web App was opened.
|
"""Global identifier, uniquely corresponding to the chat from which the Web App was opened.
|
||||||
Returned only for Web Apps launched from a direct link."""
|
Returned only for Web Apps launched from a direct link."""
|
||||||
start_param: Optional[str] = None
|
start_param: str | None = None
|
||||||
"""The value of the startattach parameter, passed via link.
|
"""The value of the startattach parameter, passed via link.
|
||||||
Only returned for Web Apps when launched from the attachment menu via link.
|
Only returned for Web Apps when launched from the attachment menu via link.
|
||||||
The value of the start_param parameter will also be passed in the GET-parameter
|
The value of the start_param parameter will also be passed in the GET-parameter
|
||||||
tgWebAppStartParam, so the Web App can load the correct interface right away."""
|
tgWebAppStartParam, so the Web App can load the correct interface right away."""
|
||||||
can_send_after: Optional[int] = None
|
can_send_after: int | None = None
|
||||||
"""Time in seconds, after which a message can be sent via the answerWebAppQuery method."""
|
"""Time in seconds, after which a message can be sent via the answerWebAppQuery method."""
|
||||||
auth_date: datetime
|
auth_date: datetime
|
||||||
"""Unix time when the form was opened."""
|
"""Unix time when the form was opened."""
|
||||||
|
|
@ -132,7 +133,9 @@ def check_webapp_signature(token: str, init_data: str) -> bool:
|
||||||
)
|
)
|
||||||
secret_key = hmac.new(key=b"WebAppData", msg=token.encode(), digestmod=hashlib.sha256)
|
secret_key = hmac.new(key=b"WebAppData", msg=token.encode(), digestmod=hashlib.sha256)
|
||||||
calculated_hash = hmac.new(
|
calculated_hash = hmac.new(
|
||||||
key=secret_key.digest(), msg=data_check_string.encode(), digestmod=hashlib.sha256
|
key=secret_key.digest(),
|
||||||
|
msg=data_check_string.encode(),
|
||||||
|
digestmod=hashlib.sha256,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
return hmac.compare_digest(calculated_hash, hash_)
|
return hmac.compare_digest(calculated_hash, hash_)
|
||||||
|
|
||||||
|
|
@ -180,4 +183,5 @@ def safe_parse_webapp_init_data(
|
||||||
"""
|
"""
|
||||||
if check_webapp_signature(token, init_data):
|
if check_webapp_signature(token, init_data):
|
||||||
return parse_webapp_init_data(init_data, loads=loads)
|
return parse_webapp_init_data(init_data, loads=loads)
|
||||||
raise ValueError("Invalid init data signature")
|
msg = "Invalid init data signature"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,15 @@ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
|
||||||
from .web_app import WebAppInitData, parse_webapp_init_data
|
from .web_app import WebAppInitData, parse_webapp_init_data
|
||||||
|
|
||||||
PRODUCTION_PUBLIC_KEY = bytes.fromhex(
|
PRODUCTION_PUBLIC_KEY = bytes.fromhex(
|
||||||
"e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d"
|
"e7bf03a2fa4602af4580703d88dda5bb59f32ed8b02a56c187fe7d34caed242d",
|
||||||
)
|
)
|
||||||
TEST_PUBLIC_KEY = bytes.fromhex("40055058a4ee38156a06562e52eece92a771bcd8346a8c4615cb7376eddf72ec")
|
TEST_PUBLIC_KEY = bytes.fromhex("40055058a4ee38156a06562e52eece92a771bcd8346a8c4615cb7376eddf72ec")
|
||||||
|
|
||||||
|
|
||||||
def check_webapp_signature(
|
def check_webapp_signature(
|
||||||
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY
|
bot_id: int,
|
||||||
|
init_data: str,
|
||||||
|
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check incoming WebApp init data signature without bot token using only bot id.
|
Check incoming WebApp init data signature without bot token using only bot id.
|
||||||
|
|
@ -49,13 +51,16 @@ def check_webapp_signature(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
public_key.verify(signature, message)
|
public_key.verify(signature, message)
|
||||||
return True
|
|
||||||
except InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def safe_check_webapp_init_data_from_signature(
|
def safe_check_webapp_init_data_from_signature(
|
||||||
bot_id: int, init_data: str, public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY
|
bot_id: int,
|
||||||
|
init_data: str,
|
||||||
|
public_key_bytes: bytes = PRODUCTION_PUBLIC_KEY,
|
||||||
) -> WebAppInitData:
|
) -> WebAppInitData:
|
||||||
"""
|
"""
|
||||||
Validate raw WebApp init data using only bot id and return it as WebAppInitData object
|
Validate raw WebApp init data using only bot id and return it as WebAppInitData object
|
||||||
|
|
@ -67,4 +72,5 @@ def safe_check_webapp_init_data_from_signature(
|
||||||
"""
|
"""
|
||||||
if check_webapp_signature(bot_id, init_data, public_key_bytes):
|
if check_webapp_signature(bot_id, init_data, public_key_bytes):
|
||||||
return parse_webapp_init_data(init_data)
|
return parse_webapp_init_data(init_data)
|
||||||
raise ValueError("Invalid init data signature")
|
msg = "Invalid init data signature"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import asyncio
|
||||||
import secrets
|
import secrets
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Transport
|
from asyncio import Transport
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Set, Tuple, cast
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from aiohttp import JsonPayload, MultipartWriter, Payload, web
|
from aiohttp import JsonPayload, MultipartWriter, Payload, web
|
||||||
from aiohttp.typedefs import Handler
|
from aiohttp.typedefs import Handler
|
||||||
|
|
@ -12,9 +13,11 @@ from aiohttp.web_middlewares import middleware
|
||||||
from aiogram import Bot, Dispatcher, loggers
|
from aiogram import Bot, Dispatcher, loggers
|
||||||
from aiogram.methods import TelegramMethod
|
from aiogram.methods import TelegramMethod
|
||||||
from aiogram.methods.base import TelegramType
|
from aiogram.methods.base import TelegramType
|
||||||
from aiogram.types import InputFile
|
|
||||||
from aiogram.webhook.security import IPFilter
|
from aiogram.webhook.security import IPFilter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiogram.types import InputFile
|
||||||
|
|
||||||
|
|
||||||
def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any) -> None:
|
def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -42,7 +45,7 @@ def setup_application(app: Application, dispatcher: Dispatcher, /, **kwargs: Any
|
||||||
app.on_shutdown.append(on_shutdown)
|
app.on_shutdown.append(on_shutdown)
|
||||||
|
|
||||||
|
|
||||||
def check_ip(ip_filter: IPFilter, request: web.Request) -> Tuple[str, bool]:
|
def check_ip(ip_filter: IPFilter, request: web.Request) -> tuple[str, bool]:
|
||||||
# Try to resolve client IP over reverse proxy
|
# Try to resolve client IP over reverse proxy
|
||||||
if forwarded_for := request.headers.get("X-Forwarded-For", ""):
|
if forwarded_for := request.headers.get("X-Forwarded-For", ""):
|
||||||
# Get the left-most ip when there is multiple ips
|
# Get the left-most ip when there is multiple ips
|
||||||
|
|
@ -98,7 +101,7 @@ class BaseRequestHandler(ABC):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
self.handle_in_background = handle_in_background
|
self.handle_in_background = handle_in_background
|
||||||
self.data = data
|
self.data = data
|
||||||
self._background_feed_update_tasks: Set[asyncio.Task[Any]] = set()
|
self._background_feed_update_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
def register(self, app: Application, /, path: str, **kwargs: Any) -> None:
|
def register(self, app: Application, /, path: str, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -128,13 +131,12 @@ class BaseRequestHandler(ABC):
|
||||||
:param request:
|
:param request:
|
||||||
:return: Bot instance
|
:return: Bot instance
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
|
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _background_feed_update(self, bot: Bot, update: Dict[str, Any]) -> None:
|
async def _background_feed_update(self, bot: Bot, update: dict[str, Any]) -> None:
|
||||||
result = await self.dispatcher.feed_raw_update(bot=bot, update=update, **self.data)
|
result = await self.dispatcher.feed_raw_update(bot=bot, update=update, **self.data)
|
||||||
if isinstance(result, TelegramMethod):
|
if isinstance(result, TelegramMethod):
|
||||||
await self.dispatcher.silent_call_request(bot=bot, result=result)
|
await self.dispatcher.silent_call_request(bot=bot, result=result)
|
||||||
|
|
@ -142,15 +144,18 @@ class BaseRequestHandler(ABC):
|
||||||
async def _handle_request_background(self, bot: Bot, request: web.Request) -> web.Response:
|
async def _handle_request_background(self, bot: Bot, request: web.Request) -> web.Response:
|
||||||
feed_update_task = asyncio.create_task(
|
feed_update_task = asyncio.create_task(
|
||||||
self._background_feed_update(
|
self._background_feed_update(
|
||||||
bot=bot, update=await request.json(loads=bot.session.json_loads)
|
bot=bot,
|
||||||
)
|
update=await request.json(loads=bot.session.json_loads),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._background_feed_update_tasks.add(feed_update_task)
|
self._background_feed_update_tasks.add(feed_update_task)
|
||||||
feed_update_task.add_done_callback(self._background_feed_update_tasks.discard)
|
feed_update_task.add_done_callback(self._background_feed_update_tasks.discard)
|
||||||
return web.json_response({}, dumps=bot.session.json_dumps)
|
return web.json_response({}, dumps=bot.session.json_dumps)
|
||||||
|
|
||||||
def _build_response_writer(
|
def _build_response_writer(
|
||||||
self, bot: Bot, result: Optional[TelegramMethod[TelegramType]]
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
result: TelegramMethod[TelegramType] | None,
|
||||||
) -> Payload:
|
) -> Payload:
|
||||||
if not result:
|
if not result:
|
||||||
# we need to return something "empty"
|
# we need to return something "empty"
|
||||||
|
|
@ -166,7 +171,7 @@ class BaseRequestHandler(ABC):
|
||||||
payload = writer.append(result.__api_method__)
|
payload = writer.append(result.__api_method__)
|
||||||
payload.set_content_disposition("form-data", name="method")
|
payload.set_content_disposition("form-data", name="method")
|
||||||
|
|
||||||
files: Dict[str, InputFile] = {}
|
files: dict[str, InputFile] = {}
|
||||||
for key, value in result.model_dump(warnings=False).items():
|
for key, value in result.model_dump(warnings=False).items():
|
||||||
value = bot.session.prepare_value(value, bot=bot, files=files)
|
value = bot.session.prepare_value(value, bot=bot, files=files)
|
||||||
if not value:
|
if not value:
|
||||||
|
|
@ -185,7 +190,7 @@ class BaseRequestHandler(ABC):
|
||||||
return writer
|
return writer
|
||||||
|
|
||||||
async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response:
|
async def _handle_request(self, bot: Bot, request: web.Request) -> web.Response:
|
||||||
result: Optional[TelegramMethod[Any]] = await self.dispatcher.feed_webhook_update(
|
result: TelegramMethod[Any] | None = await self.dispatcher.feed_webhook_update(
|
||||||
bot,
|
bot,
|
||||||
await request.json(loads=bot.session.json_loads),
|
await request.json(loads=bot.session.json_loads),
|
||||||
**self.data,
|
**self.data,
|
||||||
|
|
@ -209,7 +214,7 @@ class SimpleRequestHandler(BaseRequestHandler):
|
||||||
dispatcher: Dispatcher,
|
dispatcher: Dispatcher,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
handle_in_background: bool = True,
|
handle_in_background: bool = True,
|
||||||
secret_token: Optional[str] = None,
|
secret_token: str | None = None,
|
||||||
**data: Any,
|
**data: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -244,7 +249,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
|
||||||
self,
|
self,
|
||||||
dispatcher: Dispatcher,
|
dispatcher: Dispatcher,
|
||||||
handle_in_background: bool = True,
|
handle_in_background: bool = True,
|
||||||
bot_settings: Optional[Dict[str, Any]] = None,
|
bot_settings: dict[str, Any] | None = None,
|
||||||
**data: Any,
|
**data: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -265,7 +270,7 @@ class TokenBasedRequestHandler(BaseRequestHandler):
|
||||||
if bot_settings is None:
|
if bot_settings is None:
|
||||||
bot_settings = {}
|
bot_settings = {}
|
||||||
self.bot_settings = bot_settings
|
self.bot_settings = bot_settings
|
||||||
self.bots: Dict[str, Bot] = {}
|
self.bots: dict[str, Bot] = {}
|
||||||
|
|
||||||
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
|
def verify_secret(self, telegram_secret_token: str, bot: Bot) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
@ -283,7 +288,8 @@ class TokenBasedRequestHandler(BaseRequestHandler):
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
"""
|
"""
|
||||||
if "{bot_token}" not in path:
|
if "{bot_token}" not in path:
|
||||||
raise ValueError("Path should contains '{bot_token}' substring")
|
msg = "Path should contains '{bot_token}' substring"
|
||||||
|
raise ValueError(msg)
|
||||||
super().register(app, path=path, **kwargs)
|
super().register(app, path=path, **kwargs)
|
||||||
|
|
||||||
async def resolve_bot(self, request: web.Request) -> Bot:
|
async def resolve_bot(self, request: web.Request) -> Bot:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from ipaddress import IPv4Address, IPv4Network
|
from ipaddress import IPv4Address, IPv4Network
|
||||||
from typing import Optional, Sequence, Set, Union
|
|
||||||
|
|
||||||
DEFAULT_TELEGRAM_NETWORKS = [
|
DEFAULT_TELEGRAM_NETWORKS = [
|
||||||
IPv4Network("149.154.160.0/20"),
|
IPv4Network("149.154.160.0/20"),
|
||||||
|
|
@ -8,17 +8,17 @@ DEFAULT_TELEGRAM_NETWORKS = [
|
||||||
|
|
||||||
|
|
||||||
class IPFilter:
|
class IPFilter:
|
||||||
def __init__(self, ips: Optional[Sequence[Union[str, IPv4Network, IPv4Address]]] = None):
|
def __init__(self, ips: Sequence[str | IPv4Network | IPv4Address] | None = None):
|
||||||
self._allowed_ips: Set[IPv4Address] = set()
|
self._allowed_ips: set[IPv4Address] = set()
|
||||||
|
|
||||||
if ips:
|
if ips:
|
||||||
self.allow(*ips)
|
self.allow(*ips)
|
||||||
|
|
||||||
def allow(self, *ips: Union[str, IPv4Network, IPv4Address]) -> None:
|
def allow(self, *ips: str | IPv4Network | IPv4Address) -> None:
|
||||||
for ip in ips:
|
for ip in ips:
|
||||||
self.allow_ip(ip)
|
self.allow_ip(ip)
|
||||||
|
|
||||||
def allow_ip(self, ip: Union[str, IPv4Network, IPv4Address]) -> None:
|
def allow_ip(self, ip: str | IPv4Network | IPv4Address) -> None:
|
||||||
if isinstance(ip, str):
|
if isinstance(ip, str):
|
||||||
ip = IPv4Network(ip) if "/" in ip else IPv4Address(ip)
|
ip = IPv4Network(ip) if "/" in ip else IPv4Address(ip)
|
||||||
if isinstance(ip, IPv4Address):
|
if isinstance(ip, IPv4Address):
|
||||||
|
|
@ -26,16 +26,17 @@ class IPFilter:
|
||||||
elif isinstance(ip, IPv4Network):
|
elif isinstance(ip, IPv4Network):
|
||||||
self._allowed_ips.update(ip.hosts())
|
self._allowed_ips.update(ip.hosts())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid type of ipaddress: {type(ip)} ('{ip}')")
|
msg = f"Invalid type of ipaddress: {type(ip)} ('{ip}')"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls) -> "IPFilter":
|
def default(cls) -> "IPFilter":
|
||||||
return cls(DEFAULT_TELEGRAM_NETWORKS)
|
return cls(DEFAULT_TELEGRAM_NETWORKS)
|
||||||
|
|
||||||
def check(self, ip: Union[str, IPv4Address]) -> bool:
|
def check(self, ip: str | IPv4Address) -> bool:
|
||||||
if not isinstance(ip, IPv4Address):
|
if not isinstance(ip, IPv4Address):
|
||||||
ip = IPv4Address(ip)
|
ip = IPv4Address(ip)
|
||||||
return ip in self._allowed_ips
|
return ip in self._allowed_ips
|
||||||
|
|
||||||
def __contains__(self, item: Union[str, IPv4Address]) -> bool:
|
def __contains__(self, item: str | IPv4Address) -> bool:
|
||||||
return self.check(item)
|
return self.check(item)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from aiogram import Router
|
from aiogram import Router
|
||||||
from aiogram.filters import Filter
|
from aiogram.filters import Filter
|
||||||
|
|
@ -8,7 +8,7 @@ router = Router(name=__name__)
|
||||||
|
|
||||||
|
|
||||||
class HelloFilter(Filter):
|
class HelloFilter(Filter):
|
||||||
def __init__(self, name: Optional[str] = None) -> None:
|
def __init__(self, name: str | None = None) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
|
|
@ -16,7 +16,7 @@ class HelloFilter(Filter):
|
||||||
message: Message,
|
message: Message,
|
||||||
event_from_user: User,
|
event_from_user: User,
|
||||||
# Filters also can accept keyword parameters like in handlers
|
# Filters also can accept keyword parameters like in handlers
|
||||||
) -> Union[bool, Dict[str, Any]]:
|
) -> bool | dict[str, Any]:
|
||||||
if message.text.casefold() == "hello":
|
if message.text.casefold() == "hello":
|
||||||
# Returning a dictionary that will update the context data
|
# Returning a dictionary that will update the context data
|
||||||
return {"name": event_from_user.mention_html(name=self.name)}
|
return {"name": event_from_user.mention_html(name=self.name)}
|
||||||
|
|
@ -25,6 +25,7 @@ class HelloFilter(Filter):
|
||||||
|
|
||||||
@router.message(HelloFilter())
|
@router.message(HelloFilter())
|
||||||
async def my_handler(
|
async def my_handler(
|
||||||
message: Message, name: str # Now we can accept "name" as named parameter
|
message: Message,
|
||||||
|
name: str, # Now we can accept "name" as named parameter
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return message.answer("Hello, {name}!".format(name=name))
|
return message.answer(f"Hello, {name}!")
|
||||||
|
|
|
||||||
|
|
@ -75,11 +75,13 @@ async def handle_set_age(message: types.Message, command: CommandObject) -> None
|
||||||
# To get the command arguments you can use `command.args` property.
|
# To get the command arguments you can use `command.args` property.
|
||||||
age = command.args
|
age = command.args
|
||||||
if not age:
|
if not age:
|
||||||
raise InvalidAge("No age provided. Please provide your age as a command argument.")
|
msg = "No age provided. Please provide your age as a command argument."
|
||||||
|
raise InvalidAge(msg)
|
||||||
|
|
||||||
# If the age is invalid, raise an exception.
|
# If the age is invalid, raise an exception.
|
||||||
if not age.isdigit():
|
if not age.isdigit():
|
||||||
raise InvalidAge("Age should be a number")
|
msg = "Age should be a number"
|
||||||
|
raise InvalidAge(msg)
|
||||||
|
|
||||||
# If the age is valid, send a message to the user.
|
# If the age is valid, send a message to the user.
|
||||||
age = int(age)
|
age = int(age)
|
||||||
|
|
@ -95,7 +97,8 @@ async def handle_set_name(message: types.Message, command: CommandObject) -> Non
|
||||||
# To get the command arguments you can use `command.args` property.
|
# To get the command arguments you can use `command.args` property.
|
||||||
name = command.args
|
name = command.args
|
||||||
if not name:
|
if not name:
|
||||||
raise InvalidName("Invalid name. Please provide your name as a command argument.")
|
msg = "Invalid name. Please provide your name as a command argument."
|
||||||
|
raise InvalidName(msg)
|
||||||
|
|
||||||
# If the name is valid, send a message to the user.
|
# If the name is valid, send a message to the user.
|
||||||
await message.reply(text=f"Your name is {name}")
|
await message.reply(text=f"Your name is {name}")
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from aiogram import Bot, Dispatcher, F, Router, html
|
from aiogram import Bot, Dispatcher, F, Router, html
|
||||||
from aiogram.client.default import DefaultBotProperties
|
from aiogram.client.default import DefaultBotProperties
|
||||||
|
|
@ -66,7 +66,7 @@ async def process_name(message: Message, state: FSMContext) -> None:
|
||||||
[
|
[
|
||||||
KeyboardButton(text="Yes"),
|
KeyboardButton(text="Yes"),
|
||||||
KeyboardButton(text="No"),
|
KeyboardButton(text="No"),
|
||||||
]
|
],
|
||||||
],
|
],
|
||||||
resize_keyboard=True,
|
resize_keyboard=True,
|
||||||
),
|
),
|
||||||
|
|
@ -106,13 +106,13 @@ async def process_language(message: Message, state: FSMContext) -> None:
|
||||||
|
|
||||||
if message.text.casefold() == "python":
|
if message.text.casefold() == "python":
|
||||||
await message.reply(
|
await message.reply(
|
||||||
"Python, you say? That's the language that makes my circuits light up! 😉"
|
"Python, you say? That's the language that makes my circuits light up! 😉",
|
||||||
)
|
)
|
||||||
|
|
||||||
await show_summary(message=message, data=data)
|
await show_summary(message=message, data=data)
|
||||||
|
|
||||||
|
|
||||||
async def show_summary(message: Message, data: Dict[str, Any], positive: bool = True) -> None:
|
async def show_summary(message: Message, data: dict[str, Any], positive: bool = True) -> None:
|
||||||
name = data["name"]
|
name = data["name"]
|
||||||
language = data.get("language", "<something unexpected>")
|
language = data.get("language", "<something unexpected>")
|
||||||
text = f"I'll keep in mind that, {html.quote(name)}, "
|
text = f"I'll keep in mind that, {html.quote(name)}, "
|
||||||
|
|
@ -124,7 +124,7 @@ async def show_summary(message: Message, data: Dict[str, Any], positive: bool =
|
||||||
await message.answer(text=text, reply_markup=ReplyKeyboardRemove())
|
await message.answer(text=text, reply_markup=ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main() -> None:
|
||||||
# Initialize Bot instance with default bot properties which will be passed to all API calls
|
# Initialize Bot instance with default bot properties which will be passed to all API calls
|
||||||
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from typing import Any, Dict, Union
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from finite_state_machine import form_router
|
from finite_state_machine import form_router
|
||||||
|
|
@ -34,7 +34,7 @@ REDIS_DSN = "redis://127.0.0.1:6479"
|
||||||
OTHER_BOTS_URL = f"{BASE_URL}{OTHER_BOTS_PATH}"
|
OTHER_BOTS_URL = f"{BASE_URL}{OTHER_BOTS_PATH}"
|
||||||
|
|
||||||
|
|
||||||
def is_bot_token(value: str) -> Union[bool, Dict[str, Any]]:
|
def is_bot_token(value: str) -> bool | dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
validate_token(value)
|
validate_token(value)
|
||||||
except TokenValidationError:
|
except TokenValidationError:
|
||||||
|
|
@ -54,11 +54,11 @@ async def command_add_bot(message: Message, command: CommandObject, bot: Bot) ->
|
||||||
return await message.answer(f"Bot @{bot_user.username} successful added")
|
return await message.answer(f"Bot @{bot_user.username} successful added")
|
||||||
|
|
||||||
|
|
||||||
async def on_startup(dispatcher: Dispatcher, bot: Bot):
|
async def on_startup(dispatcher: Dispatcher, bot: Bot) -> None:
|
||||||
await bot.set_webhook(f"{BASE_URL}{MAIN_BOT_PATH}")
|
await bot.set_webhook(f"{BASE_URL}{MAIN_BOT_PATH}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||||
session = AiohttpSession()
|
session = AiohttpSession()
|
||||||
bot_settings = {"session": session, "parse_mode": ParseMode.HTML}
|
bot_settings = {"session": session, "parse_mode": ParseMode.HTML}
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,4 @@ class MyFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
@router.message(MyFilter("hello"))
|
@router.message(MyFilter("hello"))
|
||||||
async def my_handler(message: Message): ...
|
async def my_handler(message: Message) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -263,7 +263,7 @@ quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
|
||||||
|
|
||||||
|
|
||||||
@quiz_router.message(Command("start"))
|
@quiz_router.message(Command("start"))
|
||||||
async def command_start(message: Message, scenes: ScenesManager):
|
async def command_start(message: Message, scenes: ScenesManager) -> None:
|
||||||
await scenes.close()
|
await scenes.close()
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Hi! This is a quiz bot. To start the quiz, use the /quiz command.",
|
"Hi! This is a quiz bot. To start the quiz, use the /quiz command.",
|
||||||
|
|
@ -271,7 +271,7 @@ async def command_start(message: Message, scenes: ScenesManager):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_dispatcher():
|
def create_dispatcher() -> Dispatcher:
|
||||||
# Event isolation is needed to correctly handle fast user responses
|
# Event isolation is needed to correctly handle fast user responses
|
||||||
dispatcher = Dispatcher(
|
dispatcher = Dispatcher(
|
||||||
events_isolation=SimpleEventIsolation(),
|
events_isolation=SimpleEventIsolation(),
|
||||||
|
|
@ -288,7 +288,7 @@ def create_dispatcher():
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main() -> None:
|
||||||
dp = create_dispatcher()
|
dp = create_dispatcher()
|
||||||
bot = Bot(token=TOKEN)
|
bot = Bot(token=TOKEN)
|
||||||
await dp.start_polling(bot)
|
await dp.start_polling(bot)
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,11 @@ class CancellableScene(Scene):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit())
|
@on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit())
|
||||||
async def handle_cancel(self, message: Message):
|
async def handle_cancel(self, message: Message) -> None:
|
||||||
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
|
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
|
||||||
|
|
||||||
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
|
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
|
||||||
async def handle_back(self, message: Message):
|
async def handle_back(self, message: Message) -> None:
|
||||||
await message.answer("Back.")
|
await message.answer("Back.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,7 +48,7 @@ class LanguageScene(CancellableScene, state="language"):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@on.message.enter()
|
@on.message.enter()
|
||||||
async def on_enter(self, message: Message):
|
async def on_enter(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"What language do you prefer?",
|
"What language do you prefer?",
|
||||||
reply_markup=ReplyKeyboardMarkup(
|
reply_markup=ReplyKeyboardMarkup(
|
||||||
|
|
@ -58,14 +58,14 @@ class LanguageScene(CancellableScene, state="language"):
|
||||||
)
|
)
|
||||||
|
|
||||||
@on.message(F.text.casefold() == "python", after=After.exit())
|
@on.message(F.text.casefold() == "python", after=After.exit())
|
||||||
async def process_python(self, message: Message):
|
async def process_python(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Python, you say? That's the language that makes my circuits light up! 😉"
|
"Python, you say? That's the language that makes my circuits light up! 😉",
|
||||||
)
|
)
|
||||||
await self.input_language(message)
|
await self.input_language(message)
|
||||||
|
|
||||||
@on.message(after=After.exit())
|
@on.message(after=After.exit())
|
||||||
async def input_language(self, message: Message):
|
async def input_language(self, message: Message) -> None:
|
||||||
data: FSMData = await self.wizard.get_data()
|
data: FSMData = await self.wizard.get_data()
|
||||||
await self.show_results(message, language=message.text, **data)
|
await self.show_results(message, language=message.text, **data)
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@on.message.enter()
|
@on.message.enter()
|
||||||
async def on_enter(self, message: Message):
|
async def on_enter(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Did you like to write bots?",
|
"Did you like to write bots?",
|
||||||
reply_markup=ReplyKeyboardMarkup(
|
reply_markup=ReplyKeyboardMarkup(
|
||||||
|
|
@ -96,18 +96,18 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
|
||||||
)
|
)
|
||||||
|
|
||||||
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
|
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
|
||||||
async def process_like_write_bots(self, message: Message):
|
async def process_like_write_bots(self, message: Message) -> None:
|
||||||
await message.reply("Cool! I'm too!")
|
await message.reply("Cool! I'm too!")
|
||||||
|
|
||||||
@on.message(F.text.casefold() == "no", after=After.exit())
|
@on.message(F.text.casefold() == "no", after=After.exit())
|
||||||
async def process_dont_like_write_bots(self, message: Message):
|
async def process_dont_like_write_bots(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Not bad not terrible.\nSee you soon.",
|
"Not bad not terrible.\nSee you soon.",
|
||||||
reply_markup=ReplyKeyboardRemove(),
|
reply_markup=ReplyKeyboardRemove(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@on.message()
|
@on.message()
|
||||||
async def input_like_bots(self, message: Message):
|
async def input_like_bots(self, message: Message) -> None:
|
||||||
await message.answer("I don't understand you :(")
|
await message.answer("I don't understand you :(")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -117,25 +117,25 @@ class NameScene(CancellableScene, state="name"):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@on.message.enter() # Marker for handler that should be called when a user enters the scene.
|
@on.message.enter() # Marker for handler that should be called when a user enters the scene.
|
||||||
async def on_enter(self, message: Message):
|
async def on_enter(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Hi there! What's your name?",
|
"Hi there! What's your name?",
|
||||||
reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True),
|
reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@on.callback_query.enter() # different types of updates that start the scene also supported.
|
@on.callback_query.enter() # different types of updates that start the scene also supported.
|
||||||
async def on_enter_callback(self, callback_query: CallbackQuery):
|
async def on_enter_callback(self, callback_query: CallbackQuery) -> None:
|
||||||
await callback_query.answer()
|
await callback_query.answer()
|
||||||
await self.on_enter(callback_query.message)
|
await self.on_enter(callback_query.message)
|
||||||
|
|
||||||
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
|
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
|
||||||
async def on_leave(self, message: Message):
|
async def on_leave(self, message: Message) -> None:
|
||||||
data: FSMData = await self.wizard.get_data()
|
data: FSMData = await self.wizard.get_data()
|
||||||
name = data.get("name", "Anonymous")
|
name = data.get("name", "Anonymous")
|
||||||
await message.answer(f"Nice to meet you, {html.quote(name)}!")
|
await message.answer(f"Nice to meet you, {html.quote(name)}!")
|
||||||
|
|
||||||
@on.message(after=After.goto(LikeBotsScene))
|
@on.message(after=After.goto(LikeBotsScene))
|
||||||
async def input_name(self, message: Message):
|
async def input_name(self, message: Message) -> None:
|
||||||
await self.wizard.update_data(name=message.text)
|
await self.wizard.update_data(name=message.text)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -154,22 +154,22 @@ class DefaultScene(
|
||||||
start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene))
|
start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene))
|
||||||
|
|
||||||
@on.message(Command("demo"))
|
@on.message(Command("demo"))
|
||||||
async def demo(self, message: Message):
|
async def demo(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Demo started",
|
"Demo started",
|
||||||
reply_markup=InlineKeyboardMarkup(
|
reply_markup=InlineKeyboardMarkup(
|
||||||
inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]]
|
inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@on.callback_query(F.data == "start", after=After.goto(NameScene))
|
@on.callback_query(F.data == "start", after=After.goto(NameScene))
|
||||||
async def demo_callback(self, callback_query: CallbackQuery):
|
async def demo_callback(self, callback_query: CallbackQuery) -> None:
|
||||||
await callback_query.answer(cache_time=0)
|
await callback_query.answer(cache_time=0)
|
||||||
await callback_query.message.delete_reply_markup()
|
await callback_query.message.delete_reply_markup()
|
||||||
|
|
||||||
@on.message.enter() # Mark that this handler should be called when a user enters the scene.
|
@on.message.enter() # Mark that this handler should be called when a user enters the scene.
|
||||||
@on.message()
|
@on.message()
|
||||||
async def default_handler(self, message: Message):
|
async def default_handler(self, message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Start demo?\nYou can also start demo via command /demo",
|
"Start demo?\nYou can also start demo via command /demo",
|
||||||
reply_markup=ReplyKeyboardMarkup(
|
reply_markup=ReplyKeyboardMarkup(
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ async def command_start_handler(message: Message) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
f"Hello, {hbold(message.from_user.full_name)}!",
|
f"Hello, {hbold(message.from_user.full_name)}!",
|
||||||
reply_markup=InlineKeyboardMarkup(
|
reply_markup=InlineKeyboardMarkup(
|
||||||
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]]
|
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -43,7 +43,7 @@ async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None:
|
||||||
await bot.send_message(
|
await bot.send_message(
|
||||||
chat_member.chat.id,
|
chat_member.chat.id,
|
||||||
f"Member {hcode(chat_member.from_user.id)} was changed "
|
f"Member {hcode(chat_member.from_user.id)} was changed "
|
||||||
+ f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}",
|
f"from {chat_member.old_chat_member.status} to {chat_member.new_chat_member.status}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ my_router = Router()
|
||||||
|
|
||||||
|
|
||||||
@my_router.message(CommandStart())
|
@my_router.message(CommandStart())
|
||||||
async def command_start(message: Message, bot: Bot, base_url: str):
|
async def command_start(message: Message, bot: Bot, base_url: str) -> None:
|
||||||
await bot.set_chat_menu_button(
|
await bot.set_chat_menu_button(
|
||||||
chat_id=message.chat.id,
|
chat_id=message.chat.id,
|
||||||
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
|
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
|
||||||
|
|
@ -21,28 +21,29 @@ async def command_start(message: Message, bot: Bot, base_url: str):
|
||||||
|
|
||||||
|
|
||||||
@my_router.message(Command("webview"))
|
@my_router.message(Command("webview"))
|
||||||
async def command_webview(message: Message, base_url: str):
|
async def command_webview(message: Message, base_url: str) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Good. Now you can try to send it via Webview",
|
"Good. Now you can try to send it via Webview",
|
||||||
reply_markup=InlineKeyboardMarkup(
|
reply_markup=InlineKeyboardMarkup(
|
||||||
inline_keyboard=[
|
inline_keyboard=[
|
||||||
[
|
[
|
||||||
InlineKeyboardButton(
|
InlineKeyboardButton(
|
||||||
text="Open Webview", web_app=WebAppInfo(url=f"{base_url}/demo")
|
text="Open Webview",
|
||||||
)
|
web_app=WebAppInfo(url=f"{base_url}/demo"),
|
||||||
]
|
),
|
||||||
]
|
],
|
||||||
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@my_router.message(~F.message.via_bot) # Echo to all messages except messages via bot
|
@my_router.message(~F.message.via_bot) # Echo to all messages except messages via bot
|
||||||
async def echo_all(message: Message, base_url: str):
|
async def echo_all(message: Message, base_url: str) -> None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Test webview",
|
"Test webview",
|
||||||
reply_markup=InlineKeyboardMarkup(
|
reply_markup=InlineKeyboardMarkup(
|
||||||
inline_keyboard=[
|
inline_keyboard=[
|
||||||
[InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))]
|
[InlineKeyboardButton(text="Open", web_app=WebAppInfo(url=f"{base_url}/demo"))],
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,14 @@ TOKEN = getenv("BOT_TOKEN")
|
||||||
APP_BASE_URL = getenv("APP_BASE_URL")
|
APP_BASE_URL = getenv("APP_BASE_URL")
|
||||||
|
|
||||||
|
|
||||||
async def on_startup(bot: Bot, base_url: str):
|
async def on_startup(bot: Bot, base_url: str) -> None:
|
||||||
await bot.set_webhook(f"{base_url}/webhook")
|
await bot.set_webhook(f"{base_url}/webhook")
|
||||||
await bot.set_chat_menu_button(
|
await bot.set_chat_menu_button(
|
||||||
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo"))
|
menu_button=MenuButtonWebApp(text="Open Menu", web_app=WebAppInfo(url=f"{base_url}/demo")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
bot = Bot(token=TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||||
dispatcher = Dispatcher()
|
dispatcher = Dispatcher()
|
||||||
dispatcher["base_url"] = APP_BASE_URL
|
dispatcher["base_url"] = APP_BASE_URL
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from aiohttp.web_fileresponse import FileResponse
|
from aiohttp.web_fileresponse import FileResponse
|
||||||
from aiohttp.web_request import Request
|
|
||||||
from aiohttp.web_response import json_response
|
from aiohttp.web_response import json_response
|
||||||
|
|
||||||
from aiogram import Bot
|
|
||||||
from aiogram.types import (
|
from aiogram.types import (
|
||||||
InlineKeyboardButton,
|
InlineKeyboardButton,
|
||||||
InlineKeyboardMarkup,
|
InlineKeyboardMarkup,
|
||||||
|
|
@ -14,12 +15,18 @@ from aiogram.types import (
|
||||||
)
|
)
|
||||||
from aiogram.utils.web_app import check_webapp_signature, safe_parse_webapp_init_data
|
from aiogram.utils.web_app import check_webapp_signature, safe_parse_webapp_init_data
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiohttp.web_request import Request
|
||||||
|
from aiohttp.web_response import Response
|
||||||
|
|
||||||
async def demo_handler(request: Request):
|
from aiogram import Bot
|
||||||
|
|
||||||
|
|
||||||
|
async def demo_handler(request: Request) -> FileResponse:
|
||||||
return FileResponse(Path(__file__).parent.resolve() / "demo.html")
|
return FileResponse(Path(__file__).parent.resolve() / "demo.html")
|
||||||
|
|
||||||
|
|
||||||
async def check_data_handler(request: Request):
|
async def check_data_handler(request: Request) -> Response:
|
||||||
bot: Bot = request.app["bot"]
|
bot: Bot = request.app["bot"]
|
||||||
|
|
||||||
data = await request.post()
|
data = await request.post()
|
||||||
|
|
@ -28,7 +35,7 @@ async def check_data_handler(request: Request):
|
||||||
return json_response({"ok": False, "err": "Unauthorized"}, status=401)
|
return json_response({"ok": False, "err": "Unauthorized"}, status=401)
|
||||||
|
|
||||||
|
|
||||||
async def send_message_handler(request: Request):
|
async def send_message_handler(request: Request) -> Response:
|
||||||
bot: Bot = request.app["bot"]
|
bot: Bot = request.app["bot"]
|
||||||
data = await request.post()
|
data = await request.post()
|
||||||
try:
|
try:
|
||||||
|
|
@ -44,11 +51,11 @@ async def send_message_handler(request: Request):
|
||||||
InlineKeyboardButton(
|
InlineKeyboardButton(
|
||||||
text="Open",
|
text="Open",
|
||||||
web_app=WebAppInfo(
|
web_app=WebAppInfo(
|
||||||
url=str(request.url.with_scheme("https").with_path("demo"))
|
url=str(request.url.with_scheme("https").with_path("demo")),
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
await bot.answer_web_app_query(
|
await bot.answer_web_app_query(
|
||||||
web_app_query_id=web_app_init_data.query_id,
|
web_app_query_id=web_app_init_data.query_id,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def create_parser() -> ArgumentParser:
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main() -> None:
|
||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
ns = parser.parse_args()
|
ns = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ build-backend = "hatchling.build"
|
||||||
name = "aiogram"
|
name = "aiogram"
|
||||||
description = 'Modern and fully asynchronous framework for Telegram Bot API'
|
description = 'Modern and fully asynchronous framework for Telegram Bot API'
|
||||||
readme = "README.rst"
|
readme = "README.rst"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Alex Root Junior", email = "jroot.junior@gmail.com" },
|
{ name = "Alex Root Junior", email = "jroot.junior@gmail.com" },
|
||||||
|
|
@ -30,7 +30,6 @@ classifiers = [
|
||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: System Administrators",
|
"Intended Audience :: System Administrators",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
|
@ -60,7 +59,7 @@ fast = [
|
||||||
"aiodns>=3.0.0",
|
"aiodns>=3.0.0",
|
||||||
]
|
]
|
||||||
redis = [
|
redis = [
|
||||||
"redis[hiredis]>=5.0.1,<5.3.0",
|
"redis[hiredis]>=6.2.0,<7",
|
||||||
]
|
]
|
||||||
mongo = [
|
mongo = [
|
||||||
"motor>=3.3.2,<3.7.0",
|
"motor>=3.3.2,<3.7.0",
|
||||||
|
|
@ -76,7 +75,7 @@ cli = [
|
||||||
"aiogram-cli>=1.1.0,<2.0.0",
|
"aiogram-cli>=1.1.0,<2.0.0",
|
||||||
]
|
]
|
||||||
signature = [
|
signature = [
|
||||||
"cryptography>=43.0.0",
|
"cryptography>=46.0.0",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest~=7.4.2",
|
"pytest~=7.4.2",
|
||||||
|
|
@ -88,8 +87,8 @@ test = [
|
||||||
"pytest-cov~=4.1.0",
|
"pytest-cov~=4.1.0",
|
||||||
"pytest-aiohttp~=1.0.5",
|
"pytest-aiohttp~=1.0.5",
|
||||||
"aresponses~=2.1.6",
|
"aresponses~=2.1.6",
|
||||||
"pytz~=2023.3",
|
"pytz~=2025.2",
|
||||||
"pycryptodomex~=3.19.0",
|
"pycryptodomex~=3.23.0",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"Sphinx~=8.0.2",
|
"Sphinx~=8.0.2",
|
||||||
|
|
@ -105,12 +104,12 @@ docs = [
|
||||||
"sphinxcontrib-towncrier~=0.4.0a0",
|
"sphinxcontrib-towncrier~=0.4.0a0",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"black~=24.4.2",
|
"black~=25.9.0",
|
||||||
"isort~=5.13.2",
|
"isort~=6.1.0",
|
||||||
"ruff~=0.5.1",
|
"ruff~=0.13.3",
|
||||||
"mypy~=1.10.0",
|
"mypy~=1.10.1",
|
||||||
"toml~=0.10.2",
|
"toml~=0.10.2",
|
||||||
"pre-commit~=3.5",
|
"pre-commit~=4.3.0",
|
||||||
"packaging~=24.1",
|
"packaging~=24.1",
|
||||||
"motor-types~=1.0.0b4",
|
"motor-types~=1.0.0b4",
|
||||||
]
|
]
|
||||||
|
|
@ -200,9 +199,8 @@ cov-mongo = [
|
||||||
]
|
]
|
||||||
view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html"
|
view-cov = "google-chrome-stable reports/py{matrix:python}/coverage/index.html"
|
||||||
|
|
||||||
|
|
||||||
[[tool.hatch.envs.test.matrix]]
|
[[tool.hatch.envs.test.matrix]]
|
||||||
python = ["39", "310", "311", "312", "313"]
|
python = ["310", "311", "312", "313"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 99
|
line-length = 99
|
||||||
|
|
@ -219,7 +217,6 @@ exclude = [
|
||||||
"scripts",
|
"scripts",
|
||||||
"*.egg-info",
|
"*.egg-info",
|
||||||
]
|
]
|
||||||
target-version = "py39"
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
|
|
@ -227,13 +224,13 @@ select = [
|
||||||
"C4",
|
"C4",
|
||||||
"E",
|
"E",
|
||||||
"F",
|
"F",
|
||||||
"T10",
|
|
||||||
"T20",
|
|
||||||
"Q",
|
"Q",
|
||||||
"RET",
|
"RET",
|
||||||
|
"T10",
|
||||||
|
"T20",
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"F401"
|
"F401",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
|
|
@ -280,7 +277,7 @@ exclude_lines = [
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = "pydantic.mypy"
|
plugins = "pydantic.mypy"
|
||||||
python_version = "3.9"
|
python_version = "3.10"
|
||||||
show_error_codes = true
|
show_error_codes = true
|
||||||
show_error_context = true
|
show_error_context = true
|
||||||
pretty = true
|
pretty = true
|
||||||
|
|
@ -315,7 +312,7 @@ disallow_untyped_defs = true
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 99
|
line-length = 99
|
||||||
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
|
target-version = ['py310', 'py311', 'py312', 'py313']
|
||||||
exclude = '''
|
exclude = '''
|
||||||
(
|
(
|
||||||
\.eggs
|
\.eggs
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
version: "3.9"
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
redis:
|
redis:
|
||||||
image: redis:6-alpine
|
image: redis:8-alpine
|
||||||
ports:
|
ports:
|
||||||
- "${REDIS_PORT-6379}:6379"
|
- "${REDIS_PORT-6379}:6379"
|
||||||
|
|
||||||
mongo:
|
mongo:
|
||||||
image: mongo:7.0.6
|
image: mongo:8.0.14
|
||||||
environment:
|
environment:
|
||||||
MONGO_INITDB_ROOT_USERNAME: mongo
|
MONGO_INITDB_ROOT_USERNAME: mongo
|
||||||
MONGO_INITDB_ROOT_PASSWORD: mongo
|
MONGO_INITDB_ROOT_PASSWORD: mongo
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,6 @@ class TestDataclassKwargs:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"py_version,expected",
|
"py_version,expected",
|
||||||
[
|
[
|
||||||
((3, 9, 0), ALL_VERSIONS),
|
|
||||||
((3, 9, 2), ALL_VERSIONS),
|
|
||||||
((3, 10, 2), PY_310),
|
((3, 10, 2), PY_310),
|
||||||
((3, 11, 0), PY_311),
|
((3, 11, 0), PY_311),
|
||||||
((4, 13, 0), LATEST_PY),
|
((4, 13, 0), LATEST_PY),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue