mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Beta 3 (#884)
* Rework middlewares, separate management to `MiddlewareManager` class * Rework middlewares * Added changes description for redis * Added changes description for redis * Fixed tests with Redis // aioredis replacement * Changed msg.<html/md>_text attributes behaviour * Added changelog for spoilers * Added possibility to get command magic result as handler arguments
This commit is contained in:
parent
930bca0876
commit
286cf39c8a
51 changed files with 1380 additions and 804 deletions
|
|
@ -2,20 +2,20 @@
|
|||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
rev: v4.2.0
|
||||
hooks:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-merge-conflict
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 21.8b0
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: &files '^(aiogram|tests|examples)'
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-isort
|
||||
rev: v5.9.3
|
||||
rev: v5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
additional_dependencies: [ toml ]
|
||||
|
|
|
|||
1
CHANGES/865.bugfix.rst
Normal file
1
CHANGES/865.bugfix.rst
Normal file
|
|
@ -0,0 +1 @@
|
|||
Added parsing of spoiler message entity
|
||||
2
CHANGES/874.misc.rst
Normal file
2
CHANGES/874.misc.rst
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
Changed :code:`Message.html_text` and :code:`Message.md_text` attributes behaviour when message has no text.
|
||||
The empty string will be used instead of raising error.
|
||||
1
CHANGES/882.misc.rst
Normal file
1
CHANGES/882.misc.rst
Normal file
|
|
@ -0,0 +1 @@
|
|||
Used `redis-py` instead of `aioredis` package in due to this packages was merged into single one
|
||||
3
CHANGES/883.misc.rst
Normal file
3
CHANGES/883.misc.rst
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
Solved common naming problem with middlewares that confusing too much developers
|
||||
- now you can't see the `middleware` and `middlewares` attributes at the same point
|
||||
because this functionality encapsulated to special interface.
|
||||
1
CHANGES/885.bugfix.rst
Normal file
1
CHANGES/885.bugfix.rst
Normal file
|
|
@ -0,0 +1 @@
|
|||
Fixed CallbackData factory parsing IntEnum's
|
||||
1
CHANGES/889.feature.rst
Normal file
1
CHANGES/889.feature.rst
Normal file
|
|
@ -0,0 +1 @@
|
|||
Added possibility to get command magic result as handler argument
|
||||
5
Makefile
5
Makefile
|
|
@ -47,7 +47,7 @@ help:
|
|||
|
||||
.PHONY: install
|
||||
install:
|
||||
poetry install -E fast -E redis -E proxy -E i18n -E docs
|
||||
poetry install -E fast -E redis -E proxy -E i18n -E docs --remove-untracked
|
||||
$(py) pre-commit install
|
||||
|
||||
.PHONY: clean
|
||||
|
|
@ -94,9 +94,6 @@ test: test-run-services
|
|||
test-coverage: test-run-services
|
||||
mkdir -p $(reports_dir)/tests/
|
||||
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
|
||||
|
||||
.PHONY: test-coverage-report
|
||||
test-coverage-report:
|
||||
$(py) coverage html -d $(reports_dir)/coverage
|
||||
|
||||
.PHONY: test-coverage-view
|
||||
|
|
|
|||
|
|
@ -3,22 +3,9 @@ from __future__ import annotations
|
|||
import abc
|
||||
import datetime
|
||||
import json
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Final,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Final, Optional, Type, Union, cast
|
||||
|
||||
from aiogram.exceptions import (
|
||||
RestartingTelegram,
|
||||
|
|
@ -36,26 +23,15 @@ from aiogram.exceptions import (
|
|||
|
||||
from ...methods import Response, TelegramMethod
|
||||
from ...methods.base import TelegramType
|
||||
from ...types import UNSET, TelegramObject
|
||||
from ...types import UNSET
|
||||
from ..telegram import PRODUCTION, TelegramAPIServer
|
||||
from .middlewares.base import BaseRequestMiddleware
|
||||
from .middlewares.manager import RequestMiddlewareManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bot import Bot
|
||||
|
||||
_JsonLoads = Callable[..., Any]
|
||||
_JsonDumps = Callable[..., str]
|
||||
NextRequestMiddlewareType = Callable[
|
||||
["Bot", TelegramMethod[TelegramObject]], Awaitable[Response[TelegramObject]]
|
||||
]
|
||||
|
||||
RequestMiddlewareType = Union[
|
||||
BaseRequestMiddleware,
|
||||
Callable[
|
||||
[NextRequestMiddlewareType, "Bot", TelegramMethod[TelegramType]],
|
||||
Awaitable[Response[TelegramType]],
|
||||
],
|
||||
]
|
||||
|
||||
DEFAULT_TIMEOUT: Final[float] = 60.0
|
||||
|
||||
|
|
@ -80,7 +56,7 @@ class BaseSession(abc.ABC):
|
|||
self.json_dumps = json_dumps
|
||||
self.timeout = timeout
|
||||
|
||||
self.middlewares: List[RequestMiddlewareType[TelegramObject]] = []
|
||||
self.middleware = RequestMiddlewareManager()
|
||||
|
||||
def check_response(
|
||||
self, method: TelegramMethod[TelegramType], status_code: int, content: str
|
||||
|
|
@ -185,19 +161,11 @@ class BaseSession(abc.ABC):
|
|||
return {k: self.clean_json(v) for k, v in value.items() if v is not None}
|
||||
return value
|
||||
|
||||
def middleware(
|
||||
self, middleware: RequestMiddlewareType[TelegramObject]
|
||||
) -> RequestMiddlewareType[TelegramObject]:
|
||||
self.middlewares.append(middleware)
|
||||
return middleware
|
||||
|
||||
async def __call__(
|
||||
self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET
|
||||
) -> TelegramType:
|
||||
middleware = partial(self.make_request, timeout=timeout)
|
||||
for m in reversed(self.middlewares):
|
||||
middleware = partial(m, middleware) # type: ignore
|
||||
return await middleware(bot, method)
|
||||
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
|
||||
return cast(TelegramType, await middleware(bot, method))
|
||||
|
||||
async def __aenter__(self) -> BaseSession:
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,15 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Union
|
||||
|
||||
from aiogram.methods import Response, TelegramMethod
|
||||
from aiogram.types import TelegramObject
|
||||
from aiogram.methods.base import TelegramType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...bot import Bot
|
||||
|
||||
|
||||
NextRequestMiddlewareType = Callable[
|
||||
["Bot", TelegramMethod[TelegramObject]], Awaitable[Response[TelegramObject]]
|
||||
["Bot", TelegramMethod[TelegramType]], Awaitable[Response[TelegramType]]
|
||||
]
|
||||
RequestMiddlewareType = Union[
|
||||
"BaseRequestMiddleware",
|
||||
Callable[
|
||||
[NextRequestMiddlewareType[TelegramType], "Bot", TelegramMethod[TelegramType]],
|
||||
Awaitable[Response[TelegramType]],
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -21,10 +29,10 @@ class BaseRequestMiddleware(ABC):
|
|||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
make_request: NextRequestMiddlewareType,
|
||||
make_request: NextRequestMiddlewareType[TelegramType],
|
||||
bot: "Bot",
|
||||
method: TelegramMethod[TelegramObject],
|
||||
) -> Response[TelegramObject]:
|
||||
method: TelegramMethod[TelegramType],
|
||||
) -> Response[TelegramType]:
|
||||
"""
|
||||
Execute middleware
|
||||
|
||||
|
|
|
|||
79
aiogram/client/session/middlewares/manager.py
Normal file
79
aiogram/client/session/middlewares/manager.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from aiogram.client.session.middlewares.base import (
|
||||
NextRequestMiddlewareType,
|
||||
RequestMiddlewareType,
|
||||
)
|
||||
from aiogram.methods import Response
|
||||
from aiogram.methods.base import TelegramMethod, TelegramType
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram import Bot
|
||||
|
||||
|
||||
class RequestMiddlewareManager(Sequence[RequestMiddlewareType[TelegramObject]]):
|
||||
def __init__(self) -> None:
|
||||
self._middlewares: List[RequestMiddlewareType[TelegramObject]] = []
|
||||
|
||||
def register(
|
||||
self,
|
||||
middleware: RequestMiddlewareType[TelegramObject],
|
||||
) -> RequestMiddlewareType[TelegramObject]:
|
||||
self._middlewares.append(middleware)
|
||||
return middleware
|
||||
|
||||
def unregister(self, middleware: RequestMiddlewareType[TelegramObject]) -> None:
|
||||
self._middlewares.remove(middleware)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
middleware: Optional[RequestMiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[RequestMiddlewareType[TelegramObject]], RequestMiddlewareType[TelegramObject]],
|
||||
RequestMiddlewareType[TelegramObject],
|
||||
]:
|
||||
if middleware is None:
|
||||
return self.register
|
||||
return self.register(middleware)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> RequestMiddlewareType[TelegramObject]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Sequence[RequestMiddlewareType[TelegramObject]]:
|
||||
pass
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[
|
||||
RequestMiddlewareType[TelegramObject], Sequence[RequestMiddlewareType[TelegramObject]]
|
||||
]:
|
||||
return self._middlewares[item]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._middlewares)
|
||||
|
||||
def wrap_middlewares(
|
||||
self,
|
||||
callback: Callable[[Bot, TelegramMethod[TelegramType]], Awaitable[Response[TelegramType]]],
|
||||
**kwargs: Any,
|
||||
) -> NextRequestMiddlewareType[TelegramType]:
|
||||
middleware = partial(callback, **kwargs)
|
||||
for m in reversed(self._middlewares):
|
||||
middleware = partial(m, middleware) # type: ignore
|
||||
return middleware
|
||||
|
|
@ -3,8 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type
|
|||
|
||||
from aiogram import loggers
|
||||
from aiogram.methods import TelegramMethod
|
||||
from aiogram.methods.base import Response
|
||||
from aiogram.types import TelegramObject
|
||||
from aiogram.methods.base import Response, TelegramType
|
||||
|
||||
from .base import BaseRequestMiddleware, NextRequestMiddlewareType
|
||||
|
||||
|
|
@ -25,10 +24,10 @@ class RequestLogging(BaseRequestMiddleware):
|
|||
|
||||
async def __call__(
|
||||
self,
|
||||
make_request: NextRequestMiddlewareType,
|
||||
make_request: NextRequestMiddlewareType[TelegramType],
|
||||
bot: "Bot",
|
||||
method: TelegramMethod[TelegramObject],
|
||||
) -> Response[TelegramObject]:
|
||||
method: TelegramMethod[TelegramType],
|
||||
) -> Response[TelegramType]:
|
||||
if type(method) not in self.ignore_methods:
|
||||
loggers.middlewares.info(
|
||||
"Make request with method=%r by bot id=%d",
|
||||
|
|
|
|||
|
|
@ -36,8 +36,19 @@ class Dispatcher(Router):
|
|||
storage: Optional[BaseStorage] = None,
|
||||
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
|
||||
events_isolation: Optional[BaseEventIsolation] = None,
|
||||
disable_fsm: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Root router
|
||||
|
||||
:param storage: Storage for FSM
|
||||
:param fsm_strategy: FSM strategy
|
||||
:param events_isolation: Events isolation
|
||||
:param disable_fsm: Disable FSM, note that if you disable FSM
|
||||
then you should not use storage and events isolation
|
||||
:param kwargs: Other arguments, will be passed as keyword arguments to handlers
|
||||
"""
|
||||
super(Dispatcher, self).__init__(**kwargs)
|
||||
|
||||
# Telegram API provides originally only one event type - Update
|
||||
|
|
@ -48,7 +59,8 @@ class Dispatcher(Router):
|
|||
)
|
||||
self.update.register(self._listen_update)
|
||||
|
||||
# Error handlers should work is out of all other functions and be registered before all others middlewares
|
||||
# Error handlers should work is out of all other functions
|
||||
# and should be registered before all others middlewares
|
||||
self.update.outer_middleware(ErrorsMiddleware(self))
|
||||
|
||||
# User context middleware makes small optimization for all other builtin
|
||||
|
|
@ -62,11 +74,31 @@ class Dispatcher(Router):
|
|||
strategy=fsm_strategy,
|
||||
events_isolation=events_isolation if events_isolation else DisabledEventIsolation(),
|
||||
)
|
||||
self.update.outer_middleware(self.fsm)
|
||||
if not disable_fsm:
|
||||
# Note that when FSM middleware is disabled, the event isolation is also disabled
|
||||
# Because the isolation mechanism is a part of the FSM
|
||||
self.update.outer_middleware(self.fsm)
|
||||
self.shutdown.register(self.fsm.close)
|
||||
|
||||
self._data: Dict[str, Any] = {}
|
||||
self._running_lock = Lock()
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self._data[item]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self._data[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._data[key]
|
||||
|
||||
def get(self, key: str, /, default: Optional[Any] = None) -> Optional[Any]:
|
||||
return self._data.get(key, default)
|
||||
|
||||
@property
|
||||
def storage(self) -> BaseStorage:
|
||||
return self.fsm.storage
|
||||
|
||||
@property
|
||||
def parent_router(self) -> None:
|
||||
"""
|
||||
|
|
@ -100,8 +132,15 @@ class Dispatcher(Router):
|
|||
|
||||
token = Bot.set_current(bot)
|
||||
try:
|
||||
kwargs.update(bot=bot)
|
||||
response = await self.update.wrap_outer_middleware(self.update.trigger, update, kwargs)
|
||||
response = await self.update.wrap_outer_middleware(
|
||||
self.update.trigger,
|
||||
update,
|
||||
{
|
||||
**self._data,
|
||||
**kwargs,
|
||||
"bot": bot,
|
||||
},
|
||||
)
|
||||
handled = response is not UNHANDLED
|
||||
return response
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -1,33 +1,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
|
||||
from ...exceptions import FiltersResolveError
|
||||
from ...types import TelegramObject
|
||||
from ..filters.base import BaseFilter
|
||||
from .bases import (
|
||||
REJECTED,
|
||||
UNHANDLED,
|
||||
MiddlewareEventType,
|
||||
MiddlewareType,
|
||||
NextMiddlewareType,
|
||||
SkipHandler,
|
||||
)
|
||||
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -48,8 +32,9 @@ class TelegramEventObserver:
|
|||
|
||||
self.handlers: List[HandlerObject] = []
|
||||
self.filters: List[Type[BaseFilter]] = []
|
||||
self.outer_middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
self.middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
|
||||
self.middleware = MiddlewareManager()
|
||||
self.outer_middleware = MiddlewareManager()
|
||||
|
||||
# Re-used filters check method from already implemented handler object
|
||||
# with dummy callback which never will be used
|
||||
|
|
@ -75,7 +60,11 @@ class TelegramEventObserver:
|
|||
|
||||
:param bound_filter:
|
||||
"""
|
||||
if not issubclass(bound_filter, BaseFilter):
|
||||
# TODO: This functionality should be deprecated in the future
|
||||
# in due to bound filter has uncontrollable ordering and
|
||||
# makes debugging process is harder that explicit using filters
|
||||
|
||||
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
|
||||
raise TypeError(
|
||||
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
||||
)
|
||||
|
|
@ -97,18 +86,11 @@ class TelegramEventObserver:
|
|||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType[TelegramObject]]:
|
||||
"""
|
||||
Get all middlewares in a tree
|
||||
:param *:
|
||||
"""
|
||||
middlewares = []
|
||||
if outer:
|
||||
middlewares.extend(self.outer_middlewares)
|
||||
else:
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middlewares)
|
||||
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
|
||||
middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
for router in reversed(tuple(self.router.chain_head)):
|
||||
observer = router.observers[self.event_name]
|
||||
middlewares.extend(observer.middleware)
|
||||
|
||||
return middlewares
|
||||
|
||||
|
|
@ -198,23 +180,13 @@ class TelegramEventObserver:
|
|||
)
|
||||
return callback
|
||||
|
||||
@classmethod
|
||||
def _wrap_middleware(
|
||||
cls, middlewares: List[MiddlewareType[MiddlewareEventType]], handler: HandlerType
|
||||
) -> NextMiddlewareType[MiddlewareEventType]:
|
||||
@functools.wraps(handler)
|
||||
def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
||||
return handler(event, **kwargs)
|
||||
|
||||
middleware = mapper
|
||||
for m in reversed(middlewares):
|
||||
middleware = functools.partial(m, middleware)
|
||||
return middleware
|
||||
|
||||
def wrap_outer_middleware(
|
||||
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
|
||||
wrapped_outer = self.middleware.wrap_middlewares(
|
||||
self.outer_middleware,
|
||||
callback,
|
||||
)
|
||||
return wrapped_outer(event, data)
|
||||
|
||||
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
|
||||
|
|
@ -233,8 +205,9 @@ class TelegramEventObserver:
|
|||
if result:
|
||||
kwargs.update(data, handler=handler)
|
||||
try:
|
||||
wrapped_inner = self._wrap_middleware(
|
||||
self._resolve_middlewares(), handler.call
|
||||
wrapped_inner = self.outer_middleware.wrap_middlewares(
|
||||
self._resolve_middlewares(),
|
||||
handler.call,
|
||||
)
|
||||
return await wrapped_inner(event, kwargs)
|
||||
except SkipHandler:
|
||||
|
|
@ -254,71 +227,3 @@ class TelegramEventObserver:
|
|||
return callback
|
||||
|
||||
return wrapper
|
||||
|
||||
def middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
"""
|
||||
Decorator for registering inner middlewares
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.middleware() # via decorator (variant 1)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.middleware # via decorator (variant 2)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def my_middleware(handler, event, data): ...
|
||||
<event>.middleware(my_middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
|
||||
self.middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
||||
def outer_middleware(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
"""
|
||||
Decorator for registering outer middlewares
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.outer_middleware() # via decorator (variant 1)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@<event>.outer_middleware # via decorator (variant 2)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def my_middleware(handler, event, data): ...
|
||||
<event>.outer_middleware(my_middleware) # via method
|
||||
"""
|
||||
|
||||
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
|
||||
self.outer_middlewares.append(m)
|
||||
return m
|
||||
|
||||
if middleware is None:
|
||||
return wrapper
|
||||
return wrapper(middleware)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Literal, Optional, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from magic_filter import MagicFilter
|
||||
|
|
@ -22,9 +22,20 @@ class CallbackDataException(Exception):
|
|||
|
||||
|
||||
class CallbackData(BaseModel):
|
||||
"""
|
||||
Base class for callback data wrapper
|
||||
|
||||
This class should be used as super-class of user-defined callbacks.
|
||||
|
||||
The class-keyword :code:`prefix` is required to define prefix
|
||||
and also the argument :code:`sep` can be passed to define separator (default is :code:`:`).
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
sep: str
|
||||
prefix: str
|
||||
__separator__: ClassVar[str]
|
||||
"""Data separator (default is :code:`:`)"""
|
||||
__prefix__: ClassVar[str]
|
||||
"""Callback prefix"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
if "prefix" not in kwargs:
|
||||
|
|
@ -32,12 +43,14 @@ class CallbackData(BaseModel):
|
|||
f"prefix required, usage example: "
|
||||
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
|
||||
)
|
||||
cls.sep = kwargs.pop("sep", ":")
|
||||
cls.prefix = kwargs.pop("prefix")
|
||||
if cls.sep in cls.prefix:
|
||||
cls.__separator__ = kwargs.pop("sep", ":")
|
||||
cls.__prefix__ = kwargs.pop("prefix")
|
||||
if cls.__separator__ in cls.__prefix__:
|
||||
raise ValueError(
|
||||
f"Separator symbol {cls.sep!r} can not be used inside prefix {cls.prefix!r}"
|
||||
f"Separator symbol {cls.__separator__!r} can not be used "
|
||||
f"inside prefix {cls.__prefix__!r}"
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def _encode_value(self, key: str, value: Any) -> str:
|
||||
if value is None:
|
||||
|
|
@ -52,31 +65,45 @@ class CallbackData(BaseModel):
|
|||
)
|
||||
|
||||
def pack(self) -> str:
|
||||
result = [self.prefix]
|
||||
"""
|
||||
Generate callback data string
|
||||
|
||||
:return: valid callback data for Telegram Bot API
|
||||
"""
|
||||
result = [self.__prefix__]
|
||||
for key, value in self.dict().items():
|
||||
encoded = self._encode_value(key, value)
|
||||
if self.sep in encoded:
|
||||
if self.__separator__ in encoded:
|
||||
raise ValueError(
|
||||
f"Separator symbol {self.sep!r} can not be used in value {key}={encoded!r}"
|
||||
f"Separator symbol {self.__separator__!r} can not be used "
|
||||
f"in value {key}={encoded!r}"
|
||||
)
|
||||
result.append(encoded)
|
||||
callback_data = self.sep.join(result)
|
||||
callback_data = self.__separator__.join(result)
|
||||
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
|
||||
raise ValueError(
|
||||
f"Resulted callback data is too long! len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
||||
f"Resulted callback data is too long! "
|
||||
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
||||
)
|
||||
return callback_data
|
||||
|
||||
@classmethod
|
||||
def unpack(cls: Type[T], value: str) -> T:
|
||||
prefix, *parts = value.split(cls.sep)
|
||||
"""
|
||||
Parse callback data string
|
||||
|
||||
:param value: value from Telegram
|
||||
:return: instance of CallbackData
|
||||
"""
|
||||
prefix, *parts = value.split(cls.__separator__)
|
||||
names = cls.__fields__.keys()
|
||||
if len(parts) != len(names):
|
||||
raise TypeError(
|
||||
f"Callback data {cls.__name__!r} takes {len(names)} arguments but {len(parts)} were given"
|
||||
f"Callback data {cls.__name__!r} takes {len(names)} arguments "
|
||||
f"but {len(parts)} were given"
|
||||
)
|
||||
if prefix != cls.prefix:
|
||||
raise ValueError(f"Bad prefix ({prefix!r} != {cls.prefix!r})")
|
||||
if prefix != cls.__prefix__:
|
||||
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
|
||||
payload = {}
|
||||
for k, v in zip(names, parts): # type: str, Optional[str]
|
||||
if field := cls.__fields__.get(k):
|
||||
|
|
@ -87,15 +114,30 @@ class CallbackData(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter:
|
||||
"""
|
||||
Generates a filter for callback query with rule
|
||||
|
||||
:param rule: magic rule
|
||||
:return: instance of filter
|
||||
"""
|
||||
return CallbackQueryFilter(callback_data=cls, rule=rule)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
# class Config:
|
||||
# use_enum_values = True
|
||||
|
||||
|
||||
class CallbackQueryFilter(BaseFilter):
|
||||
"""
|
||||
This filter helps to handle callback query.
|
||||
|
||||
Should not be used directly, you should create the instance of this filter
|
||||
via callback data instance
|
||||
"""
|
||||
|
||||
callback_data: Type[CallbackData]
|
||||
"""Expected type of callback data"""
|
||||
rule: Optional[MagicFilter] = None
|
||||
"""Magic rule"""
|
||||
|
||||
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
|
||||
if not isinstance(query, CallbackQuery) or not query.data:
|
||||
|
|
@ -111,3 +153,4 @@ class CallbackQueryFilter(BaseFilter):
|
|||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
use_enum_values = True
|
||||
|
|
|
|||
|
|
@ -59,7 +59,10 @@ class Command(BaseFilter):
|
|||
command = await self.parse_command(text=text, bot=bot)
|
||||
except CommandException:
|
||||
return False
|
||||
return {"command": command}
|
||||
result = {"command": command}
|
||||
if command.magic_result and isinstance(command.magic_result, dict):
|
||||
result.update(command.magic_result)
|
||||
return result
|
||||
|
||||
def extract_command(self, text: str) -> CommandObject:
|
||||
# First step: separate command with arguments
|
||||
|
|
@ -110,20 +113,22 @@ class Command(BaseFilter):
|
|||
self.validate_prefix(command=command)
|
||||
await self.validate_mention(bot=bot, command=command)
|
||||
command = self.validate_command(command)
|
||||
self.do_magic(command=command)
|
||||
command = self.do_magic(command=command)
|
||||
return command
|
||||
|
||||
def do_magic(self, command: CommandObject) -> None:
|
||||
def do_magic(self, command: CommandObject) -> Any:
|
||||
if not self.command_magic:
|
||||
return
|
||||
if not self.command_magic.resolve(command):
|
||||
return command
|
||||
result = self.command_magic.resolve(command)
|
||||
if not result:
|
||||
raise CommandException("Rejected via magic filter")
|
||||
return replace(command, magic_result=result)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class CommandObject:
|
||||
"""
|
||||
Instance of this object is always has command and it prefix.
|
||||
|
|
@ -140,6 +145,7 @@ class CommandObject:
|
|||
"""Command argument"""
|
||||
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
|
||||
"""Will be presented match result if the command is presented as regexp in filter"""
|
||||
magic_result: Optional[Any] = field(repr=False, default=None)
|
||||
|
||||
@property
|
||||
def mentioned(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ class MagicData(BaseFilter):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> bool:
|
||||
return bool(
|
||||
self.magic_data.resolve(
|
||||
AttrDict({"event": event, **{k: v for k, v in enumerate(args)}, **kwargs})
|
||||
)
|
||||
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.magic_data.resolve(
|
||||
AttrDict({"event": event, **{k: v for k, v in enumerate(args)}, **kwargs})
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
|
||||
|
||||
from magic_filter import AttrDict
|
||||
|
||||
from aiogram.dispatcher.flags.getter import extract_flags_from_object
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Flag:
|
||||
|
|
@ -25,11 +28,11 @@ class FlagDecorator:
|
|||
return self._with_flag(new_flag)
|
||||
|
||||
@overload
|
||||
def __call__(self, value: Callable[..., Any]) -> Callable[..., Any]: # type: ignore
|
||||
def __call__(self, value: Callable[..., Any], /) -> Callable[..., Any]: # type: ignore
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __call__(self, value: Any) -> "FlagDecorator":
|
||||
def __call__(self, value: Any, /) -> "FlagDecorator":
|
||||
pass
|
||||
|
||||
@overload
|
||||
|
|
@ -53,8 +56,24 @@ class FlagDecorator:
|
|||
return self._with_value(AttrDict(kwargs) if value is None else value)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _ChatActionFlagProtocol(FlagDecorator):
|
||||
def __call__( # type: ignore[override]
|
||||
self,
|
||||
action: str = ...,
|
||||
interval: float = ...,
|
||||
initial_sleep: float = ...,
|
||||
**kwargs: Any,
|
||||
) -> FlagDecorator:
|
||||
pass
|
||||
|
||||
|
||||
class FlagGenerator:
|
||||
def __getattr__(self, name: str) -> FlagDecorator:
|
||||
if name[0] == "_":
|
||||
raise AttributeError("Flag name must NOT start with underscore")
|
||||
return FlagDecorator(Flag(name, True))
|
||||
|
||||
if TYPE_CHECKING:
|
||||
chat_action: _ChatActionFlagProtocol
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from abc import ABC, abstractmethod
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
|
||||
|
||||
from aioredis import ConnectionPool, Redis
|
||||
from redis.asyncio.client import Redis
|
||||
from redis.asyncio.connection import ConnectionPool
|
||||
from redis.asyncio.lock import Lock
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
|
|
@ -131,7 +133,7 @@ class RedisStorage(BaseStorage):
|
|||
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.redis.close() # type: ignore
|
||||
await self.redis.close()
|
||||
|
||||
async def set_state(
|
||||
self,
|
||||
|
|
@ -223,7 +225,7 @@ class RedisEventIsolation(BaseEventIsolation):
|
|||
key: StorageKey,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
redis_key = self.key_builder.build(key, "lock")
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
|
||||
async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock):
|
||||
yield None
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
|
|||
61
aiogram/dispatcher/middlewares/manager.py
Normal file
61
aiogram/dispatcher/middlewares/manager.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import functools
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
|
||||
|
||||
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
|
||||
from aiogram.dispatcher.event.handler import HandlerType
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
|
||||
def __init__(self) -> None:
|
||||
self._middlewares: List[MiddlewareType[TelegramObject]] = []
|
||||
|
||||
def register(
|
||||
self,
|
||||
middleware: MiddlewareType[TelegramObject],
|
||||
) -> MiddlewareType[TelegramObject]:
|
||||
self._middlewares.append(middleware)
|
||||
return middleware
|
||||
|
||||
def unregister(self, middleware: MiddlewareType[TelegramObject]) -> None:
|
||||
self._middlewares.remove(middleware)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
middleware: Optional[MiddlewareType[TelegramObject]] = None,
|
||||
) -> Union[
|
||||
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
|
||||
MiddlewareType[TelegramObject],
|
||||
]:
|
||||
if middleware is None:
|
||||
return self.register
|
||||
return self.register(middleware)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> MiddlewareType[TelegramObject]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Sequence[MiddlewareType[TelegramObject]]:
|
||||
pass
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
|
||||
return self._middlewares[item]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._middlewares)
|
||||
|
||||
@staticmethod
|
||||
def wrap_middlewares(
|
||||
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
|
||||
) -> NextMiddlewareType[MiddlewareEventType]:
|
||||
@functools.wraps(handler)
|
||||
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
|
||||
return handler(event, **kwargs)
|
||||
|
||||
middleware = handler_wrapper
|
||||
for m in reversed(middlewares):
|
||||
middleware = functools.partial(m, middleware)
|
||||
return middleware
|
||||
|
|
@ -8,6 +8,7 @@ from ..utils.imports import import_module
|
|||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
from .event.event import EventObserver
|
||||
from .event.handler import HandlerType
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from .filters import BUILTIN_FILTERS
|
||||
|
||||
|
|
@ -253,7 +254,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.message
|
||||
|
||||
@property
|
||||
|
|
@ -264,7 +264,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.edited_message
|
||||
|
||||
@property
|
||||
|
|
@ -275,7 +274,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.channel_post
|
||||
|
||||
@property
|
||||
|
|
@ -286,7 +284,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.edited_channel_post
|
||||
|
||||
@property
|
||||
|
|
@ -297,7 +294,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.inline_query
|
||||
|
||||
@property
|
||||
|
|
@ -308,7 +304,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.chosen_inline_result
|
||||
|
||||
@property
|
||||
|
|
@ -319,7 +314,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.callback_query
|
||||
|
||||
@property
|
||||
|
|
@ -330,7 +324,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.shipping_query
|
||||
|
||||
@property
|
||||
|
|
@ -341,7 +334,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.pre_checkout_query
|
||||
|
||||
@property
|
||||
|
|
@ -352,7 +344,6 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.poll
|
||||
|
||||
@property
|
||||
|
|
@ -363,9 +354,38 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.poll_answer
|
||||
|
||||
@property
|
||||
def my_chat_member_handler(self) -> TelegramEventObserver:
|
||||
warnings.warn(
|
||||
"`Router.my_chat_member_handler(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.my_chat_member(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.my_chat_member
|
||||
|
||||
@property
|
||||
def chat_member_handler(self) -> TelegramEventObserver:
|
||||
warnings.warn(
|
||||
"`Router.chat_member_handler(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_member(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.chat_member
|
||||
|
||||
@property
|
||||
def chat_join_request_handler(self) -> TelegramEventObserver:
|
||||
warnings.warn(
|
||||
"`Router.chat_join_request_handler(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_join_request(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.chat_join_request
|
||||
|
||||
@property
|
||||
def errors_handler(self) -> TelegramEventObserver:
|
||||
warnings.warn(
|
||||
|
|
@ -374,5 +394,139 @@ class Router:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.errors
|
||||
|
||||
def register_message(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_message(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.message.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.message.register(*args, **kwargs)
|
||||
|
||||
def register_edited_message(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_edited_message(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.edited_message.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.edited_message.register(*args, **kwargs)
|
||||
|
||||
def register_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_channel_post(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.channel_post.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.channel_post.register(*args, **kwargs)
|
||||
|
||||
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_edited_channel_post(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.edited_channel_post.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.edited_channel_post.register(*args, **kwargs)
|
||||
|
||||
def register_inline_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_inline_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.inline_query.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.inline_query.register(*args, **kwargs)
|
||||
|
||||
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_chosen_inline_result(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chosen_inline_result.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.chosen_inline_result.register(*args, **kwargs)
|
||||
|
||||
def register_callback_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_callback_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.callback_query.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.callback_query.register(*args, **kwargs)
|
||||
|
||||
def register_shipping_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_shipping_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.shipping_query.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.shipping_query.register(*args, **kwargs)
|
||||
|
||||
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_pre_checkout_query(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.pre_checkout_query.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.pre_checkout_query.register(*args, **kwargs)
|
||||
|
||||
def register_poll(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_poll(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.poll.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.poll.register(*args, **kwargs)
|
||||
|
||||
def register_poll_answer(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_poll_answer(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.poll_answer.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.poll_answer.register(*args, **kwargs)
|
||||
|
||||
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_my_chat_member(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.my_chat_member.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.my_chat_member.register(*args, **kwargs)
|
||||
|
||||
def register_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_chat_member(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_member.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.chat_member.register(*args, **kwargs)
|
||||
|
||||
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_chat_join_request(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.chat_join_request.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.chat_join_request.register(*args, **kwargs)
|
||||
|
||||
def register_errors(self, *args: Any, **kwargs: Any) -> HandlerType:
|
||||
warnings.warn(
|
||||
"`Router.register_errors(...)` is deprecated and will be removed in version 3.2 "
|
||||
"use `Router.errors.register(...)`",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.errors.register(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[TelegramType]):
|
|||
async def emit(self, bot: Bot) -> TelegramType:
|
||||
return await bot(self)
|
||||
|
||||
as_ = emit
|
||||
|
||||
def __await__(self) -> Generator[Any, None, TelegramType]:
|
||||
from aiogram.client.bot import Bot
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ if TYPE_CHECKING:
|
|||
from .keyboard_button_poll_type import KeyboardButtonPollType
|
||||
|
||||
|
||||
class WebApp(MutableTelegramObject):
|
||||
url: str
|
||||
|
||||
|
||||
class KeyboardButton(MutableTelegramObject):
|
||||
"""
|
||||
This object represents one button of the reply keyboard. For simple text buttons *String* can be used instead of this object to specify text of the button. Optional fields *request_contact*, *request_location*, and *request_poll* are mutually exclusive.
|
||||
|
|
@ -26,3 +30,4 @@ class KeyboardButton(MutableTelegramObject):
|
|||
"""*Optional*. If :code:`True`, the user's current location will be sent when the button is pressed. Available in private chats only"""
|
||||
request_poll: Optional[KeyboardButtonPollType] = None
|
||||
"""*Optional*. If specified, the user will be asked to create a poll and send it to the bot when the button is pressed. Available in private chats only"""
|
||||
web_app: Optional[WebApp] = None
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ if TYPE_CHECKING:
|
|||
from .voice_chat_started import VoiceChatStarted
|
||||
|
||||
|
||||
class Message(TelegramObject):
|
||||
class _BaseMessage(TelegramObject):
|
||||
"""
|
||||
This object represents a message.
|
||||
|
||||
|
|
@ -195,6 +195,8 @@ class Message(TelegramObject):
|
|||
reply_markup: Optional[InlineKeyboardMarkup] = None
|
||||
"""*Optional*. Inline keyboard attached to the message. :code:`login_url` buttons are represented as ordinary :code:`url` buttons."""
|
||||
|
||||
|
||||
class Message(_BaseMessage):
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
if self.text:
|
||||
|
|
@ -265,11 +267,8 @@ class Message(TelegramObject):
|
|||
return ContentType.UNKNOWN
|
||||
|
||||
def _unparse_entities(self, text_decoration: TextDecoration) -> str:
|
||||
text = self.text or self.caption
|
||||
if text is None:
|
||||
raise TypeError("This message doesn't have any text.")
|
||||
|
||||
entities = self.entities or self.caption_entities
|
||||
text = self.text or self.caption or ""
|
||||
entities = self.entities or self.caption_entities or []
|
||||
return text_decoration.unparse(text=text, entities=entities)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from aiogram.types import Message, TelegramObject
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_INTERVAL = 5.0
|
||||
DEFAULT_INITIAL_SLEEP = 0.1
|
||||
DEFAULT_INITIAL_SLEEP = 0.0
|
||||
|
||||
|
||||
class ChatActionSender:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
BASE_DOCS_URL = "https://docs.aiogram.dev/"
|
||||
BRANCH = "dev-3.x"
|
||||
|
||||
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
|
||||
|
||||
|
||||
def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query: Any) -> str:
|
||||
url = urljoin(url, "/".join(path), allow_fragments=True)
|
||||
if query:
|
||||
url += "?" + urlencode(query)
|
||||
if fragment_:
|
||||
url += "#" + fragment_
|
||||
return url
|
||||
|
||||
|
||||
def docs_url(*path: str, fragment_: Optional[str] = None, **query: Any) -> str:
|
||||
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
|
||||
|
||||
|
||||
def create_tg_link(link: str, **kwargs: Any) -> str:
|
||||
url = f"tg://{link}"
|
||||
if kwargs:
|
||||
query = urlencode(kwargs)
|
||||
url += f"?{query}"
|
||||
return url
|
||||
return _format_url(f"tg://{link}", **kwargs)
|
||||
|
||||
|
||||
def create_telegram_link(uri: str, **kwargs: Any) -> str:
|
||||
url = urljoin("https://t.me", uri)
|
||||
if kwargs:
|
||||
query = urlencode(query=kwargs)
|
||||
url += f"?{query}"
|
||||
return url
|
||||
def create_telegram_link(*path: str, **kwargs: Any) -> str:
|
||||
return _format_url("https://t.me", *path, **kwargs)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class TextDecoration(ABC):
|
|||
if entity.type in {"bot_command", "url", "mention", "phone_number"}:
|
||||
# This entities should not be changed
|
||||
return text
|
||||
if entity.type in {"bold", "italic", "code", "underline", "strikethrough"}:
|
||||
if entity.type in {"bold", "italic", "code", "underline", "strikethrough", "spoiler"}:
|
||||
return cast(str, getattr(self, entity.type)(value=text))
|
||||
if entity.type == "pre":
|
||||
return (
|
||||
|
|
@ -102,35 +102,39 @@ class TextDecoration(ABC):
|
|||
yield self.quote(remove_surrogates(text[offset:length]))
|
||||
|
||||
@abstractmethod
|
||||
def link(self, value: str, link: str) -> str: # pragma: no cover
|
||||
def link(self, value: str, link: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def bold(self, value: str) -> str: # pragma: no cover
|
||||
def bold(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def italic(self, value: str) -> str: # pragma: no cover
|
||||
def italic(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def code(self, value: str) -> str: # pragma: no cover
|
||||
def code(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre(self, value: str) -> str: # pragma: no cover
|
||||
def pre(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_language(self, value: str, language: str) -> str: # pragma: no cover
|
||||
def pre_language(self, value: str, language: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def underline(self, value: str) -> str: # pragma: no cover
|
||||
def underline(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def strikethrough(self, value: str) -> str: # pragma: no cover
|
||||
def strikethrough(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def spoiler(self, value: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -139,14 +143,20 @@ class TextDecoration(ABC):
|
|||
|
||||
|
||||
class HtmlDecoration(TextDecoration):
|
||||
BOLD_TAG = "b"
|
||||
ITALIC_TAG = "i"
|
||||
UNDERLINE_TAG = "u"
|
||||
STRIKETHROUGH_TAG = "s"
|
||||
SPOILER_TAG = ('span class="tg-spoiler"', "span")
|
||||
|
||||
def link(self, value: str, link: str) -> str:
|
||||
return f'<a href="{link}">{value}</a>'
|
||||
|
||||
def bold(self, value: str) -> str:
|
||||
return f"<b>{value}</b>"
|
||||
return f"<{self.BOLD_TAG}>{value}</{self.BOLD_TAG}>"
|
||||
|
||||
def italic(self, value: str) -> str:
|
||||
return f"<i>{value}</i>"
|
||||
return f"<{self.ITALIC_TAG}>{value}</{self.ITALIC_TAG}>"
|
||||
|
||||
def code(self, value: str) -> str:
|
||||
return f"<code>{value}</code>"
|
||||
|
|
@ -158,10 +168,13 @@ class HtmlDecoration(TextDecoration):
|
|||
return f'<pre><code class="language-{language}">{value}</code></pre>'
|
||||
|
||||
def underline(self, value: str) -> str:
|
||||
return f"<u>{value}</u>"
|
||||
return f"<{self.UNDERLINE_TAG}>{value}</{self.UNDERLINE_TAG}>"
|
||||
|
||||
def strikethrough(self, value: str) -> str:
|
||||
return f"<s>{value}</s>"
|
||||
return f"<{self.STRIKETHROUGH_TAG}>{value}</{self.STRIKETHROUGH_TAG}>"
|
||||
|
||||
def spoiler(self, value: str) -> str:
|
||||
return f"<{self.SPOILER_TAG[0]}>{value}</{self.SPOILER_TAG[1]}>"
|
||||
|
||||
def quote(self, value: str) -> str:
|
||||
return html.escape(value, quote=False)
|
||||
|
|
@ -194,6 +207,9 @@ class MarkdownDecoration(TextDecoration):
|
|||
def strikethrough(self, value: str) -> str:
|
||||
return f"~{value}~"
|
||||
|
||||
def spoiler(self, value: str) -> str:
|
||||
return f"|{value}|"
|
||||
|
||||
def quote(self, value: str) -> str:
|
||||
return re.sub(pattern=self.MARKDOWN_QUOTE_PATTERN, repl=r"\\\1", string=value)
|
||||
|
||||
|
|
|
|||
118
docs/dispatcher/filters/callback_data.rst
Normal file
118
docs/dispatcher/filters/callback_data.rst
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
==============================
|
||||
Callback Data Factory & Filter
|
||||
==============================
|
||||
|
||||
.. autoclass:: aiogram.dispatcher.filters.callback_data.CallbackData
|
||||
:members:
|
||||
:member-order: bysource
|
||||
:undoc-members: False
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
Create subclass of :code:`CallbackData`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyCallback(CallbackData, prefix="my"):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
After that you can generate any callback based on this class, for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
cb1 = MyCallback(foo="demo", bar=42)
|
||||
cb1.pack() # returns 'my:demo:42'
|
||||
cb1.unpack('my:demo:42') # returns <MyCallback(foo="demo", bar=42)>
|
||||
|
||||
So... Now you can use this class to generate any callbacks with defined structure
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
...
|
||||
# Pass it into the markup
|
||||
InlineKeyboardButton(
|
||||
text="demo",
|
||||
callback_data=MyCallback(foo="demo", bar="42").pack() # value should be packed to string
|
||||
)
|
||||
...
|
||||
|
||||
... and handle by specific rules
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Filter callback by type and value of field :code:`foo`
|
||||
@router.callback_query(MyCallback.filter(F.foo == "demo"))
|
||||
async def my_callback_foo(query: CallbackQuery, callback_data: MyCallback):
|
||||
await query.answer(...)
|
||||
...
|
||||
print("bar =", callback_data.bar)
|
||||
|
||||
Also can be used in :doc:`Keyboard builder </utils/keyboard>`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
builder = InlineKeyboardBuilder()
|
||||
builder.button(
|
||||
text="demo",
|
||||
callback_data=MyCallback(foo="demo", bar="42") # Value can be not packed to string inplace, because builder knows what to do with callback instance
|
||||
)
|
||||
|
||||
|
||||
Another abstract example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Action(str, Enum):
|
||||
ban = "ban"
|
||||
kick = "kick"
|
||||
warn = "warn"
|
||||
|
||||
class AdminAction(CallbackData, prefix="adm"):
|
||||
action: Action
|
||||
chat_id: int
|
||||
user_id: int
|
||||
|
||||
...
|
||||
# Inside handler
|
||||
builder = InlineKeyboardBuilder()
|
||||
for action in Action:
|
||||
builder.button(
|
||||
text=action.value.title(),
|
||||
callback_data=AdminAction(action=action, chat_id=chat_id, user_id=user_id),
|
||||
)
|
||||
await bot.send_message(
|
||||
chat_id=admins_chat,
|
||||
text=f"What do you want to do with {html.quote(name)}",
|
||||
reply_markup=builder.as_markup(),
|
||||
)
|
||||
...
|
||||
|
||||
@router.callback_query(AdminAction.filter(F.action == Action.ban))
|
||||
async def ban_user(query: CallbackQuery, callback_data: AdminAction, bot: Bot):
|
||||
await bot.ban_chat_member(
|
||||
chat_id=callback_data.chat_id,
|
||||
user_id=callback_data.user_id,
|
||||
...
|
||||
)
|
||||
|
||||
Known limitations
|
||||
=================
|
||||
|
||||
Allowed types and their subclasses:
|
||||
|
||||
- :code:`str`
|
||||
- :code:`int`
|
||||
- :code:`bool`
|
||||
- :code:`float`
|
||||
- :code:`Decimal` (:code:`from decimal import Decimal`)
|
||||
- :code:`Fraction` (:code:`from fractions import Fraction`)
|
||||
- :code:`UUID` (:code:`from uuid import UUID`)
|
||||
- :code:`Enum` (:code:`from enum import Enum`, only for string enums)
|
||||
- :code:`IntEnum` (:code:`from enum import IntEnum`, only for int enums)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Note that the integer Enum's should be always is subclasses of :code:`IntEnum` in due to parsing issues.
|
||||
|
|
@ -22,6 +22,7 @@ Here is list of builtin filters:
|
|||
exception
|
||||
magic_filters
|
||||
magic_data
|
||||
callback_data
|
||||
|
||||
Own filters specification
|
||||
=========================
|
||||
|
|
|
|||
|
|
@ -4,6 +4,6 @@ Utils
|
|||
|
||||
.. toctree::
|
||||
|
||||
i18n
|
||||
keyboard
|
||||
i18n
|
||||
chat_action
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiogram import Bot, Dispatcher, types
|
||||
from aiogram.types import Message
|
||||
|
|
|
|||
2
mypy.ini
2
mypy.ini
|
|
@ -30,7 +30,7 @@ ignore_missing_imports = True
|
|||
[mypy-uvloop]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-aioredis]
|
||||
[mypy-redis.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-babel.*]
|
||||
|
|
|
|||
970
poetry.lock
generated
970
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -37,7 +37,7 @@ classifiers = [
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
magic-filter = "^1.0.5"
|
||||
magic-filter = "^1.0.6"
|
||||
aiohttp = "^3.8.1"
|
||||
pydantic = "^1.9.0"
|
||||
aiofiles = "^0.8.0"
|
||||
|
|
@ -46,30 +46,29 @@ uvloop = { version = "^0.16.0", markers = "sys_platform == 'darwin' or sys_platf
|
|||
# i18n
|
||||
Babel = { version = "^2.9.1", optional = true }
|
||||
# Proxy
|
||||
aiohttp-socks = {version = "^0.7.1", optional = true}
|
||||
aiohttp-socks = { version = "^0.7.1", optional = true }
|
||||
# Redis
|
||||
aioredis = {version = "^2.0.1", optional = true}
|
||||
redis = { version = "^4.2.2", optional = true }
|
||||
# Docs
|
||||
Sphinx = { version = "^4.2.0", optional = true }
|
||||
sphinx-intl = { version = "^2.0.1", optional = true }
|
||||
sphinx-autobuild = { version = "^2021.3.14", optional = true }
|
||||
sphinx-copybutton = {version = "^0.5.0", optional = true}
|
||||
furo = {version = "^2022.2.14", optional = true}
|
||||
sphinx-copybutton = { version = "^0.5.0", optional = true }
|
||||
furo = { version = "^2022.4.7", optional = true }
|
||||
sphinx-prompt = { version = "^1.5.0", optional = true }
|
||||
Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
|
||||
towncrier = {version = "^21.9.0", optional = true}
|
||||
towncrier = { version = "^21.9.0", optional = true }
|
||||
pygments = { version = "^2.4", optional = true }
|
||||
pymdown-extensions = {version = "^9.2", optional = true}
|
||||
pymdown-extensions = { version = "^9.3", optional = true }
|
||||
markdown-include = { version = "^0.6", optional = true }
|
||||
Pygments = {version = "^2.11.2", optional = true}
|
||||
Pygments = { version = "^2.11.2", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
ipython = "^8.0.1"
|
||||
ipython = "^8.1.1"
|
||||
black = "^22.1.0"
|
||||
isort = "^5.10.1"
|
||||
flake8 = "^4.0.1"
|
||||
flake8-html = "^0.4.1"
|
||||
mypy = "^0.931"
|
||||
mypy = "^0.942"
|
||||
pytest = "^7.0.1"
|
||||
pytest-html = "^3.1.1"
|
||||
pytest-asyncio = "^0.18.1"
|
||||
|
|
@ -90,7 +89,7 @@ sentry-sdk = "^1.5.5"
|
|||
|
||||
[tool.poetry.extras]
|
||||
fast = ["uvloop"]
|
||||
redis = ["aioredis"]
|
||||
redis = ["redis"]
|
||||
proxy = ["aiohttp-socks"]
|
||||
i18n = ["Babel"]
|
||||
docs = [
|
||||
|
|
@ -110,7 +109,7 @@ docs = [
|
|||
|
||||
[tool.black]
|
||||
line-length = 99
|
||||
target-version = ['py37', 'py38']
|
||||
target-version = ['py38', 'py39', 'py310']
|
||||
exclude = '''
|
||||
(
|
||||
\.eggs
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
from _pytest.config import UsageError
|
||||
from aioredis.connection import parse_url as parse_redis_url
|
||||
from redis.asyncio.connection import parse_url as parse_redis_url
|
||||
|
||||
from aiogram import Bot, Dispatcher
|
||||
from aiogram.dispatcher.fsm.storage.memory import (
|
||||
|
|
|
|||
|
|
@ -215,11 +215,11 @@ class TestBaseSession:
|
|||
return await make_request(bot, method)
|
||||
|
||||
session = CustomSession()
|
||||
assert not session.middlewares
|
||||
assert not session.middleware._middlewares
|
||||
|
||||
session.middleware(my_middleware)
|
||||
assert my_middleware in session.middlewares
|
||||
assert len(session.middlewares) == 1
|
||||
assert my_middleware in session.middleware
|
||||
assert len(session.middleware) == 1
|
||||
|
||||
async def test_use_middleware(self, bot: MockedBot):
|
||||
flag_before = False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from aiogram import Bot
|
||||
from aiogram.client.session.middlewares.base import (
|
||||
BaseRequestMiddleware,
|
||||
NextRequestMiddlewareType,
|
||||
)
|
||||
from aiogram.client.session.middlewares.manager import RequestMiddlewareManager
|
||||
from aiogram.methods import Response, TelegramMethod
|
||||
from aiogram.types import TelegramObject
|
||||
|
||||
|
||||
class TestMiddlewareManager:
|
||||
async def test_register(self):
|
||||
manager = RequestMiddlewareManager()
|
||||
|
||||
@manager
|
||||
async def middleware(handler, event, data):
|
||||
await handler(event, data)
|
||||
|
||||
assert middleware in manager._middlewares
|
||||
manager.unregister(middleware)
|
||||
assert middleware not in manager._middlewares
|
||||
|
||||
async def test_wrap_middlewares(self):
|
||||
manager = RequestMiddlewareManager()
|
||||
|
||||
class MyMiddleware(BaseRequestMiddleware):
|
||||
async def __call__(
|
||||
self,
|
||||
make_request: NextRequestMiddlewareType,
|
||||
bot: Bot,
|
||||
method: TelegramMethod[TelegramObject],
|
||||
) -> Response[TelegramObject]:
|
||||
return await make_request(bot, method)
|
||||
|
||||
manager.register(MyMiddleware())
|
||||
|
||||
@manager()
|
||||
@manager
|
||||
async def middleware(make_request, bot, method):
|
||||
return await make_request(bot, method)
|
||||
|
||||
async def target_call(bot, method, timeout: int = None):
|
||||
return timeout
|
||||
|
||||
assert await manager.wrap_middlewares(target_call, timeout=42)(None, None) == 42
|
||||
|
|
@ -641,13 +641,15 @@ class TestMessage:
|
|||
assert method.message_id == message.message_id
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,entities,correct",
|
||||
"text,entities,mode,expected_value",
|
||||
[
|
||||
["test", [MessageEntity(type="bold", offset=0, length=4)], True],
|
||||
["", [], False],
|
||||
["test", [MessageEntity(type="bold", offset=0, length=4)], "html", "<b>test</b>"],
|
||||
["test", [MessageEntity(type="bold", offset=0, length=4)], "md", "*test*"],
|
||||
["", [], "html", ""],
|
||||
["", [], "md", ""],
|
||||
],
|
||||
)
|
||||
def test_html_text(self, text, entities, correct):
|
||||
def test_html_text(self, text, entities, mode, expected_value):
|
||||
message = Message(
|
||||
message_id=42,
|
||||
chat=Chat(id=42, type="private"),
|
||||
|
|
@ -655,11 +657,4 @@ class TestMessage:
|
|||
text=text,
|
||||
entities=entities,
|
||||
)
|
||||
if correct:
|
||||
assert message.html_text
|
||||
assert message.md_text
|
||||
else:
|
||||
with pytest.raises(TypeError):
|
||||
assert message.html_text
|
||||
with pytest.raises(TypeError):
|
||||
assert message.md_text
|
||||
assert getattr(message, f"{mode}_text") == expected_value
|
||||
|
|
|
|||
|
|
@ -5,27 +5,38 @@ from aiogram.dispatcher.router import Router
|
|||
from tests.deprecated import check_deprecated
|
||||
|
||||
OBSERVERS = {
|
||||
"callback_query",
|
||||
"channel_post",
|
||||
"chosen_inline_result",
|
||||
"edited_channel_post",
|
||||
"edited_message",
|
||||
"errors",
|
||||
"inline_query",
|
||||
"message",
|
||||
"edited_message",
|
||||
"channel_post",
|
||||
"edited_channel_post",
|
||||
"inline_query",
|
||||
"chosen_inline_result",
|
||||
"callback_query",
|
||||
"shipping_query",
|
||||
"pre_checkout_query",
|
||||
"poll",
|
||||
"poll_answer",
|
||||
"pre_checkout_query",
|
||||
"shipping_query",
|
||||
"my_chat_member",
|
||||
"chat_member",
|
||||
"chat_join_request",
|
||||
"errors",
|
||||
}
|
||||
|
||||
DEPRECATED_OBSERVERS = {observer + "_handler" for observer in OBSERVERS}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("observer_name", DEPRECATED_OBSERVERS)
|
||||
@pytest.mark.parametrize("observer_name", OBSERVERS)
|
||||
def test_deprecated_handlers_name(observer_name: str):
|
||||
router = Router()
|
||||
|
||||
with check_deprecated("3.2", exception=AttributeError):
|
||||
observer = getattr(router, observer_name)
|
||||
observer = getattr(router, f"{observer_name}_handler")
|
||||
assert isinstance(observer, TelegramEventObserver)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("observer_name", OBSERVERS)
|
||||
def test_deprecated_register_handlers(observer_name: str):
|
||||
router = Router()
|
||||
|
||||
with check_deprecated("3.2", exception=AttributeError):
|
||||
register = getattr(router, f"register_{observer_name}")
|
||||
register(lambda event: True)
|
||||
assert callable(register)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,22 @@ class TestDispatcher:
|
|||
|
||||
assert dp.update.handlers
|
||||
assert dp.update.handlers[0].callback == dp._listen_update
|
||||
assert dp.update.outer_middlewares
|
||||
assert dp.update.outer_middleware
|
||||
|
||||
def test_data_bind(self):
|
||||
dp = Dispatcher()
|
||||
assert dp.get("foo") is None
|
||||
assert dp.get("foo", 42) == 42
|
||||
|
||||
dp["foo"] = 1
|
||||
assert dp._data["foo"] == 1
|
||||
assert dp["foo"] == 1
|
||||
|
||||
del dp["foo"]
|
||||
assert "foo" not in dp._data
|
||||
|
||||
def test_storage_property(self, dispatcher: Dispatcher):
|
||||
assert dispatcher.storage is dispatcher.fsm.storage
|
||||
|
||||
def test_parent_router(self, dispatcher: Dispatcher):
|
||||
with pytest.raises(RuntimeError):
|
||||
|
|
|
|||
42
tests/test_dispatcher/test_event/test_middleware.py
Normal file
42
tests/test_dispatcher/test_event/test_middleware.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from functools import partial
|
||||
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
|
||||
|
||||
class TestMiddlewareManager:
|
||||
async def test_register(self):
|
||||
manager = MiddlewareManager()
|
||||
|
||||
@manager
|
||||
async def middleware(handler, event, data):
|
||||
await handler(event, data)
|
||||
|
||||
assert middleware in manager._middlewares
|
||||
manager.unregister(middleware)
|
||||
assert middleware not in manager._middlewares
|
||||
|
||||
async def test_wrap_middlewares(self):
|
||||
manager = MiddlewareManager()
|
||||
|
||||
async def target(*args, **kwargs):
|
||||
kwargs["target"] = True
|
||||
kwargs["stack"].append(-1)
|
||||
return kwargs
|
||||
|
||||
async def middleware1(handler, event, data):
|
||||
data["mw1"] = True
|
||||
data["stack"].append(1)
|
||||
return await handler(event, data)
|
||||
|
||||
async def middleware2(handler, event, data):
|
||||
data["mw2"] = True
|
||||
data["stack"].append(2)
|
||||
return await handler(event, data)
|
||||
|
||||
wrapped = manager.wrap_middlewares([middleware1, middleware2], target)
|
||||
|
||||
assert isinstance(wrapped, partial)
|
||||
assert wrapped.func is middleware1
|
||||
|
||||
result = await wrapped(None, {"stack": []})
|
||||
assert result == {"mw1": True, "mw2": True, "target": True, "stack": [1, 2, -1]}
|
||||
|
|
@ -297,10 +297,9 @@ class TestTelegramEventObserver:
|
|||
def test_register_middleware(self, middleware_type):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
||||
middlewares = getattr(event_observer, f"{middleware_type}s")
|
||||
decorator = getattr(event_observer, middleware_type)
|
||||
middlewares = getattr(event_observer, middleware_type)
|
||||
|
||||
@decorator
|
||||
@middlewares
|
||||
async def my_middleware1(handler, event, data):
|
||||
pass
|
||||
|
||||
|
|
@ -308,7 +307,7 @@ class TestTelegramEventObserver:
|
|||
assert my_middleware1.__name__ == "my_middleware1"
|
||||
assert my_middleware1 in middlewares
|
||||
|
||||
@decorator()
|
||||
@middlewares()
|
||||
async def my_middleware2(handler, event, data):
|
||||
pass
|
||||
|
||||
|
|
@ -319,13 +318,13 @@ class TestTelegramEventObserver:
|
|||
async def my_middleware3(handler, event, data):
|
||||
pass
|
||||
|
||||
decorator(my_middleware3)
|
||||
middlewares(my_middleware3)
|
||||
|
||||
assert my_middleware3 is not None
|
||||
assert my_middleware3.__name__ == "my_middleware3"
|
||||
assert my_middleware3 in middlewares
|
||||
|
||||
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]
|
||||
assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3]
|
||||
|
||||
def test_register_global_filters(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class MyCallback(CallbackData, prefix="test"):
|
|||
|
||||
class TestCallbackData:
|
||||
def test_init_subclass_prefix_required(self):
|
||||
assert MyCallback.prefix == "test"
|
||||
assert MyCallback.__prefix__ == "test"
|
||||
|
||||
with pytest.raises(ValueError, match="prefix required.+"):
|
||||
|
||||
|
|
@ -38,12 +38,12 @@ class TestCallbackData:
|
|||
pass
|
||||
|
||||
def test_init_subclass_sep_validation(self):
|
||||
assert MyCallback.sep == ":"
|
||||
assert MyCallback.__separator__ == ":"
|
||||
|
||||
class MyCallback2(CallbackData, prefix="test2", sep="@"):
|
||||
pass
|
||||
|
||||
assert MyCallback2.sep == "@"
|
||||
assert MyCallback2.__separator__ == "@"
|
||||
|
||||
with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"):
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,18 @@ class TestCommandFilter:
|
|||
command = Command(commands=["test"])
|
||||
assert bool(await command(message=message, bot=bot)) is result
|
||||
|
||||
async def test_command_magic_result(self, bot: MockedBot):
|
||||
message = Message(
|
||||
message_id=0,
|
||||
text="/test 42",
|
||||
chat=Chat(id=42, type="private"),
|
||||
date=datetime.datetime.now(),
|
||||
)
|
||||
command = Command(commands=["test"], command_magic=(F.args.as_("args")))
|
||||
result = await command(message=message, bot=bot)
|
||||
assert "args" in result
|
||||
assert result["args"] == "42"
|
||||
|
||||
|
||||
class TestCommandObject:
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import re
|
|||
|
||||
import pytest
|
||||
|
||||
from aiogram import Dispatcher, F
|
||||
from aiogram import Dispatcher
|
||||
from aiogram.dispatcher.filters import ExceptionMessageFilter, ExceptionTypeFilter
|
||||
from aiogram.types import Update
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ class TestMagicDataFilter:
|
|||
assert value.spam is True
|
||||
return value
|
||||
|
||||
f = MagicData(magic_data=F.func(check))
|
||||
f = MagicData(magic_data=F.func(check).as_("test"))
|
||||
result = await f(Update(update_id=123), "foo", "bar", spam=True)
|
||||
|
||||
assert called
|
||||
assert isinstance(result, bool)
|
||||
assert result
|
||||
assert isinstance(result, dict)
|
||||
assert result["test"]
|
||||
|
|
|
|||
|
|
@ -111,8 +111,8 @@ class TestSimpleI18nMiddleware:
|
|||
middleware = SimpleI18nMiddleware(i18n=i18n)
|
||||
middleware.setup(router=dp)
|
||||
|
||||
assert middleware not in dp.update.outer_middlewares
|
||||
assert middleware in dp.message.outer_middlewares
|
||||
assert middleware not in dp.update.outer_middleware
|
||||
assert middleware in dp.message.outer_middleware
|
||||
|
||||
async def test_get_unknown_locale(self, i18n: I18n):
|
||||
dp = Dispatcher()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict
|
|||
|
||||
import pytest
|
||||
|
||||
from aiogram.utils.link import create_telegram_link, create_tg_link
|
||||
from aiogram.utils.link import BRANCH, create_telegram_link, create_tg_link, docs_url
|
||||
|
||||
|
||||
class TestLink:
|
||||
|
|
@ -22,3 +22,12 @@ class TestLink:
|
|||
)
|
||||
def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str):
|
||||
assert create_telegram_link(base, **params) == result
|
||||
|
||||
def test_fragment(self):
|
||||
assert (
|
||||
docs_url("test.html", fragment_="test")
|
||||
== f"https://docs.aiogram.dev/en/{BRANCH}/test.html#test"
|
||||
)
|
||||
|
||||
def test_docs(self):
|
||||
assert docs_url("test.html") == f"https://docs.aiogram.dev/en/{BRANCH}/test.html"
|
||||
|
|
|
|||
|
|
@ -47,6 +47,11 @@ class TestTextDecoration:
|
|||
'<a href="tg://user?id=42">test</a>',
|
||||
],
|
||||
[html_decoration, MessageEntity(type="url", offset=0, length=5), "test"],
|
||||
[
|
||||
html_decoration,
|
||||
MessageEntity(type="spoiler", offset=0, length=5),
|
||||
'<span class="tg-spoiler">test</span>',
|
||||
],
|
||||
[
|
||||
html_decoration,
|
||||
MessageEntity(type="text_link", offset=0, length=5, url="https://aiogram.dev"),
|
||||
|
|
@ -76,6 +81,7 @@ class TestTextDecoration:
|
|||
[markdown_decoration, MessageEntity(type="bot_command", offset=0, length=5), "test"],
|
||||
[markdown_decoration, MessageEntity(type="email", offset=0, length=5), "test"],
|
||||
[markdown_decoration, MessageEntity(type="phone_number", offset=0, length=5), "test"],
|
||||
[markdown_decoration, MessageEntity(type="spoiler", offset=0, length=5), "|test|"],
|
||||
[
|
||||
markdown_decoration,
|
||||
MessageEntity(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue