mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add token validation util, fix deepcopy of sessions and make bot hashable and comparable
This commit is contained in:
parent
9adc2f91bd
commit
c674b5547b
11 changed files with 223 additions and 41 deletions
|
|
@ -1,6 +1,10 @@
|
|||
from typing import TypeVar
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TypeVar, Dict, Any
|
||||
|
||||
from ...utils.mixins import ContextInstanceMixin
|
||||
from ...utils.token import extract_bot_id, validate_token
|
||||
from ..methods import TelegramMethod
|
||||
from .session.aiohttp import AiohttpSession
|
||||
from .session.base import BaseSession
|
||||
|
|
@ -10,13 +14,20 @@ T = TypeVar("T")
|
|||
|
||||
class BaseBot(ContextInstanceMixin):
|
||||
def __init__(self, token: str, session: BaseSession = None):
|
||||
validate_token(token)
|
||||
|
||||
if session is None:
|
||||
session = AiohttpSession()
|
||||
|
||||
self.session = session
|
||||
self.token = token
|
||||
self.__token = token
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return extract_bot_id(self.__token)
|
||||
|
||||
async def emit(self, method: TelegramMethod[T]) -> T:
|
||||
return await self.session.make_request(self.token, method)
|
||||
return await self.session.make_request(self.__token, method)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
|
@ -26,3 +37,11 @@ class BaseBot(ContextInstanceMixin):
|
|||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.session.close()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__token)
|
||||
|
||||
def __eq__(self, other: BaseBot):
|
||||
if not isinstance(other, BaseBot):
|
||||
return False
|
||||
return hash(self) == hash(other)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from typing import Callable, Optional, TypeVar, cast
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Callable, Optional, TypeVar, cast, Dict, Any
|
||||
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
|
|
@ -53,3 +56,17 @@ class AiohttpSession(BaseSession):
|
|||
response = call.build_response(raw_result)
|
||||
self.raise_for_status(response)
|
||||
return cast(T, response.result)
|
||||
|
||||
async def __aenter__(self) -> AiohttpSession:
|
||||
await self.create_session()
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memodict: Dict[str, Any]):
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memodict[id(self)] = result
|
||||
for key, value in self.__dict__.items():
|
||||
# aiohttp ClientSession cannot be copied.
|
||||
copied_value = copy.deepcopy(value, memo=memodict) if key != '_session' else None
|
||||
setattr(result, key, copied_value)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import datetime
|
||||
import json
|
||||
|
|
@ -13,9 +15,9 @@ class BaseSession(abc.ABC):
|
|||
def __init__(
|
||||
self,
|
||||
api: Optional[TelegramAPIServer] = None,
|
||||
json_loads: Optional[Callable] = None,
|
||||
json_dumps: Optional[Callable] = None,
|
||||
):
|
||||
json_loads: Optional[Callable[[Any], Any]] = None,
|
||||
json_dumps: Optional[Callable[[Any], Any]] = None,
|
||||
) -> None:
|
||||
if api is None:
|
||||
api = PRODUCTION
|
||||
if json_loads is None:
|
||||
|
|
@ -27,7 +29,7 @@ class BaseSession(abc.ABC):
|
|||
self.json_loads = json_loads
|
||||
self.json_dumps = json_dumps
|
||||
|
||||
def raise_for_status(self, response: Response[T]):
|
||||
def raise_for_status(self, response: Response[T]) -> None:
|
||||
if response.ok:
|
||||
return
|
||||
raise Exception(response.description)
|
||||
|
|
@ -37,7 +39,7 @@ class BaseSession(abc.ABC):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def make_request(self, token: str, method: TelegramMethod[T]) -> T:
|
||||
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
|
||||
pass
|
||||
|
||||
def prepare_value(self, value: Any) -> Union[str, int, bool]:
|
||||
|
|
@ -53,9 +55,15 @@ class BaseSession(abc.ABC):
|
|||
else:
|
||||
return str(value)
|
||||
|
||||
def clean_json(self, value: Any):
|
||||
def clean_json(self, value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [self.clean_json(v) for v in value if v is not None]
|
||||
elif isinstance(value, dict):
|
||||
return {k: self.clean_json(v) for k, v in value.items() if v is not None}
|
||||
return value
|
||||
|
||||
async def __aenter__(self) -> BaseSession:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
|
|
|||
42
aiogram/utils/token.py
Normal file
42
aiogram/utils/token.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from functools import lru_cache
|
||||
|
||||
|
||||
class TokenValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def validate_token(token: str) -> bool:
|
||||
"""
|
||||
Validate Telegram token
|
||||
|
||||
:param token:
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(token, str):
|
||||
raise TokenValidationError(
|
||||
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
|
||||
)
|
||||
|
||||
if any(x.isspace() for x in token):
|
||||
message = "Token is invalid! It can't contains spaces."
|
||||
raise TokenValidationError(message)
|
||||
|
||||
left, sep, right = token.partition(":")
|
||||
if (not sep) or (not left.isdigit()) or (not right):
|
||||
raise TokenValidationError("Token is invalid!")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def extract_bot_id(token: str) -> int:
|
||||
"""
|
||||
Extract bot ID from Telegram token
|
||||
|
||||
:param token:
|
||||
:return:
|
||||
"""
|
||||
validate_token(token)
|
||||
raw_bot_id, *_ = token.split(":")
|
||||
return int(raw_bot_id)
|
||||
Loading…
Add table
Add a link
Reference in a new issue