Add token validation util, fix deepcopy of sessions and make bot hashable and comparable

This commit is contained in:
Alex Root Junior 2019-11-28 23:12:44 +02:00
parent 9adc2f91bd
commit c674b5547b
11 changed files with 223 additions and 41 deletions

View file

@ -1,6 +1,10 @@
from typing import TypeVar
from __future__ import annotations
import copy
from typing import TypeVar, Dict, Any
from ...utils.mixins import ContextInstanceMixin
from ...utils.token import extract_bot_id, validate_token
from ..methods import TelegramMethod
from .session.aiohttp import AiohttpSession
from .session.base import BaseSession
@ -10,13 +14,20 @@ T = TypeVar("T")
class BaseBot(ContextInstanceMixin):
def __init__(self, token: str, session: BaseSession = None):
validate_token(token)
if session is None:
session = AiohttpSession()
self.session = session
self.token = token
self.__token = token
@property
def id(self):
return extract_bot_id(self.__token)
async def emit(self, method: TelegramMethod[T]) -> T:
return await self.session.make_request(self.token, method)
return await self.session.make_request(self.__token, method)
async def close(self):
await self.session.close()
@ -26,3 +37,11 @@ class BaseBot(ContextInstanceMixin):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()
def __hash__(self):
return hash(self.__token)
def __eq__(self, other: BaseBot):
if not isinstance(other, BaseBot):
return False
return hash(self) == hash(other)

View file

@ -1,4 +1,7 @@
from typing import Callable, Optional, TypeVar, cast
from __future__ import annotations
import copy
from typing import Callable, Optional, TypeVar, cast, Dict, Any
from aiohttp import ClientSession, FormData
@ -53,3 +56,17 @@ class AiohttpSession(BaseSession):
response = call.build_response(raw_result)
self.raise_for_status(response)
return cast(T, response.result)
async def __aenter__(self) -> AiohttpSession:
await self.create_session()
return self
def __deepcopy__(self, memodict: Dict[str, Any]):
cls = self.__class__
result = cls.__new__(cls)
memodict[id(self)] = result
for key, value in self.__dict__.items():
# aiohttp ClientSession cannot be copied.
copied_value = copy.deepcopy(value, memo=memodict) if key != '_session' else None
setattr(result, key, copied_value)
return result

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
import datetime
import json
@ -13,9 +15,9 @@ class BaseSession(abc.ABC):
def __init__(
self,
api: Optional[TelegramAPIServer] = None,
json_loads: Optional[Callable] = None,
json_dumps: Optional[Callable] = None,
):
json_loads: Optional[Callable[[Any], Any]] = None,
json_dumps: Optional[Callable[[Any], Any]] = None,
) -> None:
if api is None:
api = PRODUCTION
if json_loads is None:
@ -27,7 +29,7 @@ class BaseSession(abc.ABC):
self.json_loads = json_loads
self.json_dumps = json_dumps
def raise_for_status(self, response: Response[T]):
def raise_for_status(self, response: Response[T]) -> None:
if response.ok:
return
raise Exception(response.description)
@ -37,7 +39,7 @@ class BaseSession(abc.ABC):
pass
@abc.abstractmethod
async def make_request(self, token: str, method: TelegramMethod[T]) -> T:
async def make_request(self, token: str, method: TelegramMethod[T]) -> T: # pragma: no cover
pass
def prepare_value(self, value: Any) -> Union[str, int, bool]:
@ -53,9 +55,15 @@ class BaseSession(abc.ABC):
else:
return str(value)
def clean_json(self, value: Any):
def clean_json(self, value: Any) -> Any:
if isinstance(value, list):
return [self.clean_json(v) for v in value if v is not None]
elif isinstance(value, dict):
return {k: self.clean_json(v) for k, v in value.items() if v is not None}
return value
async def __aenter__(self) -> BaseSession:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

42
aiogram/utils/token.py Normal file
View file

@ -0,0 +1,42 @@
from functools import lru_cache
class TokenValidationError(Exception):
pass
@lru_cache()
def validate_token(token: str) -> bool:
"""
Validate Telegram token
:param token:
:return:
"""
if not isinstance(token, str):
raise TokenValidationError(
f"Token is invalid! It must be 'str' type instead of {type(token)} type."
)
if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces."
raise TokenValidationError(message)
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right):
raise TokenValidationError("Token is invalid!")
return True
@lru_cache()
def extract_bot_id(token: str) -> int:
"""
Extract bot ID from Telegram token
:param token:
:return:
"""
validate_token(token)
raw_bot_id, *_ = token.split(":")
return int(raw_bot_id)