Fixed typehints

This commit is contained in:
Alex Root Junior 2023-07-02 14:05:12 +03:00
parent 51aeceddad
commit 55ae13189c
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
7 changed files with 57 additions and 50 deletions

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Awaitable, Callable, Union
from typing import TYPE_CHECKING, Protocol
from aiogram.methods import Response, TelegramMethod
from aiogram.methods.base import TelegramType
@ -9,14 +9,24 @@ from aiogram.methods.base import TelegramType
if TYPE_CHECKING:
from ...bot import Bot
NextRequestMiddlewareType = Callable[["Bot", TelegramMethod], Awaitable[Response]]
RequestMiddlewareType = Union[
"BaseRequestMiddleware",
Callable[
[NextRequestMiddlewareType, "Bot", TelegramMethod],
Awaitable[Response],
],
]
class NextRequestMiddlewareType(Protocol[TelegramType]):
async def __call__(
self,
bot: "Bot",
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
pass
class RequestMiddlewareProtocol(Protocol):
async def __call__(
self,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
pass
class BaseRequestMiddleware(ABC):

View file

@ -1,63 +1,51 @@
from __future__ import annotations
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Union,
overload,
)
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
from aiogram.client.session.middlewares.base import (
NextRequestMiddlewareType,
RequestMiddlewareType,
RequestMiddlewareProtocol,
)
from aiogram.methods import Response
from aiogram.methods.base import TelegramMethod, TelegramType
from aiogram.types import TelegramObject
if TYPE_CHECKING:
from aiogram import Bot
from aiogram.methods.base import TelegramType
class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
class RequestMiddlewareManager(Sequence[RequestMiddlewareProtocol]):
def __init__(self) -> None:
self._middlewares: List[RequestMiddlewareType] = []
self._middlewares: List[RequestMiddlewareProtocol] = []
def register(
self,
middleware: RequestMiddlewareType,
) -> RequestMiddlewareType:
middleware: RequestMiddlewareProtocol,
) -> RequestMiddlewareProtocol:
self._middlewares.append(middleware)
return middleware
def unregister(self, middleware: RequestMiddlewareType) -> None:
def unregister(self, middleware: RequestMiddlewareProtocol) -> None:
self._middlewares.remove(middleware)
def __call__(
self,
middleware: Optional[RequestMiddlewareType] = None,
) -> Union[Callable[[RequestMiddlewareType], RequestMiddlewareType], RequestMiddlewareType,]:
middleware: Optional[RequestMiddlewareProtocol] = None,
) -> Union[
Callable[[RequestMiddlewareProtocol], RequestMiddlewareProtocol],
RequestMiddlewareProtocol,
]:
if middleware is None:
return self.register
return self.register(middleware)
@overload
def __getitem__(self, item: int) -> RequestMiddlewareType:
def __getitem__(self, item: int) -> RequestMiddlewareProtocol:
pass
@overload
def __getitem__(self, item: slice) -> Sequence[RequestMiddlewareType]:
def __getitem__(self, item: slice) -> Sequence[RequestMiddlewareProtocol]:
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[RequestMiddlewareType, Sequence[RequestMiddlewareType]]:
) -> Union[RequestMiddlewareProtocol, Sequence[RequestMiddlewareProtocol]]:
return self._middlewares[item]
def __len__(self) -> int:
@ -65,10 +53,10 @@ class RequestMiddlewareManager(Sequence[RequestMiddlewareType]):
def wrap_middlewares(
self,
callback: Callable[[Bot, TelegramMethod], Awaitable[Response]],
callback: NextRequestMiddlewareType[TelegramType],
**kwargs: Any,
) -> NextRequestMiddlewareType:
) -> NextRequestMiddlewareType[TelegramType]:
middleware = partial(callback, **kwargs)
for m in reversed(self._middlewares):
middleware = partial(m, middleware) # type: ignore
return middleware
middleware = partial(m, middleware)
return cast(NextRequestMiddlewareType[TelegramType], middleware)

View file

@ -24,9 +24,9 @@ class RequestLogging(BaseRequestMiddleware):
async def __call__(
self,
make_request: NextRequestMiddlewareType,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
method: TelegramMethod,
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
if type(method) not in self.ignore_methods:
loggers.middlewares.info(

View file

@ -1,9 +1,19 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, Optional, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Generic,
Optional,
TypeVar,
)
from pydantic import BaseConfig, BaseModel, ConfigDict, Extra, root_validator
from pydantic import BaseModel, ConfigDict
from pydantic.functional_validators import model_validator
from ..types import InputFile, ResponseParameters
@ -68,8 +78,6 @@ class TelegramMethod(BaseModel, Generic[TelegramType], ABC):
async def emit(self, bot: Bot) -> TelegramType:
return await bot(self)
as_ = emit
def __await__(self) -> Generator[Any, None, TelegramType]:
from aiogram.client.bot import Bot

View file

@ -10,13 +10,14 @@ from aiogram.utils.text_decorations import (
html_decoration,
markdown_decoration,
)
from ..enums import ContentType
from .base import (
UNSET_DISABLE_WEB_PAGE_PREVIEW,
UNSET_PARSE_MODE,
UNSET_PROTECT_CONTENT,
TelegramObject,
)
from ..enums import ContentType
if TYPE_CHECKING:
from ..methods import (

View file

@ -2,8 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional, cast
from .base import TelegramObject
from ..utils.mypy_hacks import lru_cache
from .base import TelegramObject
if TYPE_CHECKING:
from .callback_query import CallbackQuery

View file

@ -2,9 +2,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from .base import TelegramObject
from ..utils import markdown
from ..utils.link import create_tg_link
from .base import TelegramObject
if TYPE_CHECKING:
from ..methods import GetUserProfilePhotos