From 5797ab6c992ba4a0fe16dcfef799f3682326b892 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Thu, 14 Nov 2019 02:05:08 +0200 Subject: [PATCH] Allow to send files --- aiogram/api/methods/base.py | 2 +- aiogram/api/session/aiohttp.py | 14 +++--- aiogram/api/session/base.py | 15 ++++--- aiogram/api/types/__init__.py | 4 +- aiogram/api/types/input_file.py | 76 ++++++++++++++++++++++++++++++++- 5 files changed, 91 insertions(+), 20 deletions(-) diff --git a/aiogram/api/methods/base.py b/aiogram/api/methods/base.py index 3b0a230e..b16e5c8e 100644 --- a/aiogram/api/methods/base.py +++ b/aiogram/api/methods/base.py @@ -19,7 +19,7 @@ class Request(BaseModel): method: str data: Dict[str, Optional[Any]] - files: Optional[Dict[str, Union[io.BytesIO, bytes, InputFile]]] + files: Optional[Dict[str, InputFile]] class Config(BaseConfig): arbitrary_types_allowed = True diff --git a/aiogram/api/session/aiohttp.py b/aiogram/api/session/aiohttp.py index 4378e4d1..a96fe676 100644 --- a/aiogram/api/session/aiohttp.py +++ b/aiogram/api/session/aiohttp.py @@ -1,12 +1,11 @@ from typing import Optional, TypeVar from aiohttp import ClientSession, FormData -from pydantic import BaseModel -from .base import BaseSession, TelegramAPIServer, PRODUCTION -from ..methods import TelegramMethod, Request +from ..methods import Request, TelegramMethod +from .base import PRODUCTION, BaseSession, TelegramAPIServer -T = TypeVar('T') +T = TypeVar("T") class AiohttpSession(BaseSession): @@ -28,11 +27,11 @@ class AiohttpSession(BaseSession): if value is None: continue if isinstance(value, bool): - print("elif isinstance(value, bool):", key, value) form.add_field(key, value) else: - print("else:", key, value) form.add_field(key, str(value)) + for key, value in request.files.items(): + form.add_field(key, value, filename=value.filename or key) return form async def make_request(self, token: str, call: TelegramMethod[T]) -> T: @@ -46,6 +45,5 @@ class AiohttpSession(BaseSession): raw_result = await response.json() response = call.build_response(raw_result) - if not response.ok: - self.raise_for_status(response) + self.raise_for_status(response) return response.result diff --git a/aiogram/api/session/base.py b/aiogram/api/session/base.py index 323ec98d..14040a8a 100644 --- a/aiogram/api/session/base.py +++ b/aiogram/api/session/base.py @@ -1,12 +1,11 @@ import abc import asyncio -from typing import TypeVar, Generic - -from pydantic.dataclasses import dataclass +from typing import Generic, TypeVar from aiogram.api.methods import Response, TelegramMethod +from pydantic.dataclasses import dataclass -T = TypeVar('T') +T = TypeVar("T") @dataclass @@ -22,8 +21,8 @@ class TelegramAPIServer: PRODUCTION = TelegramAPIServer( - base='https://api.telegram.org/bot{token}/{method}', - file='https://api.telegram.org/file/bot{token}/{path}' + base="https://api.telegram.org/bot{token}/{method}", + file="https://api.telegram.org/file/bot{token}/{path}", ) @@ -32,7 +31,9 @@ class BaseSession(abc.ABC, Generic[T]): self.api = api def raise_for_status(self, response: Response[T]): - print(f"ERROR: {response}") + if response.ok: + return + raise Exception(response.description) @abc.abstractmethod async def close(self): diff --git a/aiogram/api/types/__init__.py b/aiogram/api/types/__init__.py index 60a1e497..fd2ffd4c 100644 --- a/aiogram/api/types/__init__.py +++ b/aiogram/api/types/__init__.py @@ -97,6 +97,7 @@ from .webhook_info import WebhookInfo __all__ = ( "TelegramObject", + "InputFile", "Update", "WebhookInfo", "User", @@ -135,7 +136,6 @@ __all__ = ( "InputMediaAnimation", "InputMediaAudio", "InputMediaDocument", - "InputFile", "Sticker", "StickerSet", "MaskPosition", @@ -195,5 +195,5 @@ __all__ = ( ) # Load typing forward refs for every TelegramObject -for entity in __all__[1:]: +for entity in __all__[2:]: globals()[entity].update_forward_refs(**globals()) diff --git a/aiogram/api/types/input_file.py b/aiogram/api/types/input_file.py index 8f8a8e4d..806ec3d3 100644 --- a/aiogram/api/types/input_file.py +++ b/aiogram/api/types/input_file.py @@ -1,11 +1,83 @@ from __future__ import annotations -from .base import TelegramObject +import io +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +import aiofiles as aiofiles + +DEFAULT_CHUNK_SIZE = 64 * 1024 # 64 kb -class InputFile(TelegramObject): +class InputFile(ABC): """ This object represents the contents of a file to be uploaded. Must be posted using multipart/form-data in the usual way that files are uploaded via the browser. Source: https://core.telegram.org/bots/api#inputfile """ + + def __init__(self, filename: Optional[str] = None, chunk_size: int = DEFAULT_CHUNK_SIZE): + self.filename = filename + self.chunk_size = chunk_size + + @classmethod + def __get_validators__(cls): + yield + + @abstractmethod + async def read(self, chunk_size: int): + pass + + async def __aiter__(self): + async for chunk in self.read(self.chunk_size): + yield chunk + + +class BufferedInputFile(InputFile): + def __init__(self, file: bytes, filename: str, chunk_size: int = DEFAULT_CHUNK_SIZE): + super().__init__(filename=filename, chunk_size=chunk_size) + + self.data = file + + @classmethod + def from_file( + cls, + path: Union[str, Path], + filename: Optional[str] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): + if filename is None: + filename = os.path.basename(path) + with open(path, "rb") as f: + data = f.read() + return cls(data, filename=filename, chunk_size=chunk_size) + + async def read(self, chunk_size: int): + buffer = io.BytesIO(self.data) + chunk = buffer.read(chunk_size) + while chunk: + yield chunk + chunk = buffer.read(chunk_size) + + +class FSInputFile(InputFile): + def __init__( + self, + path: Union[str, Path], + filename: Optional[str] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): + if filename is None: + filename = os.path.basename(path) + super().__init__(filename=filename, chunk_size=chunk_size) + + self.path = path + + async def read(self, chunk_size: int): + async with aiofiles.open(self.path, "rb") as f: + chunk = await f.read(chunk_size) + while chunk: + yield chunk + chunk = await f.read(chunk_size)