diff --git a/.gitignore b/.gitignore index 6c2a9404..a8b34bd1 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,6 @@ experiment.py # Doc's docs/html + +# i18n/l10n +*.mo diff --git a/README.md b/README.md index 1bc5260f..208ce568 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![Github issues](https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square)](https://github.com/aiogram/aiogram/issues) [![MIT License](https://img.shields.io/pypi/l/aiogram.svg?style=flat-square)](https://opensource.org/licenses/MIT) -**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.6 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler. +**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.7 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler. You can [read the docs here](http://aiogram.readthedocs.io/en/latest/). diff --git a/aiogram/__init__.py b/aiogram/__init__.py index 2f879d92..0bb6ea26 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -1,14 +1,42 @@ import asyncio +import os +from . import bot +from . import contrib +from . import dispatcher +from . import types +from . import utils from .bot import Bot from .dispatcher import Dispatcher +from .dispatcher import filters +from .dispatcher import middlewares +from .utils import exceptions, executor, helper, markdown as md try: import uvloop except ImportError: uvloop = None else: - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + if 'DISABLE_UVLOOP' not in os.environ: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -__version__ = '1.4' +__all__ = [ + 'Bot', + 'Dispatcher', + '__api_version__', + '__version__', + 'bot', + 'contrib', + 'dispatcher', + 'exceptions', + 'executor', + 'filters', + 'helper', + 'md', + 'middlewares', + 'types', + 'utils' +] + +__version__ = '2.0.dev1' __api_version__ = '3.6' diff --git a/aiogram/bot/api.py b/aiogram/bot/api.py index 49c7ef63..961e9d10 100644 --- a/aiogram/bot/api.py +++ b/aiogram/bot/api.py @@ -1,8 +1,14 @@ +import abc +import asyncio import logging import os +import ssl +from asyncio import AbstractEventLoop from http import HTTPStatus +from typing import Optional, Tuple import aiohttp +import certifi from .. import types from ..utils import exceptions @@ -34,58 +40,73 @@ def check_token(token: str) -> bool: return True -async def _check_result(method_name, response): - """ - Checks whether `result` is a valid API response. - A result is considered invalid if: - - The server returned an HTTP response code other than 200 - - The content of the result is invalid JSON. - - The method call was unsuccessful (The JSON 'ok' field equals False) +async def check_result(method_name: str, content_type: str, status_code: int, body: str): + """ + Checks whether `result` is a valid API response. + A result is considered invalid if: + - The server returned an HTTP response code other than 200 + - The content of the result is invalid JSON. + - The method call was unsuccessful (The JSON 'ok' field equals False) - :raises ApiException: if one of the above listed cases is applicable - :param method_name: The name of the method called - :param response: The returned response of the method request - :return: The result parsed to a JSON dictionary. - """ - body = await response.text() - log.debug(f"Response for {method_name}: [{response.status}] {body}") + :param method_name: The name of the method called + :param status_code: status code + :param content_type: content type of result + :param body: result body + :return: The result parsed to a JSON dictionary + :raises ApiException: if one of the above listed cases is applicable + """ + log.debug('Response for %s: [%d] "%r"', method_name, status_code, body) - if response.content_type != 'application/json': - raise exceptions.NetworkError(f"Invalid response with content type {response.content_type}: \"{body}\"") + if content_type != 'application/json': + raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"") + try: + result_json = json.loads(body) + except ValueError: + result_json = {} + + description = result_json.get('description') or body + parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {}) + + if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED: + return result_json.get('result') + elif parameters.retry_after: + raise exceptions.RetryAfter(parameters.retry_after) + elif parameters.migrate_to_chat_id: + raise exceptions.MigrateToChat(parameters.migrate_to_chat_id) + elif status_code == HTTPStatus.BAD_REQUEST: + exceptions.BadRequest.detect(description) + elif status_code == HTTPStatus.NOT_FOUND: + exceptions.NotFound.detect(description) + elif status_code == HTTPStatus.CONFLICT: + exceptions.ConflictError.detect(description) + elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]: + exceptions.Unauthorized.detect(description) + elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: + raise exceptions.NetworkError('File too large for uploading. ' + 'Check telegram api limits https://core.telegram.org/bots/api#senddocument') + elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: + if 'restart' in description: + raise exceptions.RestartingTelegram() + raise exceptions.TelegramAPIError(description) + raise exceptions.TelegramAPIError(f"{description} [{status_code}]") + + +async def make_request(session, token, method, data=None, files=None, **kwargs): + # log.debug(f"Make request: '{method}' with data: {data} and files {files}") + log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files) + + url = Methods.api_url(token=token, method=method) + + req = compose_data(data, files) try: - result_json = await response.json(loads=json.loads) - except ValueError: - result_json = {} - - description = result_json.get('description') or body - parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {}) - - if HTTPStatus.OK <= response.status <= HTTPStatus.IM_USED: - return result_json.get('result') - elif parameters.retry_after: - raise exceptions.RetryAfter(parameters.retry_after) - elif parameters.migrate_to_chat_id: - raise exceptions.MigrateToChat(parameters.migrate_to_chat_id) - elif response.status == HTTPStatus.BAD_REQUEST: - exceptions.BadRequest.detect(description) - elif response.status == HTTPStatus.NOT_FOUND: - exceptions.NotFound.detect(description) - elif response.status == HTTPStatus.CONFLICT: - exceptions.ConflictError.detect(description) - elif response.status in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]: - exceptions.Unauthorized.detect(description) - elif response.status == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: - raise exceptions.NetworkError('File too large for uploading. ' - 'Check telegram api limits https://core.telegram.org/bots/api#senddocument') - elif response.status >= HTTPStatus.INTERNAL_SERVER_ERROR: - if 'restart' in description: - raise exceptions.RestartingTelegram() - raise exceptions.TelegramAPIError(description) - raise exceptions.TelegramAPIError(f"{description} [{response.status}]") + async with session.post(url, data=req, **kwargs) as response: + return await check_result(method, response.content_type, response.status, await response.text()) + except aiohttp.ClientError as e: + raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}") -def _guess_filename(obj): +def guess_filename(obj): """ Get file name from object @@ -97,7 +118,7 @@ def _guess_filename(obj): return os.path.basename(name) -def _compose_data(params=None, files=None): +def compose_data(params=None, files=None): """ Prepare request data @@ -121,47 +142,13 @@ def _compose_data(params=None, files=None): elif isinstance(f, types.InputFile): filename, fileobj = f.filename, f.file else: - filename, fileobj = _guess_filename(f) or key, f + filename, fileobj = guess_filename(f) or key, f data.add_field(key, fileobj, filename=filename) return data -async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict: - """ - Make request to API - - That make request with Content-Type: - application/x-www-form-urlencoded - For simple request - and multipart/form-data - for files uploading - - https://core.telegram.org/bots/api#making-requests - - :param session: HTTP Client session - :type session: :obj:`aiohttp.ClientSession` - :param token: BOT token - :type token: :obj:`str` - :param method: API method - :type method: :obj:`str` - :param data: request payload - :type data: :obj:`dict` - :param files: files - :type files: :obj:`dict` - :return: result - :rtype :obj:`bool` or :obj:`dict` - """ - log.debug("Make request: '{0}' with data: {1} and files {2}".format( - method, data or {}, files or {})) - data = _compose_data(data, files) - url = Methods.api_url(token=token, method=method) - try: - async with session.post(url, data=data, **kwargs) as response: - return await _check_result(method, response) - except aiohttp.ClientError as e: - raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}") - - class Methods(Helper): """ Helper for Telegram API Methods listed on https://core.telegram.org/bots/api diff --git a/aiogram/bot/base.py b/aiogram/bot/base.py index ab1acb7b..e37f1923 100644 --- a/aiogram/bot/base.py +++ b/aiogram/bot/base.py @@ -47,7 +47,6 @@ class BaseBot: api.check_token(token) self.__token = token - # Proxy settings self.proxy = proxy self.proxy_auth = proxy_auth @@ -59,37 +58,42 @@ class BaseBot: # aiohttp main session ssl_context = ssl.create_default_context(cafile=certifi.where()) - if isinstance(proxy, str) and proxy.startswith('socks5://'): - from aiosocksy.connector import ProxyClientRequest, ProxyConnector - connector = ProxyConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop) - request_class = ProxyClientRequest + if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')): + from aiohttp_socks import SocksConnector + from aiohttp_socks.helpers import parse_socks_url + + socks_ver, host, port, username, password = parse_socks_url(proxy) + if proxy_auth and not username or password: + username = proxy_auth.login + password = proxy_auth.password + + connector = SocksConnector(socks_ver=socks_ver, host=host, port=port, + username=username, password=password, + limit=connections_limit, ssl_context=ssl_context, + loop=self.loop) + + self.proxy = None + self.proxy_auth = None else: connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop) - request_class = aiohttp.ClientRequest - self.session = aiohttp.ClientSession(connector=connector, request_class=request_class, - loop=self.loop, json_serialize=json.dumps) - - # Data stored in bot instance - self._data = {} + self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps) self.parse_mode = parse_mode - def __del__(self): - # asyncio.ensure_future(self.close()) - pass + # Data stored in bot instance + self._data = {} async def close(self): """ Close all client sessions """ - if self.session and not self.session.closed: - await self.session.close() + await self.session.close() async def request(self, method: base.String, data: Optional[Dict] = None, - files: Optional[Dict] = None) -> Union[List, Dict, base.Boolean]: + files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]: """ Make an request to Telegram Bot API @@ -105,8 +109,8 @@ class BaseBot: :rtype: Union[List, Dict] :raise: :obj:`aiogram.exceptions.TelegramApiError` """ - return await api.request(self.session, self.__token, method, data, files, - proxy=self.proxy, proxy_auth=self.proxy_auth) + return await api.make_request(self.session, self.__token, method, data, files, + proxy=self.proxy, proxy_auth=self.proxy_auth, **kwargs) async def download_file(self, file_path: base.String, destination: Optional[base.InputFile] = None, diff --git a/aiogram/bot/bot.py b/aiogram/bot/bot.py index 7cc0606a..1b6d1451 100644 --- a/aiogram/bot/bot.py +++ b/aiogram/bot/bot.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import typing +from contextvars import ContextVar from .base import BaseBot, api from .. import types from ..types import base -from ..utils.payload import generate_payload, prepare_arg +from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file class Bot(BaseBot): @@ -30,6 +33,14 @@ class Bot(BaseBot): if hasattr(self, '_me'): delattr(self, '_me') + @classmethod + def current(cls) -> Bot: + """ + Return active bot instance from the current context or None + :return: Bot or None + """ + return bot.get() + async def download_file_by_id(self, file_id: base.String, destination=None, timeout: base.Integer = 30, chunk_size: base.Integer = 65536, seek: base.Boolean = True): @@ -43,7 +54,7 @@ class Bot(BaseBot): :param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO` :param timeout: int :param chunk_size: int - :param seek: bool - go to start of file when downloading is finished. + :param seek: bool - go to start of file when downloading is finished :return: destination """ file = await self.get_file(file_id) @@ -67,21 +78,21 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#getupdates - :param offset: Identifier of the first update to be returned. + :param offset: Identifier of the first update to be returned :type offset: :obj:`typing.Union[base.Integer, None]` - :param limit: Limits the number of updates to be retrieved. + :param limit: Limits the number of updates to be retrieved :type limit: :obj:`typing.Union[base.Integer, None]` - :param timeout: Timeout in seconds for long polling. + :param timeout: Timeout in seconds for long polling :type timeout: :obj:`typing.Union[base.Integer, None]` - :param allowed_updates: List the types of updates you want your bot to receive. + :param allowed_updates: List the types of updates you want your bot to receive :type allowed_updates: :obj:`typing.Union[typing.List[base.String], None]` - :return: An Array of Update objects is returned. + :return: An Array of Update objects is returned :rtype: :obj:`typing.List[types.Update]` """ allowed_updates = prepare_arg(allowed_updates) payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_UPDATES, payload) + result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None) return [types.Update(**update) for update in result] async def set_webhook(self, url: base.String, @@ -98,20 +109,23 @@ class Bot(BaseBot): :param url: HTTPS url to send updates to. Use an empty string to remove webhook integration :type url: :obj:`base.String` - :param certificate: Upload your public key certificate so that the root certificate in use can be checked. + :param certificate: Upload your public key certificate so that the root certificate in use can be checked :type certificate: :obj:`typing.Union[base.InputFile, None]` :param max_connections: Maximum allowed number of simultaneous HTTPS connections to the webhook for update delivery, 1-100. :type max_connections: :obj:`typing.Union[base.Integer, None]` - :param allowed_updates: List the types of updates you want your bot to receive. + :param allowed_updates: List the types of updates you want your bot to receive :type allowed_updates: :obj:`typing.Union[typing.List[base.String], None]` - :return: Returns true. + :return: Returns true :rtype: :obj:`base.Boolean` """ allowed_updates = prepare_arg(allowed_updates) payload = generate_payload(**locals(), exclude=['certificate']) - result = await self.send_file('certificate', api.Methods.SET_WEBHOOK, certificate, payload) + files = {} + prepare_file(payload, files, 'certificate', certificate) + + result = await self.request(api.Methods.SET_WEBHOOK, payload, files) return result async def delete_webhook(self) -> base.Boolean: @@ -121,12 +135,12 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#deletewebhook - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.DELETE_WEBHOOK, payload) + result = await self.request(api.Methods.DELETE_WEBHOOK, payload) return result async def get_webhook_info(self) -> types.WebhookInfo: @@ -137,12 +151,12 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#getwebhookinfo - :return: On success, returns a WebhookInfo object. + :return: On success, returns a WebhookInfo object :rtype: :obj:`types.WebhookInfo` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_WEBHOOK_INFO, payload) + result = await self.request(api.Methods.GET_WEBHOOK_INFO, payload) return types.WebhookInfo(**result) # === Base methods === @@ -154,12 +168,12 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#getme - :return: Returns basic information about the bot in form of a User object. + :return: Returns basic information about the bot in form of a User object :rtype: :obj:`types.User` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_ME, payload) + result = await self.request(api.Methods.GET_ME, payload) return types.User(**result) async def send_message(self, chat_id: typing.Union[base.Integer, base.String], text: base.String, @@ -185,14 +199,14 @@ class Bot(BaseBot): :type parse_mode: :obj:`typing.Union[base.String, None]` :param disable_web_page_preview: Disables link previews for links in this message :type disable_web_page_preview: :obj:`typing.Union[base.Boolean, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) @@ -201,7 +215,6 @@ class Bot(BaseBot): payload.setdefault('parse_mode', self.parse_mode) result = await self.request(api.Methods.SEND_MESSAGE, payload) - return types.Message(**result) async def forward_message(self, chat_id: typing.Union[base.Integer, base.String], @@ -216,16 +229,16 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param from_chat_id: Unique identifier for the chat where the original message was sent :type from_chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param message_id: Message identifier in the chat specified in from_chat_id :type message_id: :obj:`base.Integer` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.FORWARD_MESSAGE, payload) + result = await self.request(api.Methods.FORWARD_MESSAGE, payload) return types.Message(**result) async def send_photo(self, chat_id: typing.Union[base.Integer, base.String], @@ -245,21 +258,21 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param photo: Photo to send. + :param photo: Photo to send :type photo: :obj:`typing.Union[base.InputFile, base.String]` :param caption: Photo caption (may also be used when resending photos by file_id), 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or inline URLs in your bot's message. :type parse_mode: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) @@ -267,8 +280,10 @@ class Bot(BaseBot): if self.parse_mode: payload.setdefault('parse_mode', self.parse_mode) - result = await self.send_file('photo', api.Methods.SEND_PHOTO, photo, payload) + files = {} + prepare_file(payload, files, 'photo', photo) + result = await self.request(api.Methods.SEND_PHOTO, payload, files) return types.Message(**result) async def send_audio(self, chat_id: typing.Union[base.Integer, base.String], @@ -295,7 +310,7 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param audio: Audio file to send. + :param audio: Audio file to send :type audio: :obj:`typing.Union[base.InputFile, base.String]` :param caption: Audio caption, 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` @@ -307,17 +322,17 @@ class Bot(BaseBot): :param performer: Performer :type performer: :obj:`typing.Union[base.String, None]` :param title: Track name - :param thumb: Thumbnail of the file sent. - :param :obj:`typing.Union[base.InputFile, base.String, None]` :type title: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param thumb: Thumbnail of the file sent + :type thumb: :obj:`typing.Union[base.InputFile, base.String, None]` + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. - :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, - types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :param reply_markup: Additional interface options + :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, + types.ReplyKeyboardRemove, types.ForceReply, None]` + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) @@ -325,8 +340,10 @@ class Bot(BaseBot): if self.parse_mode: payload.setdefault('parse_mode', self.parse_mode) - result = await self.send_file('audio', api.Methods.SEND_AUDIO, audio, payload) + files = {} + prepare_file(payload, files, 'audio', audio) + result = await self.request(api.Methods.SEND_AUDIO, payload, files) return types.Message(**result) async def send_document(self, chat_id: typing.Union[base.Integer, base.String], @@ -349,23 +366,23 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param document: File to send. + :param document: File to send :type document: :obj:`typing.Union[base.InputFile, base.String]` - :param thumb: Thumbnail of the file sent. - :param :obj:`typing.Union[base.InputFile, base.String, None]` + :param thumb: Thumbnail of the file sent + :type thumb: :obj:`typing.Union[base.InputFile, base.String, None]` :param caption: Document caption (may also be used when resending documents by file_id), 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or inline URLs in your bot's message. :type parse_mode: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. - :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, - types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply], None]` - :return: On success, the sent Message is returned. + :param reply_markup: Additional interface options + :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, + types.ReplyKeyboardRemove, types.ForceReply], None]` + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) @@ -373,8 +390,10 @@ class Bot(BaseBot): if self.parse_mode: payload.setdefault('parse_mode', self.parse_mode) - result = await self.send_file('document', api.Methods.SEND_DOCUMENT, document, payload) + files = {} + prepare_file(payload, files, 'document', document) + result = await self.request(api.Methods.SEND_DOCUMENT, payload, document) return types.Message(**result) async def send_video(self, chat_id: typing.Union[base.Integer, base.String], @@ -400,7 +419,7 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param video: Video to send. + :param video: Video to send :type video: :obj:`typing.Union[base.InputFile, base.String]` :param duration: Duration of sent video in seconds :type duration: :obj:`typing.Union[base.Integer, None]` @@ -408,8 +427,8 @@ class Bot(BaseBot): :type width: :obj:`typing.Union[base.Integer, None]` :param height: Video height :type height: :obj:`typing.Union[base.Integer, None]` - :param thumb: Thumbnail of the file sent. - :param :obj:`typing.Union[base.InputFile, base.String, None]` + :param thumb: Thumbnail of the file sent + :type thumb: :obj:`typing.Union[base.InputFile, base.String, None]` :param caption: Video caption (may also be used when resending videos by file_id), 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, @@ -417,40 +436,44 @@ class Bot(BaseBot): :type parse_mode: :obj:`typing.Union[base.String, None]` :param supports_streaming: Pass True, if the uploaded video is suitable for streaming :type supports_streaming: :obj:`typing.Union[base.Boolean, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) - payload = generate_payload(**locals(), exclude=['video']) + payload = generate_payload(**locals(), exclude=['video', 'thumb']) if self.parse_mode: payload.setdefault('parse_mode', self.parse_mode) - result = await self.send_file('video', api.Methods.SEND_VIDEO, video, payload) + files = {} + prepare_file(payload, files, 'video', video) + prepare_attachment(payload, files, 'thumb', thumb) + result = await self.request(api.Methods.SEND_VIDEO, payload, files) return types.Message(**result) async def send_animation(self, - chat_id: typing.Union[base.Integer, base.String], - animation: typing.Union[base.InputFile, base.String], - duration: typing.Union[base.Integer, None] = None, - width: typing.Union[base.Integer, None] = None, - height: typing.Union[base.Integer, None] = None, - thumb: typing.Union[typing.Union[base.InputFile, base.String], None] = None, - caption: typing.Union[base.String, None] = None, - parse_mode: typing.Union[base.String, None] = None, - disable_notification: typing.Union[base.Boolean, None] = None, - reply_to_message_id: typing.Union[base.Integer, None] = None, - reply_markup: typing.Union[typing.Union[types.InlineKeyboardMarkup, - types.ReplyKeyboardMarkup, - types.ReplyKeyboardRemove, - types.ForceReply], None] = None,) -> types.Message: + chat_id: typing.Union[base.Integer, base.String], + animation: typing.Union[base.InputFile, base.String], + duration: typing.Union[base.Integer, None] = None, + width: typing.Union[base.Integer, None] = None, + height: typing.Union[base.Integer, None] = None, + thumb: typing.Union[typing.Union[base.InputFile, base.String], None] = None, + caption: typing.Union[base.String, None] = None, + parse_mode: typing.Union[base.String, None] = None, + disable_notification: typing.Union[base.Boolean, None] = None, + reply_to_message_id: typing.Union[base.Integer, None] = None, + reply_markup: typing.Union[typing.Union[types.InlineKeyboardMarkup, + types.ReplyKeyboardMarkup, + types.ReplyKeyboardRemove, + types.ForceReply], None] = None + ) -> types.Message: """ Use this method to send animation files (GIF or H.264/MPEG-4 AVC video without sound). @@ -459,9 +482,12 @@ class Bot(BaseBot): Source https://core.telegram.org/bots/api#sendanimation - :param chat_id: Unique identifier for the target chat or username of the target channel (in the format @channelusername) + :param chat_id: Unique identifier for the target chat or username of the target channel + (in the format @channelusername) :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param animation: Animation to send. Pass a file_id as String to send an animation that exists on the Telegram servers (recommended), pass an HTTP URL as a String for Telegram to get an animation from the Internet, or upload a new animation using multipart/form-data. + :param animation: Animation to send. Pass a file_id as String to send an animation that exists + on the Telegram servers (recommended), pass an HTTP URL as a String for Telegram to get an animation + from the Internet, or upload a new animation using multipart/form-data :type animation: :obj:`typing.Union[base.InputFile, base.String]` :param duration: Duration of sent animation in seconds :type duration: :obj:`typing.Union[base.Integer, None]` @@ -469,25 +495,33 @@ class Bot(BaseBot): :type width: :obj:`typing.Union[base.Integer, None]` :param height: Animation height :type height: :obj:`typing.Union[base.Integer, None]` - :param thumb: Thumbnail of the file sent. The thumbnail should be in JPEG format and less than 200 kB in size. A thumbnail‘s width and height should not exceed 90. Ignored if the file is not uploaded using multipart/form-data. Thumbnails can’t be reused and can be only uploaded as a new file, so you can pass “attach://” if the thumbnail was uploaded using multipart/form-data under . + :param thumb: Thumbnail of the file sent. The thumbnail should be in JPEG format and less than 200 kB in size. + A thumbnail‘s width and height should not exceed 90. :type thumb: :obj:`typing.Union[typing.Union[base.InputFile, base.String], None]` :param caption: Animation caption (may also be used when resending animation by file_id), 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` - :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or inline URLs in the media caption. + :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, + fixed-width text or inline URLs in the media caption :type parse_mode: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. A JSON-serialized object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - :type reply_markup: :obj:`typing.Union[typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply], None]` - :return: On success, the sent Message is returned. + :param reply_markup: Additional interface options. A JSON-serialized object for an inline keyboard, + custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user + :type reply_markup: :obj:`typing.Union[typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, + types.ReplyKeyboardRemove, types.ForceReply], None]` + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) - payload = generate_payload(**locals(), exclude=["animation"]) - result = await self.send_file("animation", api.Methods.SEND_ANIMATION, thumb, payload) + payload = generate_payload(**locals(), exclude=["animation", "thumb"]) + + files = {} + prepare_file(payload, files, 'animation', animation) + prepare_attachment(payload, files, 'thumb', thumb) + result = await self.request(api.Methods.SEND_ANIMATION, payload, files) return types.Message(**result) async def send_voice(self, chat_id: typing.Union[base.Integer, base.String], @@ -512,7 +546,7 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param voice: Audio file to send. + :param voice: Audio file to send :type voice: :obj:`typing.Union[base.InputFile, base.String]` :param caption: Voice message caption, 0-200 characters :type caption: :obj:`typing.Union[base.String, None]` @@ -521,14 +555,14 @@ class Bot(BaseBot): :type parse_mode: :obj:`typing.Union[base.String, None]` :param duration: Duration of the voice message in seconds :type duration: :obj:`typing.Union[base.Integer, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) @@ -536,8 +570,10 @@ class Bot(BaseBot): if self.parse_mode: payload.setdefault('parse_mode', self.parse_mode) - result = await self.send_file('voice', api.Methods.SEND_VOICE, voice, payload) + files = {} + prepare_file(payload, files, 'voice', voice) + result = await self.request(api.Methods.SEND_VOICE, payload, files) return types.Message(**result) async def send_video_note(self, chat_id: typing.Union[base.Integer, base.String], @@ -559,28 +595,31 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param video_note: Video note to send. + :param video_note: Video note to send :type video_note: :obj:`typing.Union[base.InputFile, base.String]` :param duration: Duration of sent video in seconds :type duration: :obj:`typing.Union[base.Integer, None]` :param length: Video width and height :type length: :obj:`typing.Union[base.Integer, None]` - :param thumb: Thumbnail of the file sent. - :param :obj:`typing.Union[base.InputFile, base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param thumb: Thumbnail of the file sent + :type thumb: :obj:`typing.Union[base.InputFile, base.String, None]` + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. - :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, - types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :param reply_markup: Additional interface options + :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, + types.ReplyKeyboardRemove, types.ForceReply, None]` + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals(), exclude=['video_note']) - result = await self.send_file('video_note', api.Methods.SEND_VIDEO_NOTE, video_note, payload) + files = {} + prepare_file(payload, files, 'video_note', video_note) + + result = await self.request(api.Methods.SEND_VIDEO_NOTE, payload, files) return types.Message(**result) async def send_media_group(self, chat_id: typing.Union[base.Integer, base.String], @@ -597,24 +636,23 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param media: A JSON-serialized array describing photos and videos to be sent :type media: :obj:`typing.Union[types.MediaGroup, typing.List]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :return: On success, an array of the sent Messages is returned. + :return: On success, an array of the sent Messages is returned :rtype: typing.List[types.Message] """ # Convert list to MediaGroup if isinstance(media, list): media = types.MediaGroup(media) - # Extract files - files = media.get_files() + files = dict(media.get_files()) media = prepare_arg(media) payload = generate_payload(**locals(), exclude=['files']) - result = await self.request(api.Methods.SEND_MEDIA_GROUP, payload, files) + result = await self.request(api.Methods.SEND_MEDIA_GROUP, payload, files) return [types.Message(**message) for message in result] async def send_location(self, chat_id: typing.Union[base.Integer, base.String], @@ -639,20 +677,20 @@ class Bot(BaseBot): :type longitude: :obj:`base.Float` :param live_period: Period in seconds for which the location will be updated :type live_period: :obj:`typing.Union[base.Integer, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_LOCATION, payload) + result = await self.request(api.Methods.SEND_LOCATION, payload) return types.Message(**result) async def edit_message_live_location(self, latitude: base.Float, longitude: base.Float, @@ -668,7 +706,7 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#editmessagelivelocation - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified :type chat_id: :obj:`typing.Union[base.Integer, base.String, None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message :type message_id: :obj:`typing.Union[base.Integer, None]` @@ -678,7 +716,7 @@ class Bot(BaseBot): :type latitude: :obj:`base.Float` :param longitude: Longitude of new location :type longitude: :obj:`base.Float` - :param reply_markup: A JSON-serialized object for a new inline keyboard. + :param reply_markup: A JSON-serialized object for a new inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` :return: On success, if the edited message was sent by the bot, the edited Message is returned, otherwise True is returned. @@ -686,11 +724,10 @@ class Bot(BaseBot): """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.EDIT_MESSAGE_LIVE_LOCATION, payload) + result = await self.request(api.Methods.EDIT_MESSAGE_LIVE_LOCATION, payload) if isinstance(result, bool): return result - return types.Message(**result) async def stop_message_live_location(self, @@ -705,13 +742,13 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#stopmessagelivelocation - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified :type chat_id: :obj:`typing.Union[base.Integer, base.String, None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message :type message_id: :obj:`typing.Union[base.Integer, None]` :param inline_message_id: Required if chat_id and message_id are not specified. Identifier of the inline message :type inline_message_id: :obj:`typing.Union[base.String, None]` - :param reply_markup: A JSON-serialized object for a new inline keyboard. + :param reply_markup: A JSON-serialized object for a new inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` :return: On success, if the message was sent by the bot, the sent Message is returned, otherwise True is returned. @@ -719,11 +756,10 @@ class Bot(BaseBot): """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.STOP_MESSAGE_LIVE_LOCATION, payload) + result = await self.request(api.Methods.STOP_MESSAGE_LIVE_LOCATION, payload) if isinstance(result, bool): return result - return types.Message(**result) async def send_venue(self, chat_id: typing.Union[base.Integer, base.String], @@ -754,22 +790,22 @@ class Bot(BaseBot): :type address: :obj:`base.String` :param foursquare_id: Foursquare identifier of the venue :type foursquare_id: :obj:`typing.Union[base.String, None]` - :param foursquare_type: Foursquare type of the venue, if known. + :param foursquare_type: Foursquare type of the venue, if known :type foursquare_type: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_VENUE, payload) + result = await self.request(api.Methods.SEND_VENUE, payload) return types.Message(**result) async def send_contact(self, chat_id: typing.Union[base.Integer, base.String], @@ -797,20 +833,20 @@ class Bot(BaseBot): :type last_name: :obj:`typing.Union[base.String, None]` :param vcard: vcard :type vcard: :obj:`typing.Union[base.String, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_CONTACT, payload) + result = await self.request(api.Methods.SEND_CONTACT, payload) return types.Message(**result) async def send_chat_action(self, chat_id: typing.Union[base.Integer, base.String], @@ -827,14 +863,14 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param action: Type of action to broadcast. + :param action: Type of action to broadcast :type action: :obj:`base.String` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_CHAT_ACTION, payload) + result = await self.request(api.Methods.SEND_CHAT_ACTION, payload) return result async def get_user_profile_photos(self, user_id: base.Integer, offset: typing.Union[base.Integer, None] = None, @@ -846,16 +882,16 @@ class Bot(BaseBot): :param user_id: Unique identifier of the target user :type user_id: :obj:`base.Integer` - :param offset: Sequential number of the first photo to be returned. By default, all photos are returned. + :param offset: Sequential number of the first photo to be returned. By default, all photos are returned :type offset: :obj:`typing.Union[base.Integer, None]` - :param limit: Limits the number of photos to be retrieved. Values between 1—100 are accepted. Defaults to 100. + :param limit: Limits the number of photos to be retrieved. Values between 1—100 are accepted. Defaults to 100 :type limit: :obj:`typing.Union[base.Integer, None]` - :return: Returns a UserProfilePhotos object. + :return: Returns a UserProfilePhotos object :rtype: :obj:`types.UserProfilePhotos` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_USER_PROFILE_PHOTOS, payload) + result = await self.request(api.Methods.GET_USER_PROFILE_PHOTOS, payload) return types.UserProfilePhotos(**result) async def get_file(self, file_id: base.String) -> types.File: @@ -870,12 +906,12 @@ class Bot(BaseBot): :param file_id: File identifier to get info about :type file_id: :obj:`base.String` - :return: On success, a File object is returned. + :return: On success, a File object is returned :rtype: :obj:`types.File` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_FILE, payload) + result = await self.request(api.Methods.GET_FILE, payload) return types.File(**result) async def kick_chat_member(self, chat_id: typing.Union[base.Integer, base.String], user_id: base.Integer, @@ -897,15 +933,15 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param user_id: Unique identifier of the target user :type user_id: :obj:`base.Integer` - :param until_date: Date when the user will be unbanned, unix time. + :param until_date: Date when the user will be unbanned, unix time :type until_date: :obj:`typing.Union[base.Integer, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ until_date = prepare_arg(until_date) payload = generate_payload(**locals()) - result = await self.request(api.Methods.KICK_CHAT_MEMBER, payload) + result = await self.request(api.Methods.KICK_CHAT_MEMBER, payload) return result async def unban_chat_member(self, chat_id: typing.Union[base.Integer, base.String], @@ -922,12 +958,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param user_id: Unique identifier of the target user :type user_id: :obj:`base.Integer` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.UNBAN_CHAT_MEMBER, payload) + result = await self.request(api.Methods.UNBAN_CHAT_MEMBER, payload) return result async def restrict_chat_member(self, chat_id: typing.Union[base.Integer, base.String], @@ -948,7 +984,7 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param user_id: Unique identifier of the target user :type user_id: :obj:`base.Integer` - :param until_date: Date when restrictions will be lifted for the user, unix time. + :param until_date: Date when restrictions will be lifted for the user, unix time :type until_date: :obj:`typing.Union[base.Integer, None]` :param can_send_messages: Pass True, if the user can send text messages, contacts, locations and venues :type can_send_messages: :obj:`typing.Union[base.Boolean, None]` @@ -961,13 +997,13 @@ class Bot(BaseBot): :param can_add_web_page_previews: Pass True, if the user may add web page previews to their messages, implies can_send_media_messages :type can_add_web_page_previews: :obj:`typing.Union[base.Boolean, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ until_date = prepare_arg(until_date) payload = generate_payload(**locals()) - result = await self.request(api.Methods.RESTRICT_CHAT_MEMBER, payload) + result = await self.request(api.Methods.RESTRICT_CHAT_MEMBER, payload) return result async def promote_chat_member(self, chat_id: typing.Union[base.Integer, base.String], @@ -1009,12 +1045,12 @@ class Bot(BaseBot): with a subset of his own privileges or demote administrators that he has promoted, directly or indirectly (promoted by administrators that were appointed by him) :type can_promote_members: :obj:`typing.Union[base.Boolean, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.PROMOTE_CHAT_MEMBER, payload) + result = await self.request(api.Methods.PROMOTE_CHAT_MEMBER, payload) return result async def export_chat_invite_link(self, chat_id: typing.Union[base.Integer, base.String]) -> base.String: @@ -1026,12 +1062,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns exported invite link as String on success. + :return: Returns exported invite link as String on success :rtype: :obj:`base.String` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.EXPORT_CHAT_INVITE_LINK, payload) + result = await self.request(api.Methods.EXPORT_CHAT_INVITE_LINK, payload) return result async def set_chat_photo(self, chat_id: typing.Union[base.Integer, base.String], @@ -1049,12 +1085,15 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param photo: New chat photo, uploaded using multipart/form-data :type photo: :obj:`base.InputFile` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals(), exclude=['photo']) - result = await self.send_file('photo', api.Methods.SET_CHAT_PHOTO, photo, payload) + files = {} + prepare_file(payload, files, 'photo', photo) + + result = await self.request(api.Methods.SET_CHAT_PHOTO, payload, files) return result async def delete_chat_photo(self, chat_id: typing.Union[base.Integer, base.String]) -> base.Boolean: @@ -1069,12 +1108,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.DELETE_CHAT_PHOTO, payload) + result = await self.request(api.Methods.DELETE_CHAT_PHOTO, payload) return result async def set_chat_title(self, chat_id: typing.Union[base.Integer, base.String], @@ -1092,12 +1131,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param title: New chat title, 1-255 characters :type title: :obj:`base.String` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.SET_CHAT_TITLE, payload) + result = await self.request(api.Methods.SET_CHAT_TITLE, payload) return result async def set_chat_description(self, chat_id: typing.Union[base.Integer, base.String], @@ -1112,12 +1151,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param description: New chat description, 0-255 characters :type description: :obj:`typing.Union[base.String, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.SET_CHAT_DESCRIPTION, payload) + result = await self.request(api.Methods.SET_CHAT_DESCRIPTION, payload) return result async def pin_chat_message(self, chat_id: typing.Union[base.Integer, base.String], message_id: base.Integer, @@ -1135,12 +1174,12 @@ class Bot(BaseBot): :param disable_notification: Pass True, if it is not necessary to send a notification to all group members about the new pinned message :type disable_notification: :obj:`typing.Union[base.Boolean, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.PIN_CHAT_MESSAGE, payload) + result = await self.request(api.Methods.PIN_CHAT_MESSAGE, payload) return result async def unpin_chat_message(self, chat_id: typing.Union[base.Integer, base.String]) -> base.Boolean: @@ -1152,12 +1191,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target supergroup :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.UNPIN_CHAT_MESSAGE, payload) + result = await self.request(api.Methods.UNPIN_CHAT_MESSAGE, payload) return result async def leave_chat(self, chat_id: typing.Union[base.Integer, base.String]) -> base.Boolean: @@ -1168,12 +1207,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target supergroup or channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.LEAVE_CHAT, payload) + result = await self.request(api.Methods.LEAVE_CHAT, payload) return result async def get_chat(self, chat_id: typing.Union[base.Integer, base.String]) -> types.Chat: @@ -1185,12 +1224,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target supergroup or channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns a Chat object on success. + :return: Returns a Chat object on success :rtype: :obj:`types.Chat` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_CHAT, payload) + result = await self.request(api.Methods.GET_CHAT, payload) return types.Chat(**result) async def get_chat_administrators(self, chat_id: typing.Union[base.Integer, base.String] @@ -1209,8 +1248,8 @@ class Bot(BaseBot): :rtype: :obj:`typing.List[types.ChatMember]` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_CHAT_ADMINISTRATORS, payload) + result = await self.request(api.Methods.GET_CHAT_ADMINISTRATORS, payload) return [types.ChatMember(**chatmember) for chatmember in result] async def get_chat_members_count(self, chat_id: typing.Union[base.Integer, base.String]) -> base.Integer: @@ -1221,12 +1260,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target supergroup or channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns Int on success. + :return: Returns Int on success :rtype: :obj:`base.Integer` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_CHAT_MEMBERS_COUNT, payload) + result = await self.request(api.Methods.GET_CHAT_MEMBERS_COUNT, payload) return result async def get_chat_member(self, chat_id: typing.Union[base.Integer, base.String], @@ -1240,12 +1279,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param user_id: Unique identifier of the target user :type user_id: :obj:`base.Integer` - :return: Returns a ChatMember object on success. + :return: Returns a ChatMember object on success :rtype: :obj:`types.ChatMember` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_CHAT_MEMBER, payload) + result = await self.request(api.Methods.GET_CHAT_MEMBER, payload) return types.ChatMember(**result) async def set_chat_sticker_set(self, chat_id: typing.Union[base.Integer, base.String], @@ -1263,12 +1302,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param sticker_set_name: Name of the sticker set to be set as the group sticker set :type sticker_set_name: :obj:`base.String` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.SET_CHAT_STICKER_SET, payload) + result = await self.request(api.Methods.SET_CHAT_STICKER_SET, payload) return result async def delete_chat_sticker_set(self, chat_id: typing.Union[base.Integer, base.String]) -> base.Boolean: @@ -1283,12 +1322,12 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target supergroup :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.DELETE_CHAT_STICKER_SET, payload) + result = await self.request(api.Methods.DELETE_CHAT_STICKER_SET, payload) return result async def answer_callback_query(self, callback_query_id: base.String, @@ -1313,17 +1352,17 @@ class Bot(BaseBot): :param show_alert: If true, an alert will be shown by the client instead of a notification at the top of the chat screen. Defaults to false. :type show_alert: :obj:`typing.Union[base.Boolean, None]` - :param url: URL that will be opened by the user's client. + :param url: URL that will be opened by the user's client :type url: :obj:`typing.Union[base.String, None]` :param cache_time: The maximum amount of time in seconds that the result of the callback query may be cached client-side. :type cache_time: :obj:`typing.Union[base.Integer, None]` - :return: On success, True is returned. + :return: On success, True is returned :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.ANSWER_CALLBACK_QUERY, payload) + result = await self.request(api.Methods.ANSWER_CALLBACK_QUERY, payload) return result async def edit_message_text(self, text: base.String, @@ -1339,7 +1378,7 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#editmessagetext - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String, None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message @@ -1353,7 +1392,7 @@ class Bot(BaseBot): :type parse_mode: :obj:`typing.Union[base.String, None]` :param disable_web_page_preview: Disables link previews for links in this message :type disable_web_page_preview: :obj:`typing.Union[base.Boolean, None]` - :param reply_markup: A JSON-serialized object for an inline keyboard. + :param reply_markup: A JSON-serialized object for an inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` :return: On success, if edited message is sent by the bot, the edited Message is returned, otherwise True is returned. @@ -1365,10 +1404,8 @@ class Bot(BaseBot): payload.setdefault('parse_mode', self.parse_mode) result = await self.request(api.Methods.EDIT_MESSAGE_TEXT, payload) - if isinstance(result, bool): return result - return types.Message(**result) async def edit_message_caption(self, chat_id: typing.Union[base.Integer, base.String, None] = None, @@ -1383,7 +1420,7 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#editmessagecaption - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String, None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message @@ -1395,7 +1432,7 @@ class Bot(BaseBot): :param parse_mode: Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or inline URLs in your bot's message. :type parse_mode: :obj:`typing.Union[base.String, None]` - :param reply_markup: A JSON-serialized object for an inline keyboard. + :param reply_markup: A JSON-serialized object for an inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` :return: On success, if edited message is sent by the bot, the edited Message is returned, otherwise True is returned. @@ -1407,19 +1444,17 @@ class Bot(BaseBot): payload.setdefault('parse_mode', self.parse_mode) result = await self.request(api.Methods.EDIT_MESSAGE_CAPTION, payload) - if isinstance(result, bool): return result - return types.Message(**result) async def edit_message_media(self, - media: types.InputMedia, - chat_id: typing.Union[typing.Union[base.Integer, base.String], None] = None, - message_id: typing.Union[base.Integer, None] = None, - inline_message_id: typing.Union[base.String, None] = None, - reply_markup: typing.Union[types.InlineKeyboardMarkup, None] = None, - ) -> typing.Union[types.Message, base.Boolean]: + media: types.InputMedia, + chat_id: typing.Union[typing.Union[base.Integer, base.String], None] = None, + message_id: typing.Union[base.Integer, None] = None, + inline_message_id: typing.Union[base.String, None] = None, + reply_markup: typing.Union[types.InlineKeyboardMarkup, None] = None, + ) -> typing.Union[types.Message, base.Boolean]: """ Use this method to edit audio, document, photo, or video messages. If a message is a part of a message album, then it can be edited only to a photo or a video. @@ -1432,7 +1467,7 @@ class Bot(BaseBot): Source https://core.telegram.org/bots/api#editmessagemedia - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified :type chat_id: :obj:`typing.Union[typing.Union[base.Integer, base.String], None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message :type message_id: :obj:`typing.Union[base.Integer, None]` @@ -1440,25 +1475,23 @@ class Bot(BaseBot): :type inline_message_id: :obj:`typing.Union[base.String, None]` :param media: A JSON-serialized object for a new media content of the message :type media: :obj:`types.InputMedia` - :param reply_markup: A JSON-serialized object for a new inline keyboard. + :param reply_markup: A JSON-serialized object for a new inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` - :return: On success, if the edited message was sent by the bot, the edited Message is returned, otherwise True is returned. + :return: On success, if the edited message was sent by the bot, the edited Message is returned, + otherwise True is returned :rtype: :obj:`typing.Union[types.Message, base.Boolean]` """ - - if isinstance(media, types.InputMedia) and media.file: - files = {media.attachment_key: media.file} - else: - files = None - reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.EDIT_MESSAGE_MEDIA, payload, files) + if isinstance(media, types.InputMedia): + files = dict(media.get_files()) + else: + files = None + result = await self.request(api.Methods.EDIT_MESSAGE_MEDIA, payload, files) if isinstance(result, bool): return result - return types.Message(**result) async def edit_message_reply_markup(self, @@ -1472,14 +1505,14 @@ class Bot(BaseBot): Source: https://core.telegram.org/bots/api#editmessagereplymarkup - :param chat_id: Required if inline_message_id is not specified. + :param chat_id: Required if inline_message_id is not specified Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String, None]` :param message_id: Required if inline_message_id is not specified. Identifier of the sent message :type message_id: :obj:`typing.Union[base.Integer, None]` :param inline_message_id: Required if chat_id and message_id are not specified. Identifier of the inline message :type inline_message_id: :obj:`typing.Union[base.String, None]` - :param reply_markup: A JSON-serialized object for an inline keyboard. + :param reply_markup: A JSON-serialized object for an inline keyboard :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` :return: On success, if edited message is sent by the bot, the edited Message is returned, otherwise True is returned. @@ -1487,11 +1520,10 @@ class Bot(BaseBot): """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.EDIT_MESSAGE_REPLY_MARKUP, payload) + result = await self.request(api.Methods.EDIT_MESSAGE_REPLY_MARKUP, payload) if isinstance(result, bool): return result - return types.Message(**result) async def delete_message(self, chat_id: typing.Union[base.Integer, base.String], @@ -1512,12 +1544,12 @@ class Bot(BaseBot): :type chat_id: :obj:`typing.Union[base.Integer, base.String]` :param message_id: Identifier of the message to delete :type message_id: :obj:`base.Integer` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.DELETE_MESSAGE, payload) + result = await self.request(api.Methods.DELETE_MESSAGE, payload) return result # === Stickers === @@ -1538,22 +1570,25 @@ class Bot(BaseBot): :param chat_id: Unique identifier for the target chat or username of the target channel :type chat_id: :obj:`typing.Union[base.Integer, base.String]` - :param sticker: Sticker to send. + :param sticker: Sticker to send :type sticker: :obj:`typing.Union[base.InputFile, base.String]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: Additional interface options. + :param reply_markup: Additional interface options :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, types.ReplyKeyboardMarkup, types.ReplyKeyboardRemove, types.ForceReply, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals(), exclude=['sticker']) - result = await self.send_file('sticker', api.Methods.SEND_STICKER, sticker, payload) + files = {} + prepare_file(payload, files, 'sticker', sticker) + + result = await self.request(api.Methods.SEND_STICKER, payload, files) return types.Message(**result) async def get_sticker_set(self, name: base.String) -> types.StickerSet: @@ -1564,12 +1599,12 @@ class Bot(BaseBot): :param name: Name of the sticker set :type name: :obj:`base.String` - :return: On success, a StickerSet object is returned. + :return: On success, a StickerSet object is returned :rtype: :obj:`types.StickerSet` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_STICKER_SET, payload) + result = await self.request(api.Methods.GET_STICKER_SET, payload) return types.StickerSet(**result) async def upload_sticker_file(self, user_id: base.Integer, png_sticker: base.InputFile) -> types.File: @@ -1584,12 +1619,15 @@ class Bot(BaseBot): :param png_sticker: Png image with the sticker, must be up to 512 kilobytes in size, dimensions must not exceed 512px, and either width or height must be exactly 512px. :type png_sticker: :obj:`base.InputFile` - :return: Returns the uploaded File on success. + :return: Returns the uploaded File on success :rtype: :obj:`types.File` """ payload = generate_payload(**locals(), exclude=['png_sticker']) - result = await self.send_file('png_sticker', api.Methods.UPLOAD_STICKER_FILE, png_sticker, payload) + files = {} + prepare_file(payload, files, 'png_sticker', png_sticker) + + result = await self.request(api.Methods.UPLOAD_STICKER_FILE, payload, files) return types.File(**result) async def create_new_sticker_set(self, user_id: base.Integer, name: base.String, title: base.String, @@ -1603,7 +1641,7 @@ class Bot(BaseBot): :param user_id: User identifier of created sticker set owner :type user_id: :obj:`base.Integer` - :param name: Short name of sticker set, to be used in t.me/addstickers/ URLs (e.g., animals). + :param name: Short name of sticker set, to be used in t.me/addstickers/ URLs (e.g., animals) :type name: :obj:`base.String` :param title: Sticker set title, 1-64 characters :type title: :obj:`base.String` @@ -1616,13 +1654,16 @@ class Bot(BaseBot): :type contains_masks: :obj:`typing.Union[base.Boolean, None]` :param mask_position: A JSON-serialized object for position where the mask should be placed on faces :type mask_position: :obj:`typing.Union[types.MaskPosition, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ mask_position = prepare_arg(mask_position) payload = generate_payload(**locals(), exclude=['png_sticker']) - result = await self.send_file('png_sticker', api.Methods.CREATE_NEW_STICKER_SET, png_sticker, payload) + files = {} + prepare_file(payload, files, 'png_sticker', png_sticker) + + result = await self.request(api.Methods.CREATE_NEW_STICKER_SET, payload, files) return result async def add_sticker_to_set(self, user_id: base.Integer, name: base.String, @@ -1644,13 +1685,16 @@ class Bot(BaseBot): :type emojis: :obj:`base.String` :param mask_position: A JSON-serialized object for position where the mask should be placed on faces :type mask_position: :obj:`typing.Union[types.MaskPosition, None]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ mask_position = prepare_arg(mask_position) payload = generate_payload(**locals(), exclude=['png_sticker']) - result = await self.send_file('png_sticker', api.Methods.ADD_STICKER_TO_SET, png_sticker, payload) + files = {} + prepare_file(payload, files, 'png_sticker', png_sticker) + + result = await self.request(api.Methods.ADD_STICKER_TO_SET, payload, files) return result async def set_sticker_position_in_set(self, sticker: base.String, position: base.Integer) -> base.Boolean: @@ -1663,7 +1707,7 @@ class Bot(BaseBot): :type sticker: :obj:`base.String` :param position: New sticker position in the set, zero-based :type position: :obj:`base.Integer` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) @@ -1681,12 +1725,12 @@ class Bot(BaseBot): :param sticker: File identifier of the sticker :type sticker: :obj:`base.String` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.DELETE_STICKER_FROM_SET, payload) + result = await self.request(api.Methods.DELETE_STICKER_FROM_SET, payload) return result async def answer_inline_query(self, inline_query_id: base.String, @@ -1724,13 +1768,13 @@ class Bot(BaseBot): :param switch_pm_parameter: Deep-linking parameter for the /start message sent to the bot when user presses the switch button. 1-64 characters, only A-Z, a-z, 0-9, _ and - are allowed. :type switch_pm_parameter: :obj:`typing.Union[base.String, None]` - :return: On success, True is returned. + :return: On success, True is returned :rtype: :obj:`base.Boolean` """ results = prepare_arg(results) payload = generate_payload(**locals()) - result = await self.request(api.Methods.ANSWER_INLINE_QUERY, payload) + result = await self.request(api.Methods.ANSWER_INLINE_QUERY, payload) return result # === Payments === @@ -1764,7 +1808,7 @@ class Bot(BaseBot): :type title: :obj:`base.String` :param description: Product description, 1-255 characters :type description: :obj:`base.String` - :param payload: Bot-defined invoice payload, 1-128 bytes. + :param payload: Bot-defined invoice payload, 1-128 bytes This will not be displayed to the user, use for your internal processes. :type payload: :obj:`base.String` :param provider_token: Payments provider token, obtained via Botfather @@ -1777,9 +1821,9 @@ class Bot(BaseBot): :param prices: Price breakdown, a list of components (e.g. product price, tax, discount, delivery cost, delivery tax, bonus, etc.) :type prices: :obj:`typing.List[types.LabeledPrice]` - :param provider_data: JSON-encoded data about the invoice, which will be shared with the payment provider. + :param provider_data: JSON-encoded data about the invoice, which will be shared with the payment provider :type provider_data: :obj:`typing.Union[typing.Dict, None]` - :param photo_url: URL of the product photo for the invoice. + :param photo_url: URL of the product photo for the invoice :type photo_url: :obj:`typing.Union[base.String, None]` :param photo_size: Photo size :type photo_size: :obj:`typing.Union[base.Integer, None]` @@ -1797,21 +1841,21 @@ class Bot(BaseBot): :type need_shipping_address: :obj:`typing.Union[base.Boolean, None]` :param is_flexible: Pass True, if the final price depends on the shipping method :type is_flexible: :obj:`typing.Union[base.Boolean, None]` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: A JSON-serialized object for an inline keyboard. + :param reply_markup: A JSON-serialized object for an inline keyboard If empty, one 'Pay total price' button will be shown. If not empty, the first button must be a Pay button. :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ prices = prepare_arg([price.to_python() if hasattr(price, 'to_python') else price for price in prices]) reply_markup = prepare_arg(reply_markup) payload_ = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_INVOICE, payload_) + result = await self.request(api.Methods.SEND_INVOICE, payload_) return types.Message(**result) async def answer_shipping_query(self, shipping_query_id: base.String, ok: base.Boolean, @@ -1828,14 +1872,14 @@ class Bot(BaseBot): :param ok: Specify True if delivery to the specified address is possible and False if there are any problems (for example, if delivery to the specified address is not possible) :type ok: :obj:`base.Boolean` - :param shipping_options: Required if ok is True. A JSON-serialized array of available shipping options. + :param shipping_options: Required if ok is True. A JSON-serialized array of available shipping options :type shipping_options: :obj:`typing.Union[typing.List[types.ShippingOption], None]` - :param error_message: Required if ok is False. + :param error_message: Required if ok is False Error message in human readable form that explains why it is impossible to complete the order (e.g. "Sorry, delivery to your desired address is unavailable'). Telegram will display this message to the user. :type error_message: :obj:`typing.Union[base.String, None]` - :return: On success, True is returned. + :return: On success, True is returned :rtype: :obj:`base.Boolean` """ if shipping_options: @@ -1844,8 +1888,8 @@ class Bot(BaseBot): else shipping_option for shipping_option in shipping_options]) payload = generate_payload(**locals()) - result = await self.request(api.Methods.ANSWER_SHIPPING_QUERY, payload) + result = await self.request(api.Methods.ANSWER_SHIPPING_QUERY, payload) return result async def answer_pre_checkout_query(self, pre_checkout_query_id: base.String, ok: base.Boolean, @@ -1862,18 +1906,18 @@ class Bot(BaseBot): :param ok: Specify True if everything is alright (goods are available, etc.) and the bot is ready to proceed with the order. Use False if there are any problems. :type ok: :obj:`base.Boolean` - :param error_message: Required if ok is False. + :param error_message: Required if ok is False Error message in human readable form that explains the reason for failure to proceed with the checkout (e.g. "Sorry, somebody just bought the last of our amazing black T-shirts while you were busy filling out your payment details. Please choose a different color or garment!"). Telegram will display this message to the user. :type error_message: :obj:`typing.Union[base.String, None]` - :return: On success, True is returned. + :return: On success, True is returned :rtype: :obj:`base.Boolean` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.ANSWER_PRE_CHECKOUT_QUERY, payload) + result = await self.request(api.Methods.ANSWER_PRE_CHECKOUT_QUERY, payload) return result # === Games === @@ -1899,13 +1943,13 @@ class Bot(BaseBot): :type user_id: :obj:`base.Integer` :param errors: A JSON-serialized array describing the errors :type errors: :obj:`typing.List[types.PassportElementError]` - :return: Returns True on success. + :return: Returns True on success :rtype: :obj:`base.Boolean` """ errors = prepare_arg(errors) payload = generate_payload(**locals()) - result = await self.request(api.Methods.SET_PASSPORT_DATA_ERRORS, payload) + result = await self.request(api.Methods.SET_PASSPORT_DATA_ERRORS, payload) return result # === Games === @@ -1925,20 +1969,20 @@ class Bot(BaseBot): :param game_short_name: Short name of the game, serves as the unique identifier for the game. \ Set up your games via Botfather. :type game_short_name: :obj:`base.String` - :param disable_notification: Sends the message silently. Users will receive a notification with no sound. + :param disable_notification: Sends the message silently. Users will receive a notification with no sound :type disable_notification: :obj:`typing.Union[base.Boolean, None]` :param reply_to_message_id: If the message is a reply, ID of the original message :type reply_to_message_id: :obj:`typing.Union[base.Integer, None]` - :param reply_markup: A JSON-serialized object for an inline keyboard. + :param reply_markup: A JSON-serialized object for an inline keyboard If empty, one ‘Play game_title’ button will be shown. If not empty, the first button must launch the game. :type reply_markup: :obj:`typing.Union[types.InlineKeyboardMarkup, None]` - :return: On success, the sent Message is returned. + :return: On success, the sent Message is returned :rtype: :obj:`types.Message` """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals()) - result = await self.request(api.Methods.SEND_GAME, payload) + result = await self.request(api.Methods.SEND_GAME, payload) return types.Message(**result) async def set_game_score(self, user_id: base.Integer, score: base.Integer, @@ -1957,7 +2001,7 @@ class Bot(BaseBot): :type user_id: :obj:`base.Integer` :param score: New score, must be non-negative :type score: :obj:`base.Integer` - :param force: Pass True, if the high score is allowed to decrease. + :param force: Pass True, if the high score is allowed to decrease This can be useful when fixing mistakes or banning cheaters :type force: :obj:`typing.Union[base.Boolean, None]` :param disable_edit_message: Pass True, if the game message should not be automatically @@ -1969,17 +2013,16 @@ class Bot(BaseBot): :type message_id: :obj:`typing.Union[base.Integer, None]` :param inline_message_id: Required if chat_id and message_id are not specified. Identifier of the inline message :type inline_message_id: :obj:`typing.Union[base.String, None]` - :return: On success, if the message was sent by the bot, returns the edited Message, otherwise returns True. + :return: On success, if the message was sent by the bot, returns the edited Message, otherwise returns True Returns an error, if the new score is not greater than the user's current score in the chat and force is False. :rtype: :obj:`typing.Union[types.Message, base.Boolean]` """ payload = generate_payload(**locals()) - result = await self.request(api.Methods.SET_GAME_SCORE, payload) + result = await self.request(api.Methods.SET_GAME_SCORE, payload) if isinstance(result, bool): return result - return types.Message(**result) async def get_game_high_scores(self, user_id: base.Integer, @@ -2004,7 +2047,7 @@ class Bot(BaseBot): :type message_id: :obj:`typing.Union[base.Integer, None]` :param inline_message_id: Required if chat_id and message_id are not specified. Identifier of the inline message :type inline_message_id: :obj:`typing.Union[base.String, None]` - :return: Will return the score of the specified user and several of his neighbors in a game. + :return: Will return the score of the specified user and several of his neighbors in a game On success, returns an Array of GameHighScore objects. This method will currently return scores for the target user, plus two of his closest neighbors on each side. Will also return the top three users if the @@ -2015,3 +2058,6 @@ class Bot(BaseBot): result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload) return [types.GameHighScore(**gamehighscore) for gamehighscore in result] + + +bot: ContextVar[Bot] = ContextVar('bot_instance', default=None) diff --git a/aiogram/contrib/fsm_storage/memory.py b/aiogram/contrib/fsm_storage/memory.py index f8670ec4..d526e90e 100644 --- a/aiogram/contrib/fsm_storage/memory.py +++ b/aiogram/contrib/fsm_storage/memory.py @@ -1,6 +1,6 @@ import typing -from ...dispatcher import BaseStorage +from ...dispatcher.storage import BaseStorage class MemoryStorage(BaseStorage): @@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage): chat, user = self.check_address(chat=chat, user=user) user = self._get_user(chat, user) if data is None: - data = [] + data = {} user['data'].update(data, **kwargs) async def set_state(self, *, diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index e3103ae3..eaaf3985 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -303,6 +303,8 @@ class RedisStorage2(BaseStorage): async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None, **kwargs): + if data is None: + data = {} temp_data = await self.get_data(chat=chat, user=user, default={}) temp_data.update(data, **kwargs) await self.set_data(chat=chat, user=user, data=temp_data) @@ -330,6 +332,8 @@ class RedisStorage2(BaseStorage): async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None, **kwargs): + if bucket is None: + bucket = {} temp_bucket = await self.get_data(chat=chat, user=user) temp_bucket.update(bucket, **kwargs) await self.set_data(chat=chat, user=user, data=temp_bucket) diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index cfa71663..5c755af6 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -4,7 +4,7 @@ import weakref import rethinkdb as r -from ...dispatcher import BaseStorage +from ...dispatcher.storage import BaseStorage __all__ = ['RethinkDBStorage', 'ConnectionNotClosed'] diff --git a/aiogram/contrib/middlewares/context.py b/aiogram/contrib/middlewares/context.py deleted file mode 100644 index 8e6dce7a..00000000 --- a/aiogram/contrib/middlewares/context.py +++ /dev/null @@ -1,115 +0,0 @@ -from aiogram import types -from aiogram.dispatcher import ctx -from aiogram.dispatcher.middlewares import BaseMiddleware - -OBJ_KEY = '_context_data' - - -class ContextMiddleware(BaseMiddleware): - """ - Allow to store data at all of lifetime of Update object - """ - - async def on_pre_process_update(self, update: types.Update): - """ - Start of Update lifetime - - :param update: - :return: - """ - self._configure_update(update) - - async def on_post_process_update(self, update: types.Update, result): - """ - On finishing of processing update - - :param update: - :param result: - :return: - """ - if OBJ_KEY in update.conf: - del update.conf[OBJ_KEY] - - def _configure_update(self, update: types.Update = None): - """ - Setup data storage - - :param update: - :return: - """ - obj = update.conf[OBJ_KEY] = {} - return obj - - def _get_dict(self): - """ - Get data from update stored in current context - - :return: - """ - update = ctx.get_update() - obj = update.conf.get(OBJ_KEY, None) - if obj is None: - obj = self._configure_update(update) - return obj - - def __getitem__(self, item): - """ - Item getter - - :param item: - :return: - """ - return self._get_dict()[item] - - def __setitem__(self, key, value): - """ - Item setter - - :param key: - :param value: - :return: - """ - data = self._get_dict() - data[key] = value - - def __iter__(self): - """ - Iterate over dict - - :return: - """ - return self._get_dict().__iter__() - - def keys(self): - """ - Iterate over dict keys - - :return: - """ - return self._get_dict().keys() - - def values(self): - """ - Iterate over dict values - - :return: - """ - return self._get_dict().values() - - def get(self, key, default=None): - """ - Get item from dit or return default value - - :param key: - :param default: - :return: - """ - return self._get_dict().get(key, default) - - def export(self): - """ - Export all data s dict - - :return: - """ - return self._get_dict() diff --git a/aiogram/contrib/middlewares/environment.py b/aiogram/contrib/middlewares/environment.py new file mode 100644 index 00000000..0427a739 --- /dev/null +++ b/aiogram/contrib/middlewares/environment.py @@ -0,0 +1,25 @@ +from aiogram.dispatcher.middlewares import BaseMiddleware + + +class EnvironmentMiddleware(BaseMiddleware): + def __init__(self, context=None): + super(EnvironmentMiddleware, self).__init__() + + if context is None: + context = {} + self.context = context + + def update_data(self, data): + dp = self.manager.dispatcher + data.update( + bot=dp.bot, + dispatcher=dp, + loop=dp.loop + ) + if self.context: + data.update(self.context) + + async def trigger(self, action, args): + if 'error' not in action and action.startswith('pre_process_'): + self.update_data(args[-1]) + return True diff --git a/aiogram/contrib/middlewares/fsm.py b/aiogram/contrib/middlewares/fsm.py new file mode 100644 index 00000000..e3550a34 --- /dev/null +++ b/aiogram/contrib/middlewares/fsm.py @@ -0,0 +1,80 @@ +import copy +import weakref + +from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware +from aiogram.dispatcher.storage import FSMContext + + +class FSMMiddleware(LifetimeControllerMiddleware): + skip_patterns = ['error', 'update'] + + def __init__(self): + super(FSMMiddleware, self).__init__() + self._proxies = weakref.WeakKeyDictionary() + + async def pre_process(self, obj, data, *args): + proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state()) + data['state_data'] = proxy + + async def post_process(self, obj, data, *args): + proxy = data.get('state_data', None) + if isinstance(proxy, FSMSStorageProxy): + await proxy.save() + + +class FSMSStorageProxy(dict): + def __init__(self, fsm_context: FSMContext): + super(FSMSStorageProxy, self).__init__() + self.fsm_context = fsm_context + self._copy = {} + self._data = {} + self._state = None + self._is_dirty = False + + @classmethod + async def create(cls, fsm_context: FSMContext): + """ + :param fsm_context: + :return: + """ + proxy = cls(fsm_context) + await proxy.load() + return proxy + + async def load(self): + self.clear() + self._state = await self.fsm_context.get_state() + self.update(await self.fsm_context.get_data()) + self._copy = copy.deepcopy(self) + self._is_dirty = False + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + self._is_dirty = True + + @state.deleter + def state(self): + self._state = None + self._is_dirty = True + + async def save(self, force=False): + if self._copy != self or force: + await self.fsm_context.set_data(data=self) + if self._is_dirty or force: + await self.fsm_context.set_state(self.state) + self._is_dirty = False + self._copy = copy.deepcopy(self) + + def __str__(self): + s = super(FSMSStorageProxy, self).__str__() + readable_state = f"'{self.state}'" if self.state else "''" + return f"<{self.__class__.__name__}(state={readable_state}, data={s})>" + + def clear(self): + del self.state + return super(FSMSStorageProxy, self).clear() diff --git a/aiogram/contrib/middlewares/i18n.py b/aiogram/contrib/middlewares/i18n.py new file mode 100644 index 00000000..2ecc167a --- /dev/null +++ b/aiogram/contrib/middlewares/i18n.py @@ -0,0 +1,140 @@ +import gettext +import os +from contextvars import ContextVar +from typing import Any, Dict, Tuple + +from babel import Locale + +from ... import types +from ...dispatcher.middlewares import BaseMiddleware + + +class I18nMiddleware(BaseMiddleware): + """ + I18n middleware based on gettext util + + >>> dp = Dispatcher(bot) + >>> i18n = I18nMiddleware(DOMAIN, LOCALES_DIR) + >>> dp.middleware.setup(i18n) + and then + >>> _ = i18n.gettext + or + >>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR) + """ + + ctx_locale = ContextVar('ctx_user_locale', default=None) + + def __init__(self, domain, path=None, default='en'): + """ + :param domain: domain + :param path: path where located all *.mo files + :param default: default locale name + """ + super(I18nMiddleware, self).__init__() + + if path is None: + path = os.path.join(os.getcwd(), 'locales') + + self.domain = domain + self.path = path + self.default = default + + self.locales = self.find_locales() + + def find_locales(self) -> Dict[str, gettext.GNUTranslations]: + """ + Load all compiled locales from path + + :return: dict with locales + """ + translations = {} + + for name in os.listdir(self.path): + if not os.path.isdir(os.path.join(self.path, name)): + continue + mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo') + + if os.path.exists(mo_path): + with open(mo_path, 'rb') as fp: + translations[name] = gettext.GNUTranslations(fp) + elif os.path.exists(mo_path[:-2] + 'po'): + raise RuntimeError(f"Found locale '{name} but this language is not compiled!") + + return translations + + def reload(self): + """ + Hot reload locles + """ + self.locales = self.find_locales() + + @property + def available_locales(self) -> Tuple[str]: + """ + list of loaded locales + + :return: + """ + return tuple(self.locales.keys()) + + def __call__(self, singular, plural=None, n=1, locale=None) -> str: + return self.gettext(singular, plural, n, locale) + + def gettext(self, singular, plural=None, n=1, locale=None) -> str: + """ + Get text + + :param singular: + :param plural: + :param n: + :param locale: + :return: + """ + if locale is None: + locale = self.ctx_locale.get() + + if locale not in self.locales: + if n is 1: + return singular + else: + return plural + + translator = self.locales[locale] + + if plural is None: + return translator.gettext(singular) + else: + return translator.ngettext(singular, plural, n) + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + async def get_user_locale(self, action: str, args: Tuple[Any]) -> str: + """ + User locale getter + You can override the method if you want to use different way of getting user language. + + :param action: event name + :param args: event arguments + :return: locale name + """ + user: types.User = types.User.current() + locale: Locale = user.locale + + if locale: + *_, data = args + language = data['locale'] = locale.language + return language + + async def trigger(self, action, args): + """ + Event trigger + + :param action: event name + :param args: event arguments + :return: + """ + if 'update' not in action \ + and 'error' not in action \ + and action.startswith('pre_process'): + locale = await self.get_user_locale(action, args) + self.ctx_locale.set(locale) + return True diff --git a/aiogram/contrib/middlewares/logging.py b/aiogram/contrib/middlewares/logging.py index 234c1053..ff870b2a 100644 --- a/aiogram/contrib/middlewares/logging.py +++ b/aiogram/contrib/middlewares/logging.py @@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware): return round((time.time() - start) * 1000) return -1 - async def on_pre_process_update(self, update: types.Update): + async def on_pre_process_update(self, update: types.Update, data: dict): update.conf['_start'] = time.time() self.logger.debug(f"Received update [ID:{update.update_id}]") - async def on_post_process_update(self, update: types.Update, result): + async def on_post_process_update(self, update: types.Update, result, data: dict): timeout = self.check_timeout(update) if timeout > 0: self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)") - async def on_pre_process_message(self, message: types.Message): + async def on_pre_process_message(self, message: types.Message, data: dict): self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") - async def on_post_process_message(self, message: types.Message, results): + async def on_post_process_message(self, message: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") - async def on_pre_process_edited_message(self, edited_message): + async def on_pre_process_edited_message(self, edited_message, data: dict): self.logger.info(f"Received edited message [ID:{edited_message.message_id}] " f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") - async def on_post_process_edited_message(self, edited_message, results): + async def on_post_process_edited_message(self, edited_message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"edited message [ID:{edited_message.message_id}] " f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") - async def on_pre_process_channel_post(self, channel_post: types.Message): + async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict): self.logger.info(f"Received channel post [ID:{channel_post.message_id}] " f"in channel [ID:{channel_post.chat.id}]") - async def on_post_process_channel_post(self, channel_post: types.Message, results): + async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"channel post [ID:{channel_post.message_id}] " f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]") - async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message): + async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict): self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] " f"in channel [ID:{edited_channel_post.chat.id}]") - async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results): + async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"edited channel post [ID:{edited_channel_post.message_id}] " f"in channel [ID:{edited_channel_post.chat.id}]") - async def on_pre_process_inline_query(self, inline_query: types.InlineQuery): + async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict): self.logger.info(f"Received inline query [ID:{inline_query.id}] " f"from user [ID:{inline_query.from_user.id}]") - async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results): + async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"inline query [ID:{inline_query.id}] " f"from user [ID:{inline_query.from_user.id}]") - async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult): + async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict): self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " f"from user [ID:{chosen_inline_result.from_user.id}] " f"result [ID:{chosen_inline_result.result_id}]") - async def on_post_process_chosen_inline_result(self, chosen_inline_result, results): + async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " f"from user [ID:{chosen_inline_result.from_user.id}] " f"result [ID:{chosen_inline_result.result_id}]") - async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery): + async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict): if callback_query.message: if callback_query.message.from_user: self.logger.info(f"Received callback query [ID:{callback_query.id}] " @@ -100,7 +100,7 @@ class LoggingMiddleware(BaseMiddleware): f"from inline message [ID:{callback_query.inline_message_id}] " f"from user [ID:{callback_query.from_user.id}]") - async def on_post_process_callback_query(self, callback_query, results): + async def on_post_process_callback_query(self, callback_query, results, data: dict): if callback_query.message: if callback_query.message.from_user: self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " @@ -117,25 +117,25 @@ class LoggingMiddleware(BaseMiddleware): f"from inline message [ID:{callback_query.inline_message_id}] " f"from user [ID:{callback_query.from_user.id}]") - async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery): + async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict): self.logger.info(f"Received shipping query [ID:{shipping_query.id}] " f"from user [ID:{shipping_query.from_user.id}]") - async def on_post_process_shipping_query(self, shipping_query, results): + async def on_post_process_shipping_query(self, shipping_query, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"shipping query [ID:{shipping_query.id}] " f"from user [ID:{shipping_query.from_user.id}]") - async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery): + async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict): self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] " f"from user [ID:{pre_checkout_query.from_user.id}]") - async def on_post_process_pre_checkout_query(self, pre_checkout_query, results): + async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict): self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " f"pre-checkout query [ID:{pre_checkout_query.id}] " f"from user [ID:{pre_checkout_query.from_user.id}]") - async def on_pre_process_error(self, dispatcher, update, error): + async def on_pre_process_error(self, dispatcher, update, error, data: dict): timeout = self.check_timeout(update) if timeout > 0: self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)") diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 3cea91d8..2ff5dc90 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,1056 +1,18 @@ -import asyncio -import functools -import itertools -import logging -import time -import typing - -from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \ - USER_STATE, generate_default_filters -from .handler import CancelHandler, Handler, SkipHandler -from .middlewares import MiddlewareManager -from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ - LAST_CALL, RATE_LIMIT, RESULT -from .webhook import BaseResponse -from ..bot import Bot -from ..types.message import ContentType -from ..utils import context -from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled - -log = logging.getLogger(__name__) - -MODE = 'MODE' -LONG_POLLING = 'long-polling' -UPDATE_OBJECT = 'update_object' - -DEFAULT_RATE_LIMIT = .1 - - -class Dispatcher: - """ - Simple Updates dispatcher - - It will process incoming updates: messages, edited messages, channel posts, edited channel posts, - inline queries, chosen inline results, callback queries, shipping queries, pre-checkout queries. - """ - - def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, - run_tasks_by_default: bool = False, - throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False): - - if loop is None: - loop = bot.loop - if storage is None: - storage = DisabledStorage() - - self.bot: Bot = bot - self.loop = loop - self.storage = storage - self.run_tasks_by_default = run_tasks_by_default - - self.throttling_rate_limit = throttling_rate_limit - self.no_throttle_error = no_throttle_error - - self.last_update_id = 0 - - self.updates_handler = Handler(self, middleware_key='update') - self.message_handlers = Handler(self, middleware_key='message') - self.edited_message_handlers = Handler(self, middleware_key='edited_message') - self.channel_post_handlers = Handler(self, middleware_key='channel_post') - self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post') - self.inline_query_handlers = Handler(self, middleware_key='inline_query') - self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result') - self.callback_query_handlers = Handler(self, middleware_key='callback_query') - self.shipping_query_handlers = Handler(self, middleware_key='shipping_query') - self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query') - self.errors_handlers = Handler(self, once=False, middleware_key='error') - - self.middleware = MiddlewareManager(self) - - self.updates_handler.register(self.process_update) - - self._polling = False - self._closed = True - self._close_waiter = loop.create_future() - - def __del__(self): - self.stop_polling() - - @property - def data(self): - return self.bot.data - - def __setitem__(self, key, value): - self.bot.data[key] = value - - def __getitem__(self, item): - return self.bot.data[item] - - def get(self, key, default=None): - return self.bot.data.get(key, default) - - async def skip_updates(self): - """ - You can skip old incoming updates from queue. - This method is not recommended to use if you use payments or you bot has high-load. - - :return: count of skipped updates - """ - total = 0 - updates = await self.bot.get_updates(offset=self.last_update_id, timeout=1) - while updates: - total += len(updates) - for update in updates: - if update.update_id > self.last_update_id: - self.last_update_id = update.update_id - updates = await self.bot.get_updates(offset=self.last_update_id + 1, timeout=1) - return total - - async def process_updates(self, updates): - """ - Process list of updates - - :param updates: - :return: - """ - tasks = [] - for update in updates: - tasks.append(self.updates_handler.notify(update)) - return await asyncio.gather(*tasks) - - async def process_update(self, update): - """ - Process single update object - - :param update: - :return: - """ - self.last_update_id = update.update_id - context.set_value(UPDATE_OBJECT, update) - try: - if update.message: - state = await self.storage.get_state(chat=update.message.chat.id, - user=update.message.from_user.id) - context.update_state(chat=update.message.chat.id, - user=update.message.from_user.id, - state=state) - return await self.message_handlers.notify(update.message) - if update.edited_message: - state = await self.storage.get_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id) - context.update_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id, - state=state) - return await self.edited_message_handlers.notify(update.edited_message) - if update.channel_post: - state = await self.storage.get_state(chat=update.channel_post.chat.id) - context.update_state(chat=update.channel_post.chat.id, - state=state) - return await self.channel_post_handlers.notify(update.channel_post) - if update.edited_channel_post: - state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) - context.update_state(chat=update.edited_channel_post.chat.id, - state=state) - return await self.edited_channel_post_handlers.notify(update.edited_channel_post) - if update.inline_query: - state = await self.storage.get_state(user=update.inline_query.from_user.id) - context.update_state(user=update.inline_query.from_user.id, - state=state) - return await self.inline_query_handlers.notify(update.inline_query) - if update.chosen_inline_result: - state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) - context.update_state(user=update.chosen_inline_result.from_user.id, - state=state) - return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) - if update.callback_query: - state = await self.storage.get_state( - chat=update.callback_query.message.chat.id if update.callback_query.message else None, - user=update.callback_query.from_user.id) - context.update_state(user=update.callback_query.from_user.id, - state=state) - return await self.callback_query_handlers.notify(update.callback_query) - if update.shipping_query: - state = await self.storage.get_state(user=update.shipping_query.from_user.id) - context.update_state(user=update.shipping_query.from_user.id, - state=state) - return await self.shipping_query_handlers.notify(update.shipping_query) - if update.pre_checkout_query: - state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) - context.update_state(user=update.pre_checkout_query.from_user.id, - state=state) - return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) - except Exception as e: - err = await self.errors_handlers.notify(self, update, e) - if err: - return err - raise - - async def reset_webhook(self, check=True) -> bool: - """ - Reset webhook - - :param check: check before deleting - :return: - """ - if check: - wh = await self.bot.get_webhook_info() - if not wh.url: - return False - - return await self.bot.delete_webhook() - - async def start_polling(self, timeout=20, relax=0.1, limit=None, reset_webhook=None): - """ - Start long-polling - - :param timeout: - :param relax: - :param limit: - :param reset_webhook: - :return: - """ - if self._polling: - raise RuntimeError('Polling already started') - - log.info('Start polling.') - - context.set_value(MODE, LONG_POLLING) - context.set_value('dispatcher', self) - context.set_value('bot', self.bot) - - if reset_webhook is None: - await self.reset_webhook(check=False) - if reset_webhook: - await self.reset_webhook(check=True) - - self._polling = True - offset = None - try: - while self._polling: - try: - updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout) - except: - log.exception('Cause exception while getting updates.') - await asyncio.sleep(15) - continue - - if updates: - log.debug(f"Received {len(updates)} updates.") - offset = updates[-1].update_id + 1 - - self.loop.create_task(self._process_polling_updates(updates)) - - if relax: - await asyncio.sleep(relax) - finally: - self._close_waiter.set_result(None) - log.warning('Polling is stopped.') - - async def _process_polling_updates(self, updates): - """ - Process updates received from long-polling. - - :param updates: list of updates. - """ - need_to_call = [] - for responses in itertools.chain.from_iterable(await self.process_updates(updates)): - for response in responses: - if not isinstance(response, BaseResponse): - continue - need_to_call.append(response.execute_response(self.bot)) - if need_to_call: - try: - asyncio.gather(*need_to_call) - except TelegramAPIError: - log.exception('Cause exception while processing updates.') - - def stop_polling(self): - """ - Break long-polling process. - - :return: - """ - if self._polling: - log.info('Stop polling...') - self._polling = False - - async def wait_closed(self): - """ - Wait for the long-polling to close - - :return: - """ - await asyncio.shield(self._close_waiter, loop=self.loop) - - def is_polling(self): - """ - Check if polling is enabled - - :return: - """ - return self._polling - - def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None, - state=None, custom_filters=None, run_task=None, **kwargs): - """ - Register handler for message - - .. code-block:: python3 - - # This handler works only if state is None (by default). - dp.register_message_handler(cmd_start, commands=['start', 'about']) - dp.register_message_handler(entry_point, commands=['setup']) - - # This handler works only if current state is "first_step" - dp.register_message_handler(step_handler_1, state="first_step") - - # If you want to handle all states by one handler, use `state="*"`. - dp.register_message_handler(cancel_handler, commands=['cancel'], state="*") - dp.register_message_handler(cancel_handler, func=lambda msg: msg.text.lower() == 'cancel', state="*") - - :param callback: - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param custom_filters: list of custom filters - :param kwargs: - :param state: - :return: decorated function - """ - if content_types is None: - content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] - - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) - self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None, - run_task=None, **kwargs): - """ - Decorator for message handler - - Examples: - - Simple commands handler: - - .. code-block:: python3 - - @dp.message_handler(commands=['start', 'welcome', 'about']) - async def cmd_handler(message: types.Message): - - Filter messages by regular expression: - - .. code-block:: python3 - - @dp.message_handler(rexexp='^[a-z]+-[0-9]+') - async def msg_handler(message: types.Message): - - Filter messages by command regular expression: - - .. code-block:: python3 - - @dp.message_handler(filters.RegexpCommandsFilter(regexp_commands=['item_([0-9]*)'])) - async def send_welcome(message: types.Message): - - Filter by content type: - - .. code-block:: python3 - - @dp.message_handler(content_types=ContentType.PHOTO | ContentType.DOCUMENT) - async def audio_handler(message: types.Message): - - Filter by custom function: - - .. code-block:: python3 - - @dp.message_handler(func=lambda message: message.text and 'hello' in message.text.lower()) - async def text_handler(message: types.Message): - - Use multiple filters: - - .. code-block:: python3 - - @dp.message_handler(commands=['command'], content_types=ContentType.TEXT) - async def text_handler(message: types.Message): - - Register multiple filters set for one handler: - - .. code-block:: python3 - - @dp.message_handler(commands=['command']) - @dp.message_handler(func=lambda message: demojize(message.text) == ':new_moon_with_face:') - async def text_handler(message: types.Message): - - This handler will be called if the message starts with '/command' OR is some emoji - - By default content_type is :class:`ContentType.TEXT` - - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param custom_filters: list of custom filters - :param kwargs: - :param state: - :param run_task: run callback in task (no wait results) - :return: decorated function - """ - - def decorator(callback): - self.register_message_handler(callback, - commands=commands, regexp=regexp, content_types=content_types, - func=func, state=state, custom_filters=custom_filters, run_task=run_task, - **kwargs) - return callback - - return decorator - - def register_edited_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None, - state=None, custom_filters=None, run_task=None, **kwargs): - """ - Register handler for edited message - - :param callback: - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - if content_types is None: - content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] - - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) - self.edited_message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def edited_message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, - state=None, run_task=None, **kwargs): - """ - Decorator for edited message handler - - You can use combination of different handlers - - .. code-block:: python3 - - @dp.message_handler() - @dp.edited_message_handler() - async def msg_handler(message: types.Message): - - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - - def decorator(callback): - self.register_edited_message_handler(callback, commands=commands, regexp=regexp, - content_types=content_types, func=func, state=state, - custom_filters=custom_filters, run_task=run_task, **kwargs) - return callback - - return decorator - - def register_channel_post_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None, - state=None, custom_filters=None, run_task=None, **kwargs): - """ - Register handler for channel post - - :param callback: - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - if content_types is None: - content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] - - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) - self.channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, - state=None, run_task=None, **kwargs): - """ - Decorator for channel post handler - - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - - def decorator(callback): - self.register_channel_post_handler(callback, commands=commands, regexp=regexp, content_types=content_types, - func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_edited_channel_post_handler(self, callback, *, commands=None, regexp=None, content_types=None, - func=None, state=None, custom_filters=None, run_task=None, **kwargs): - """ - Register handler for edited channel post - - :param callback: - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - if content_types is None: - content_types = ContentType.TEXT - if custom_filters is None: - custom_filters = [] - - filters_set = generate_default_filters(self, - *custom_filters, - commands=commands, - regexp=regexp, - content_types=content_types, - func=func, - state=state, - **kwargs) - self.edited_channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def edited_channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, - state=None, run_task=None, **kwargs): - """ - Decorator for edited channel post handler - - :param commands: list of commands - :param regexp: REGEXP - :param content_types: List of content types. - :param func: custom any callable object - :param custom_filters: list of custom filters - :param state: - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - - def decorator(callback): - self.register_edited_channel_post_handler(callback, commands=commands, regexp=regexp, - content_types=content_types, func=func, state=state, - custom_filters=custom_filters, run_task=run_task, **kwargs) - return callback - - return decorator - - def register_inline_handler(self, callback, *, func=None, state=None, custom_filters=None, run_task=None, **kwargs): - """ - Register handler for inline query - - Example: - - .. code-block:: python3 - - dp.register_inline_handler(some_inline_handler, func=lambda inline_query: True) - - :param callback: - :param func: custom any callable object - :param custom_filters: list of custom filters - :param state: - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(self, - *custom_filters, - func=func, - state=state, - **kwargs) - self.inline_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def inline_handler(self, *custom_filters, func=None, state=None, run_task=None, **kwargs): - """ - Decorator for inline query handler - - Example: - - .. code-block:: python3 - - @dp.inline_handler(func=lambda inline_query: True) - async def some_inline_handler(inline_query: types.InlineQuery) - - :param func: custom any callable object - :param state: - :param custom_filters: list of custom filters - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: decorated function - """ - - def decorator(callback): - self.register_inline_handler(callback, func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_chosen_inline_handler(self, callback, *, func=None, state=None, custom_filters=None, run_task=None, - **kwargs): - """ - Register handler for chosen inline query - - Example: - - .. code-block:: python3 - - dp.register_chosen_inline_handler(some_chosen_inline_handler, func=lambda chosen_inline_query: True) - - :param callback: - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: - """ - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(self, - *custom_filters, - func=func, - state=state, - **kwargs) - self.chosen_inline_result_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def chosen_inline_handler(self, *custom_filters, func=None, state=None, run_task=None, **kwargs): - """ - Decorator for chosen inline query handler - - Example: - - .. code-block:: python3 - - @dp.chosen_inline_handler(func=lambda chosen_inline_query: True) - async def some_chosen_inline_handler(chosen_inline_query: types.ChosenInlineResult) - - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - :return: - """ - - def decorator(callback): - self.register_chosen_inline_handler(callback, func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_callback_query_handler(self, callback, *, func=None, state=None, custom_filters=None, run_task=None, - **kwargs): - """ - Register handler for callback query - - Example: - - .. code-block:: python3 - - dp.register_callback_query_handler(some_callback_handler, func=lambda callback_query: True) - - :param callback: - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(self, - *custom_filters, - func=func, - state=state, - **kwargs) - self.callback_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def callback_query_handler(self, *custom_filters, func=None, state=None, run_task=None, **kwargs): - """ - Decorator for callback query handler - - Example: - - .. code-block:: python3 - - @dp.callback_query_handler(func=lambda callback_query: True) - async def some_callback_handler(callback_query: types.CallbackQuery) - - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - - def decorator(callback): - self.register_callback_query_handler(callback, func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_shipping_query_handler(self, callback, *, func=None, state=None, custom_filters=None, run_task=None, - **kwargs): - """ - Register handler for shipping query - - Example: - - .. code-block:: python3 - - dp.register_shipping_query_handler(some_shipping_query_handler, func=lambda shipping_query: True) - - :param callback: - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(self, - *custom_filters, - func=func, - state=state, - **kwargs) - self.shipping_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def shipping_query_handler(self, *custom_filters, func=None, state=None, run_task=None, **kwargs): - """ - Decorator for shipping query handler - - Example: - - .. code-block:: python3 - - @dp.shipping_query_handler(func=lambda shipping_query: True) - async def some_shipping_query_handler(shipping_query: types.ShippingQuery) - - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - - def decorator(callback): - self.register_shipping_query_handler(callback, func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_pre_checkout_query_handler(self, callback, *, func=None, state=None, custom_filters=None, - run_task=None, **kwargs): - """ - Register handler for pre-checkout query - - Example: - - .. code-block:: python3 - - dp.register_pre_checkout_query_handler(some_pre_checkout_query_handler, func=lambda shipping_query: True) - - :param callback: - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - if custom_filters is None: - custom_filters = [] - filters_set = generate_default_filters(self, - *custom_filters, - func=func, - state=state, - **kwargs) - self.pre_checkout_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def pre_checkout_query_handler(self, *custom_filters, func=None, state=None, run_task=None, **kwargs): - """ - Decorator for pre-checkout query handler - - Example: - - .. code-block:: python3 - - @dp.pre_checkout_query_handler(func=lambda shipping_query: True) - async def some_pre_checkout_query_handler(shipping_query: types.ShippingQuery) - - :param func: custom any callable object - :param state: - :param custom_filters: - :param run_task: run callback in task (no wait results) - :param kwargs: - """ - - def decorator(callback): - self.register_pre_checkout_query_handler(callback, func=func, state=state, custom_filters=custom_filters, - run_task=run_task, **kwargs) - return callback - - return decorator - - def register_errors_handler(self, callback, *, func=None, exception=None, run_task=None): - """ - Register handler for errors - - :param callback: - :param func: - :param exception: you can make handler for specific errors type - :param run_task: run callback in task (no wait results) - """ - filters_set = [] - if func is not None: - filters_set.append(func) - if exception is not None: - filters_set.append(ExceptionsFilter(exception)) - self.errors_handlers.register(self._wrap_async_task(callback, run_task), filters_set) - - def errors_handler(self, func=None, exception=None, run_task=None): - """ - Decorator for errors handler - - :param func: - :param exception: you can make handler for specific errors type - :param run_task: run callback in task (no wait results) - :return: - """ - - def decorator(callback): - self.register_errors_handler(self._wrap_async_task(callback, run_task), - func=func, exception=exception) - return callback - - return decorator - - def current_state(self, *, - chat: typing.Union[str, int, None] = None, - user: typing.Union[str, int, None] = None) -> FSMContext: - """ - Get current state for user in chat as context - - .. code-block:: python3 - - with dp.current_state(chat=message.chat.id, user=message.user.id) as state: - pass - - state = dp.current_state() - state.set_state('my_state') - - :param chat: - :param user: - :return: - """ - if chat is None: - from .ctx import get_chat - chat = get_chat() - if user is None: - from .ctx import get_user - user = get_user() - - return FSMContext(storage=self.storage, chat=chat, user=user) - - async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: - """ - Execute throttling manager. - Returns True if limit has not exceeded otherwise raises ThrottleError or returns False - - :param key: key in storage - :param rate: limit (by default is equal to default rate limit) - :param user: user id - :param chat: chat id - :param no_error: return boolean value instead of raising error - :return: bool - """ - if not self.storage.has_bucket(): - raise RuntimeError('This storage does not provide Leaky Bucket') - - if no_error is None: - no_error = self.no_throttle_error - if rate is None: - rate = self.throttling_rate_limit - if user is None and chat is None: - from . import ctx - user = ctx.get_user() - chat = ctx.get_chat() - - # Detect current time - now = time.time() - - bucket = await self.storage.get_bucket(chat=chat, user=user) - - # Fix bucket - if bucket is None: - bucket = {key: {}} - if key not in bucket: - bucket[key] = {} - data = bucket[key] - - # Calculate - called = data.get(LAST_CALL, now) - delta = now - called - result = delta >= rate or delta <= 0 - - # Save results - data[RESULT] = result - data[RATE_LIMIT] = rate - data[LAST_CALL] = now - data[DELTA] = delta - if not result: - data[EXCEEDED_COUNT] += 1 - else: - data[EXCEEDED_COUNT] = 1 - bucket[key].update(data) - await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) - - if not result and not no_error: - # Raise if it is allowed - raise Throttled(key=key, chat=chat, user=user, **data) - return result - - async def check_key(self, key, chat=None, user=None): - """ - Get information about key in bucket - - :param key: - :param chat: - :param user: - :return: - """ - if not self.storage.has_bucket(): - raise RuntimeError('This storage does not provide Leaky Bucket') - - if user is None and chat is None: - from . import ctx - user = ctx.get_user() - chat = ctx.get_chat() - - bucket = await self.storage.get_bucket(chat=chat, user=user) - data = bucket.get(key, {}) - return Throttled(key=key, chat=chat, user=user, **data) - - async def release_key(self, key, chat=None, user=None): - """ - Release blocked key - - :param key: - :param chat: - :param user: - :return: - """ - if not self.storage.has_bucket(): - raise RuntimeError('This storage does not provide Leaky Bucket') - - if user is None and chat is None: - from . import ctx - user = ctx.get_user() - chat = ctx.get_chat() - - bucket = await self.storage.get_bucket(chat=chat, user=user) - if bucket and key in bucket: - del bucket['key'] - await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) - return True - return False - - def async_task(self, func): - """ - Execute handler as task and return None. - Use this decorator for slow handlers (with timeouts) - - .. code-block:: python3 - - @dp.message_handler(commands=['command']) - @dp.async_task - async def cmd_with_timeout(message: types.Message): - await asyncio.sleep(120) - return SendMessage(message.chat.id, 'KABOOM').reply(message) - - :param func: - :return: - """ - - def process_response(task): - try: - response = task.result() - except Exception as e: - self.loop.create_task( - self.errors_handlers.notify(self, task.context.get(UPDATE_OBJECT, None), e)) - else: - if isinstance(response, BaseResponse): - self.loop.create_task(response.execute_response(self.bot)) - - @functools.wraps(func) - async def wrapper(*args, **kwargs): - task = self.loop.create_task(func(*args, **kwargs)) - task.add_done_callback(process_response) - - return wrapper - - def _wrap_async_task(self, callback, run_task=None) -> callable: - if run_task is None: - run_task = self.run_tasks_by_default - - if run_task: - return self.async_task(callback) - return callback +from . import filters +from . import handler +from . import middlewares +from . import storage +from . import webhook +from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT + +__all__ = [ + 'DEFAULT_RATE_LIMIT', + 'Dispatcher', + 'dispatcher', + 'FSMContext', + 'filters', + 'handler', + 'middlewares', + 'storage', + 'webhook' +] diff --git a/aiogram/dispatcher/ctx.py b/aiogram/dispatcher/ctx.py deleted file mode 100644 index 18229125..00000000 --- a/aiogram/dispatcher/ctx.py +++ /dev/null @@ -1,42 +0,0 @@ -from . import Bot -from .. import types -from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT -from ..utils import context - - -def _get(key, default=None, no_error=False): - result = context.get_value(key, default) - if not no_error and result is None: - raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n" - f"Maybe asyncio task factory is not configured!\n" - f"\t>>> from aiogram.utils import context\n" - f"\t>>> loop.set_task_factory(context.task_factory)") - return result - - -def get_bot() -> Bot: - return _get('bot') - - -def get_dispatcher() -> Dispatcher: - return _get('dispatcher') - - -def get_update() -> types.Update: - return _get(UPDATE_OBJECT) - - -def get_mode() -> str: - return _get(MODE, 'unknown') - - -def get_chat() -> int: - return _get('chat', no_error=True) - - -def get_user() -> int: - return _get('user', no_error=True) - - -def get_state() -> FSMContext: - return get_dispatcher().current_state() diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py new file mode 100644 index 00000000..aaa7a7cb --- /dev/null +++ b/aiogram/dispatcher/dispatcher.py @@ -0,0 +1,1007 @@ +import asyncio +import functools +import itertools +import logging +import time +import typing +from contextvars import ContextVar + +from aiogram.dispatcher.filters import Command +from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \ + Regexp, StateFilter, Text +from .handler import Handler +from .middlewares import MiddlewareManager +from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ + LAST_CALL, RATE_LIMIT, RESULT +from .webhook import BaseResponse +from .. import types +from ..bot import Bot, bot +from ..utils.exceptions import TelegramAPIError, Throttled + +log = logging.getLogger(__name__) + +MODE = 'MODE' +LONG_POLLING = 'long-polling' +UPDATE_OBJECT = 'update_object' + +DEFAULT_RATE_LIMIT = .1 + + +class Dispatcher: + """ + Simple Updates dispatcher + + It will process incoming updates: messages, edited messages, channel posts, edited channel posts, + inline queries, chosen inline results, callback queries, shipping queries, pre-checkout queries. + """ + + def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, + run_tasks_by_default: bool = False, + throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False, + filters_factory=None): + + if loop is None: + loop = bot.loop + if storage is None: + storage = DisabledStorage() + if filters_factory is None: + filters_factory = FiltersFactory(self) + + self.bot: Bot = bot + self.loop = loop + self.storage = storage + self.run_tasks_by_default = run_tasks_by_default + + self.throttling_rate_limit = throttling_rate_limit + self.no_throttle_error = no_throttle_error + + self.last_update_id = 0 + + self.filters_factory: FiltersFactory = filters_factory + self.updates_handler = Handler(self, middleware_key='update') + self.message_handlers = Handler(self, middleware_key='message') + self.edited_message_handlers = Handler(self, middleware_key='edited_message') + self.channel_post_handlers = Handler(self, middleware_key='channel_post') + self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post') + self.inline_query_handlers = Handler(self, middleware_key='inline_query') + self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result') + self.callback_query_handlers = Handler(self, middleware_key='callback_query') + self.shipping_query_handlers = Handler(self, middleware_key='shipping_query') + self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query') + self.errors_handlers = Handler(self, once=False, middleware_key='error') + + self.middleware = MiddlewareManager(self) + + self.updates_handler.register(self.process_update) + + self._polling = False + self._closed = True + self._close_waiter = loop.create_future() + + filters_factory.bind(StateFilter, exclude_event_handlers=[ + self.errors_handlers + ]) + filters_factory.bind(ContentTypeFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + ]), + filters_factory.bind(Command, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(Text, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + self.callback_query_handlers + ]) + filters_factory.bind(Regexp, event_handlers=[ + self.message_handlers, self.edited_message_handlers, + self.channel_post_handlers, self.edited_channel_post_handlers, + self.callback_query_handlers + ]) + filters_factory.bind(RegexpCommandsFilter, event_handlers=[ + self.message_handlers, self.edited_message_handlers + ]) + filters_factory.bind(ExceptionsFilter, event_handlers=[ + self.errors_handlers + ]) + + def __del__(self): + self.stop_polling() + + @property + def data(self): + return self.bot.data + + def __setitem__(self, key, value): + self.bot.data[key] = value + + def __getitem__(self, item): + return self.bot.data[item] + + def get(self, key, default=None): + return self.bot.data.get(key, default) + + @classmethod + def current(cls): + return dispatcher.get() + + async def skip_updates(self): + """ + You can skip old incoming updates from queue. + This method is not recommended to use if you use payments or you bot has high-load. + + :return: count of skipped updates + """ + total = 0 + updates = await self.bot.get_updates(offset=self.last_update_id, timeout=1) + while updates: + total += len(updates) + for update in updates: + if update.update_id > self.last_update_id: + self.last_update_id = update.update_id + updates = await self.bot.get_updates(offset=self.last_update_id + 1, timeout=1) + return total + + async def process_updates(self, updates): + """ + Process list of updates + + :param updates: + :return: + """ + tasks = [] + for update in updates: + tasks.append(self.updates_handler.notify(update)) + return await asyncio.gather(*tasks) + + async def process_update(self, update: types.Update): + """ + Process single update object + + :param update: + :return: + """ + self.last_update_id = update.update_id + types.Update.set_current(update) + + try: + if update.message: + types.User.set_current(update.message.from_user) + types.Chat.set_current(update.message.chat) + return await self.message_handlers.notify(update.message) + if update.edited_message: + types.User.set_current(update.edited_message.from_user) + types.Chat.set_current(update.edited_message.chat) + return await self.edited_message_handlers.notify(update.edited_message) + if update.channel_post: + types.Chat.set_current(update.channel_post.chat) + return await self.channel_post_handlers.notify(update.channel_post) + if update.edited_channel_post: + types.Chat.set_current(update.edited_channel_post.chat) + return await self.edited_channel_post_handlers.notify(update.edited_channel_post) + if update.inline_query: + types.User.set_current(update.inline_query.from_user) + return await self.inline_query_handlers.notify(update.inline_query) + if update.chosen_inline_result: + types.User.set_current(update.chosen_inline_result.from_user) + return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) + if update.callback_query: + if update.callback_query.message: + types.Chat.set_current(update.callback_query.message.chat) + types.User.set_current(update.callback_query.from_user) + return await self.callback_query_handlers.notify(update.callback_query) + if update.shipping_query: + types.User.set_current(update.shipping_query.from_user) + return await self.shipping_query_handlers.notify(update.shipping_query) + if update.pre_checkout_query: + types.User.set_current(update.pre_checkout_query.from_user) + return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) + except Exception as e: + err = await self.errors_handlers.notify(self, update, e) + if err: + return err + raise + + async def reset_webhook(self, check=True) -> bool: + """ + Reset webhook + + :param check: check before deleting + :return: + """ + if check: + wh = await self.bot.get_webhook_info() + if not wh.url: + return False + + return await self.bot.delete_webhook() + + async def start_polling(self, timeout=20, relax=0.1, limit=None, reset_webhook=None): + """ + Start long-polling + + :param timeout: + :param relax: + :param limit: + :param reset_webhook: + :return: + """ + if self._polling: + raise RuntimeError('Polling already started') + + log.info('Start polling.') + + # context.set_value(MODE, LONG_POLLING) + dispatcher.set(self) + bot.bot.set(self.bot) + + if reset_webhook is None: + await self.reset_webhook(check=False) + if reset_webhook: + await self.reset_webhook(check=True) + + self._polling = True + offset = None + try: + while self._polling: + try: + updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout) + except: + log.exception('Cause exception while getting updates.') + await asyncio.sleep(15) + continue + + if updates: + log.debug(f"Received {len(updates)} updates.") + offset = updates[-1].update_id + 1 + + self.loop.create_task(self._process_polling_updates(updates)) + + if relax: + await asyncio.sleep(relax) + finally: + self._close_waiter._set_result(None) + log.warning('Polling is stopped.') + + async def _process_polling_updates(self, updates): + """ + Process updates received from long-polling. + + :param updates: list of updates. + """ + need_to_call = [] + for responses in itertools.chain.from_iterable(await self.process_updates(updates)): + for response in responses: + if not isinstance(response, BaseResponse): + continue + need_to_call.append(response.execute_response(self.bot)) + if need_to_call: + try: + asyncio.gather(*need_to_call) + except TelegramAPIError: + log.exception('Cause exception while processing updates.') + + def stop_polling(self): + """ + Break long-polling process. + + :return: + """ + if self._polling: + log.info('Stop polling...') + self._polling = False + + async def wait_closed(self): + """ + Wait for the long-polling to close + + :return: + """ + await asyncio.shield(self._close_waiter, loop=self.loop) + + def is_polling(self): + """ + Check if polling is enabled + + :return: + """ + return self._polling + + def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Register handler for message + + .. code-block:: python3 + + # This handler works only if state is None (by default). + dp.register_message_handler(cmd_start, commands=['start', 'about']) + dp.register_message_handler(entry_point, commands=['setup']) + + # This handler works only if current state is "first_step" + dp.register_message_handler(step_handler_1, state="first_step") + + # If you want to handle all states by one handler, use `state="*"`. + dp.register_message_handler(cancel_handler, commands=['cancel'], state="*") + dp.register_message_handler(cancel_handler, lambda msg: msg.text.lower() == 'cancel', state="*") + + :param callback: + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param custom_filters: list of custom filters + :param kwargs: + :param state: + :return: decorated function + """ + filters_set = self.filters_factory.resolve(self.message_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) + self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, state=None, + run_task=None, **kwargs): + """ + Decorator for message handler + + Examples: + + Simple commands handler: + + .. code-block:: python3 + + @dp.message_handler(commands=['start', 'welcome', 'about']) + async def cmd_handler(message: types.Message): + + Filter messages by regular expression: + + .. code-block:: python3 + + @dp.message_handler(rexexp='^[a-z]+-[0-9]+') + async def msg_handler(message: types.Message): + + Filter messages by command regular expression: + + .. code-block:: python3 + + @dp.message_handler(filters.RegexpCommandsFilter(regexp_commands=['item_([0-9]*)'])) + async def send_welcome(message: types.Message): + + Filter by content type: + + .. code-block:: python3 + + @dp.message_handler(content_types=ContentType.PHOTO | ContentType.DOCUMENT) + async def audio_handler(message: types.Message): + + Filter by custom function: + + .. code-block:: python3 + + @dp.message_handler(lambda message: message.text and 'hello' in message.text.lower()) + async def text_handler(message: types.Message): + + Use multiple filters: + + .. code-block:: python3 + + @dp.message_handler(commands=['command'], content_types=ContentType.TEXT) + async def text_handler(message: types.Message): + + Register multiple filters set for one handler: + + .. code-block:: python3 + + @dp.message_handler(commands=['command']) + @dp.message_handler(lambda message: demojize(message.text) == ':new_moon_with_face:') + async def text_handler(message: types.Message): + + This handler will be called if the message starts with '/command' OR is some emoji + + By default content_type is :class:`ContentType.TEXT` + + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param custom_filters: list of custom filters + :param kwargs: + :param state: + :param run_task: run callback in task (no wait results) + :return: decorated function + """ + + def decorator(callback): + self.register_message_handler(callback, *custom_filters, + commands=commands, regexp=regexp, content_types=content_types, + state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_edited_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Register handler for edited message + + :param callback: + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + filters_set = self.filters_factory.resolve(self.edited_message_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) + self.edited_message_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def edited_message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Decorator for edited message handler + + You can use combination of different handlers + + .. code-block:: python3 + + @dp.message_handler() + @dp.edited_message_handler() + async def msg_handler(message: types.Message): + + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + + def decorator(callback): + self.register_edited_message_handler(callback, *custom_filters, commands=commands, regexp=regexp, + content_types=content_types, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_channel_post_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Register handler for channel post + + :param callback: + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + filters_set = self.filters_factory.resolve(self.channel_post_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) + self.channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Decorator for channel post handler + + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + + def decorator(callback): + self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp, + content_types=content_types, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_edited_channel_post_handler(self, callback, *custom_filters, commands=None, regexp=None, + content_types=None, state=None, run_task=None, **kwargs): + """ + Register handler for edited channel post + + :param callback: + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + filters_set = self.filters_factory.resolve(self.edited_message_handlers, + *custom_filters, + commands=commands, + regexp=regexp, + content_types=content_types, + state=state, + **kwargs) + self.edited_channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def edited_channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, + state=None, run_task=None, **kwargs): + """ + Decorator for edited channel post handler + + :param commands: list of commands + :param regexp: REGEXP + :param content_types: List of content types. + :param custom_filters: list of custom filters + :param state: + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + + def decorator(callback): + self.register_edited_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp, + content_types=content_types, state=state, run_task=run_task, + **kwargs) + return callback + + return decorator + + def register_inline_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs): + """ + Register handler for inline query + + Example: + + .. code-block:: python3 + + dp.register_inline_handler(some_inline_handler, lambda inline_query: True) + + :param callback: + :param custom_filters: list of custom filters + :param state: + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + if custom_filters is None: + custom_filters = [] + filters_set = self.filters_factory.resolve(self.inline_query_handlers, + *custom_filters, + state=state, + **kwargs) + self.inline_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def inline_handler(self, *custom_filters, state=None, run_task=None, **kwargs): + """ + Decorator for inline query handler + + Example: + + .. code-block:: python3 + + @dp.inline_handler(lambda inline_query: True) + async def some_inline_handler(inline_query: types.InlineQuery) + + :param state: + :param custom_filters: list of custom filters + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: decorated function + """ + + def decorator(callback): + self.register_inline_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_chosen_inline_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs): + """ + Register handler for chosen inline query + + Example: + + .. code-block:: python3 + + dp.register_chosen_inline_handler(some_chosen_inline_handler, lambda chosen_inline_query: True) + + :param callback: + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: + """ + if custom_filters is None: + custom_filters = [] + filters_set = self.filters_factory.resolve(self.chosen_inline_result_handlers, + *custom_filters, + state=state, + **kwargs) + self.chosen_inline_result_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def chosen_inline_handler(self, *custom_filters, state=None, run_task=None, **kwargs): + """ + Decorator for chosen inline query handler + + Example: + + .. code-block:: python3 + + @dp.chosen_inline_handler(lambda chosen_inline_query: True) + async def some_chosen_inline_handler(chosen_inline_query: types.ChosenInlineResult) + + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + :return: + """ + + def decorator(callback): + self.register_chosen_inline_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_callback_query_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs): + """ + Register handler for callback query + + Example: + + .. code-block:: python3 + + dp.register_callback_query_handler(some_callback_handler, lambda callback_query: True) + + :param callback: + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + filters_set = self.filters_factory.resolve(self.callback_query_handlers, + *custom_filters, + state=state, + **kwargs) + self.callback_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def callback_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs): + """ + Decorator for callback query handler + + Example: + + .. code-block:: python3 + + @dp.callback_query_handler(lambda callback_query: True) + async def some_callback_handler(callback_query: types.CallbackQuery) + + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + + def decorator(callback): + self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_shipping_query_handler(self, callback, *custom_filters, state=None, run_task=None, + **kwargs): + """ + Register handler for shipping query + + Example: + + .. code-block:: python3 + + dp.register_shipping_query_handler(some_shipping_query_handler, lambda shipping_query: True) + + :param callback: + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + filters_set = self.filters_factory.resolve(self.shipping_query_handlers, + *custom_filters, + state=state, + **kwargs) + self.shipping_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def shipping_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs): + """ + Decorator for shipping query handler + + Example: + + .. code-block:: python3 + + @dp.shipping_query_handler(lambda shipping_query: True) + async def some_shipping_query_handler(shipping_query: types.ShippingQuery) + + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + + def decorator(callback): + self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs) + return callback + + return decorator + + def register_pre_checkout_query_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs): + """ + Register handler for pre-checkout query + + Example: + + .. code-block:: python3 + + dp.register_pre_checkout_query_handler(some_pre_checkout_query_handler, lambda shipping_query: True) + + :param callback: + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + filters_set = self.filters_factory.resolve(self.pre_checkout_query_handlers, + *custom_filters, + state=state, + **kwargs) + self.pre_checkout_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def pre_checkout_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs): + """ + Decorator for pre-checkout query handler + + Example: + + .. code-block:: python3 + + @dp.pre_checkout_query_handler(lambda shipping_query: True) + async def some_pre_checkout_query_handler(shipping_query: types.ShippingQuery) + + :param state: + :param custom_filters: + :param run_task: run callback in task (no wait results) + :param kwargs: + """ + + def decorator(callback): + self.register_pre_checkout_query_handler(callback, *custom_filters, state=state, run_task=run_task, + **kwargs) + return callback + + return decorator + + def register_errors_handler(self, callback, *custom_filters, exception=None, run_task=None, **kwargs): + """ + Register handler for errors + + :param callback: + :param exception: you can make handler for specific errors type + :param run_task: run callback in task (no wait results) + """ + filters_set = self.filters_factory.resolve(self.errors_handlers, + *custom_filters, + exception=exception, + **kwargs) + self.errors_handlers.register(self._wrap_async_task(callback, run_task), filters_set) + + def errors_handler(self, *custom_filters, exception=None, run_task=None, **kwargs): + """ + Decorator for errors handler + + :param exception: you can make handler for specific errors type + :param run_task: run callback in task (no wait results) + :return: + """ + + def decorator(callback): + self.register_errors_handler(self._wrap_async_task(callback, run_task), + *custom_filters, exception=exception, **kwargs) + return callback + + return decorator + + def current_state(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None) -> FSMContext: + """ + Get current state for user in chat as context + + .. code-block:: python3 + + with dp.current_state(chat=message.chat.id, user=message.user.id) as state: + pass + + state = dp.current_state() + state.set_state('my_state') + + :param chat: + :param user: + :return: + """ + if chat is None: + chat_obj = types.Chat.current() + chat = chat_obj.id if chat_obj else None + if user is None: + user_obj = types.User.current() + user = user_obj.id if user_obj else None + + return FSMContext(storage=self.storage, chat=chat, user=user) + + async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: + """ + Execute throttling manager. + Returns True if limit has not exceeded otherwise raises ThrottleError or returns False + + :param key: key in storage + :param rate: limit (by default is equal to default rate limit) + :param user: user id + :param chat: chat id + :param no_error: return boolean value instead of raising error + :return: bool + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if no_error is None: + no_error = self.no_throttle_error + if rate is None: + rate = self.throttling_rate_limit + if user is None and chat is None: + user = types.User.current() + chat = types.Chat.current() + + # Detect current time + now = time.time() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + + # Fix bucket + if bucket is None: + bucket = {key: {}} + if key not in bucket: + bucket[key] = {} + data = bucket[key] + + # Calculate + called = data.get(LAST_CALL, now) + delta = now - called + result = delta >= rate or delta <= 0 + + # Save results + data[RESULT] = result + data[RATE_LIMIT] = rate + data[LAST_CALL] = now + data[DELTA] = delta + if not result: + data[EXCEEDED_COUNT] += 1 + else: + data[EXCEEDED_COUNT] = 1 + bucket[key].update(data) + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + + if not result and not no_error: + # Raise if it is allowed + raise Throttled(key=key, chat=chat, user=user, **data) + return result + + async def check_key(self, key, chat=None, user=None): + """ + Get information about key in bucket + + :param key: + :param chat: + :param user: + :return: + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if user is None and chat is None: + user = types.User.current() + chat = types.Chat.current() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + data = bucket.get(key, {}) + return Throttled(key=key, chat=chat, user=user, **data) + + async def release_key(self, key, chat=None, user=None): + """ + Release blocked key + + :param key: + :param chat: + :param user: + :return: + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if user is None and chat is None: + user = types.User.current() + chat = types.Chat.current() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + if bucket and key in bucket: + del bucket['key'] + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + return True + return False + + def async_task(self, func): + """ + Execute handler as task and return None. + Use this decorator for slow handlers (with timeouts) + + .. code-block:: python3 + + @dp.message_handler(commands=['command']) + @dp.async_task + async def cmd_with_timeout(message: types.Message): + await asyncio.sleep(120) + return SendMessage(message.chat.id, 'KABOOM').reply(message) + + :param func: + :return: + """ + + def process_response(task): + try: + response = task.result() + except Exception as e: + self.loop.create_task( + self.errors_handlers.notify(self, types.Update.current(), e)) + else: + if isinstance(response, BaseResponse): + self.loop.create_task(response.execute_response(self.bot)) + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + task = self.loop.create_task(func(*args, **kwargs)) + task.add_done_callback(process_response) + + return wrapper + + def _wrap_async_task(self, callback, run_task=None) -> callable: + if run_task is None: + run_task = self.run_tasks_by_default + + if run_task: + return self.async_task(callback) + return callback + + +dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None) diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py deleted file mode 100644 index fb3f04a0..00000000 --- a/aiogram/dispatcher/filters.py +++ /dev/null @@ -1,289 +0,0 @@ -import asyncio -import inspect -import re - -from ..types import CallbackQuery, ContentType, Message -from ..utils import context -from ..utils.helper import Helper, HelperMode, Item - -USER_STATE = 'USER_STATE' - - -async def check_filter(filter_, args): - """ - Helper for executing filter - - :param filter_: - :param args: - :param kwargs: - :return: - """ - if not callable(filter_): - raise TypeError('Filter must be callable and/or awaitable!') - - if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): - return await filter_(*args) - else: - return filter_(*args) - - -async def check_filters(filters, args): - """ - Check list of filters - - :param filters: - :param args: - :return: - """ - if filters is not None: - for filter_ in filters: - f = await check_filter(filter_, args) - if not f: - return False - return True - - -class Filter: - """ - Base class for filters - """ - - def __call__(self, *args, **kwargs): - return self.check(*args, **kwargs) - - def check(self, *args, **kwargs): - raise NotImplementedError - - -class AsyncFilter(Filter): - """ - Base class for asynchronous filters - """ - - def __aiter__(self): - return None - - def __await__(self): - return self.check - - async def check(self, *args, **kwargs): - pass - - -class AnyFilter(AsyncFilter): - """ - One filter from many - """ - - def __init__(self, *filters: callable): - self.filters = filters - - async def check(self, *args): - f = (check_filter(filter_, args) for filter_ in self.filters) - return any(await asyncio.gather(*f)) - - -class NotFilter(AsyncFilter): - """ - Reverse filter - """ - - def __init__(self, filter_: callable): - self.filter = filter_ - - async def check(self, *args): - return not await check_filter(self.filter, args) - - -class CommandsFilter(AsyncFilter): - """ - Check commands in message - """ - - def __init__(self, commands): - self.commands = commands - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - if command not in self.commands: - return False - - return True - - -class RegexpFilter(Filter): - """ - Regexp filter for messages - """ - - def __init__(self, regexp): - self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) - - def check(self, obj): - if isinstance(obj, Message) and obj.text: - return bool(self.regexp.search(obj.text)) - elif isinstance(obj, CallbackQuery) and obj.data: - return bool(self.regexp.search(obj.data)) - return False - - -class RegexpCommandsFilter(AsyncFilter): - """ - Check commands by regexp in message - """ - - def __init__(self, regexp_commands): - self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] - - async def check(self, message): - if not message.is_command(): - return False - - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - - if mention and mention != (await message.bot.me).username: - return False - - for command in self.regexp_commands: - search = command.search(message.text) - if search: - message.conf['regexp_command'] = search - return True - return False - - -class ContentTypeFilter(Filter): - """ - Check message content type - """ - - def __init__(self, content_types): - self.content_types = content_types - - def check(self, message): - return ContentType.ANY[0] in self.content_types or \ - message.content_type in self.content_types - - -class CancelFilter(Filter): - """ - Find cancel in message text - """ - - def __init__(self, cancel_set=None): - if cancel_set is None: - cancel_set = ['/cancel', 'cancel', 'cancel.'] - self.cancel_set = cancel_set - - def check(self, message): - if message.text: - return message.text.lower() in self.cancel_set - - -class StateFilter(AsyncFilter): - """ - Check user state - """ - - def __init__(self, dispatcher, state): - self.dispatcher = dispatcher - self.state = state - - def get_target(self, obj): - return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) - - async def check(self, obj): - if self.state == '*': - return True - - if context.check_value(USER_STATE): - context_state = context.get_value(USER_STATE) - return self.state == context_state - else: - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state - return False - - -class StatesListFilter(StateFilter): - """ - List of states - """ - - async def check(self, obj): - chat, user = self.get_target(obj) - - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state - return False - - -class ExceptionsFilter(Filter): - """ - Filter for exceptions - """ - - def __init__(self, exception): - self.exception = exception - - def check(self, dispatcher, update, exception): - return isinstance(exception, self.exception) - - -def generate_default_filters(dispatcher, *args, **kwargs): - """ - Prepare filters - - :param dispatcher: - :param args: - :param kwargs: - :return: - """ - filters_set = [] - - for name, filter_ in kwargs.items(): - if filter_ is None and name != DefaultFilters.STATE: - continue - if name == DefaultFilters.COMMANDS: - if isinstance(filter_, str): - filters_set.append(CommandsFilter([filter_])) - else: - filters_set.append(CommandsFilter(filter_)) - elif name == DefaultFilters.REGEXP: - filters_set.append(RegexpFilter(filter_)) - elif name == DefaultFilters.CONTENT_TYPES: - filters_set.append(ContentTypeFilter(filter_)) - elif name == DefaultFilters.FUNC: - filters_set.append(filter_) - elif name == DefaultFilters.STATE: - if isinstance(filter_, (list, set, tuple)): - filters_set.append(StatesListFilter(dispatcher, filter_)) - else: - filters_set.append(StateFilter(dispatcher, filter_)) - elif isinstance(filter_, Filter): - filters_set.append(filter_) - - filters_set += list(args) - - return filters_set - - -class DefaultFilters(Helper): - mode = HelperMode.snake_case - - COMMANDS = Item() # commands - REGEXP = Item() # regexp - CONTENT_TYPES = Item() # content_type - FUNC = Item() # func - STATE = Item() # state diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py new file mode 100644 index 00000000..aa3a3ecf --- /dev/null +++ b/aiogram/dispatcher/filters/__init__.py @@ -0,0 +1,24 @@ +from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \ + RegexpCommandsFilter, StateFilter, Text +from .factory import FiltersFactory +from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters + +__all__ = [ + 'AbstractFilter', + 'BoundFilter', + 'Command', + 'CommandStart', + 'CommandHelp', + 'ContentTypeFilter', + 'ExceptionsFilter', + 'Filter', + 'FilterNotPassed', + 'FilterRecord', + 'FiltersFactory', + 'RegexpCommandsFilter', + 'Regexp', + 'StateFilter', + 'Text', + 'check_filter', + 'check_filters' +] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py new file mode 100644 index 00000000..c2251b56 --- /dev/null +++ b/aiogram/dispatcher/filters/builtin.py @@ -0,0 +1,313 @@ +import inspect +import re +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, Optional, Union + +from aiogram import types +from aiogram.dispatcher.filters.filters import BoundFilter, Filter +from aiogram.types import CallbackQuery, Message + + +class Command(Filter): + """ + You can handle commands by using this filter + """ + + def __init__(self, commands: Union[Iterable, str], + prefixes: Union[Iterable, str] = '/', + ignore_case: bool = True, + ignore_mention: bool = False): + """ + Filter can be initialized from filters factory or by simply creating instance of this class + + :param commands: command or list of commands + :param prefixes: + :param ignore_case: + :param ignore_mention: + """ + if isinstance(commands, str): + commands = (commands,) + + self.commands = list(map(str.lower, commands)) if ignore_case else commands + self.prefixes = prefixes + self.ignore_case = ignore_case + self.ignore_mention = ignore_mention + + @classmethod + def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Validator for filters factory + + :param full_config: + :return: config or empty dict + """ + config = {} + if 'commands' in full_config: + config['commands'] = full_config.pop('commands') + if 'commands_prefix' in full_config: + config['prefixes'] = full_config.pop('commands_prefix') + if 'commands_ignore_mention' in full_config: + config['ignore_mention'] = full_config.pop('commands_ignore_mention') + return config + + async def check(self, message: types.Message): + return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention) + + @staticmethod + async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False): + full_command = message.text.split()[0] + prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@') + + if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower(): + return False + elif prefix not in prefixes: + return False + elif (command.lower() if ignore_case else command) not in commands: + return False + + return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)} + + @dataclass + class CommandObj: + prefix: str = '/' + command: str = '' + mention: str = None + args: str = field(repr=False, default=None) + + @property + def mentioned(self) -> bool: + return bool(self.mention) + + @property + def text(self) -> str: + line = self.prefix + self.command + if self.mentioned: + line += '@' + self.mention + if self.args: + line += ' ' + self.args + return line + + +class CommandStart(Command): + def __init__(self): + super(CommandStart, self).__init__(['start']) + + +class CommandHelp(Command): + def __init__(self): + super(CommandHelp, self).__init__(['help']) + + +class Text(Filter): + """ + Simple text filter + """ + + def __init__(self, + equals: Optional[str] = None, + contains: Optional[str] = None, + startswith: Optional[str] = None, + endswith: Optional[str] = None, + ignore_case=False): + """ + Check text for one of pattern. Only one mode can be used in one filter. + + :param equals: + :param contains: + :param startswith: + :param endswith: + :param ignore_case: case insensitive + """ + # Only one mode can be used. check it. + check = sum(map(bool, (equals, contains, startswith, endswith))) + if check > 1: + args = "' and '".join([arg[0] for arg in [('equals', equals), + ('contains', contains), + ('startswith', startswith), + ('endswith', endswith) + ] if arg[1]]) + raise ValueError(f"Arguments '{args}' cannot be used together.") + elif check == 0: + raise ValueError(f"No one mode is specified!") + + self.equals = equals + self.contains = contains + self.endswith = endswith + self.startswith = startswith + self.ignore_case = ignore_case + + @classmethod + def validate(cls, full_config: Dict[str, Any]): + if 'text' in full_config: + return {'equals': full_config.pop('text')} + elif 'text_contains' in full_config: + return {'contains': full_config.pop('text_contains')} + elif 'text_startswith' in full_config: + return {'startswith': full_config.pop('text_startswith')} + elif 'text_endswith' in full_config: + return {'endswith': full_config.pop('text_endswith')} + + async def check(self, obj: Union[Message, CallbackQuery]): + if isinstance(obj, Message): + text = obj.text or obj.caption or '' + elif isinstance(obj, CallbackQuery): + text = obj.data + else: + return False + + if self.ignore_case: + text = text.lower() + + if self.equals: + return text == self.equals + elif self.contains: + return self.contains in text + elif self.startswith: + return text.startswith(self.startswith) + elif self.endswith: + return text.endswith(self.endswith) + + return False + + +class Regexp(Filter): + """ + Regexp filter for messages and callback query + """ + + def __init__(self, regexp): + if not isinstance(regexp, re.Pattern): + regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + self.regexp = regexp + + @classmethod + def validate(cls, full_config: Dict[str, Any]): + if 'regexp' in full_config: + return {'regexp': full_config.pop('regexp')} + + async def check(self, obj: Union[Message, CallbackQuery]): + if isinstance(obj, Message): + match = self.regexp.search(obj.text or obj.caption or '') + elif isinstance(obj, CallbackQuery) and obj.data: + match = self.regexp.search(obj.data) + else: + return False + + if match: + return {'regexp': match} + return False + + +class RegexpCommandsFilter(BoundFilter): + """ + Check commands by regexp in message + """ + + key = 'regexp_commands' + + def __init__(self, regexp_commands): + self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] + + async def check(self, message): + if not message.is_command(): + return False + + command = message.text.split()[0][1:] + command, _, mention = command.partition('@') + + if mention and mention != (await message.bot.me).username: + return False + + for command in self.regexp_commands: + search = command.search(message.text) + if search: + return {'regexp_command': search} + return False + + +class ContentTypeFilter(BoundFilter): + """ + Check message content type + """ + + key = 'content_types' + required = True + default = types.ContentTypes.TEXT + + def __init__(self, content_types): + self.content_types = content_types + + async def check(self, message): + return types.ContentType.ANY in self.content_types or \ + message.content_type in self.content_types + + +class StateFilter(BoundFilter): + """ + Check user state + """ + key = 'state' + required = True + + ctx_state = ContextVar('user_state') + + def __init__(self, dispatcher, state): + from aiogram.dispatcher.filters.state import State, StatesGroup + + self.dispatcher = dispatcher + states = [] + if not isinstance(state, (list, set, tuple, frozenset)) or state is None: + state = [state, ] + for item in state: + if isinstance(item, State): + states.append(item.state) + elif inspect.isclass(item) and issubclass(item, StatesGroup): + states.extend(item.all_states_names) + else: + states.append(item) + self.states = states + + def get_target(self, obj): + return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + + async def check(self, obj): + if '*' in self.states: + return {'state': self.dispatcher.current_state()} + + try: + state = self.ctx_state.get() + except LookupError: + chat, user = self.get_target(obj) + + if chat or user: + state = await self.dispatcher.storage.get_state(chat=chat, user=user) + self.ctx_state.set(state) + if state in self.states: + return {'state': self.dispatcher.current_state(), 'raw_state': state} + + else: + if state in self.states: + return {'state': self.dispatcher.current_state(), 'raw_state': state} + + return False + + +class ExceptionsFilter(BoundFilter): + """ + Filter for exceptions + """ + + key = 'exception' + + def __init__(self, dispatcher, exception): + super().__init__(dispatcher) + self.exception = exception + + async def check(self, dispatcher, update, exception): + try: + raise exception + except self.exception: + return True + except: + return False diff --git a/aiogram/dispatcher/filters/factory.py b/aiogram/dispatcher/filters/factory.py new file mode 100644 index 00000000..099a9b60 --- /dev/null +++ b/aiogram/dispatcher/filters/factory.py @@ -0,0 +1,73 @@ +import typing + +from .filters import AbstractFilter, FilterRecord +from ..handler import Handler + + +class FiltersFactory: + """ + Default filters factory + """ + + def __init__(self, dispatcher): + self._dispatcher = dispatcher + self._registered: typing.List[FilterRecord] = [] + + def bind(self, callback: typing.Union[typing.Callable, AbstractFilter], + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.List[Handler]] = None, + exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None): + """ + Register filter + + :param callback: callable or subclass of :obj:`AbstractFilter` + :param validator: custom validator. + :param event_handlers: list of instances of :obj:`Handler` + :param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`) + """ + record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers) + self._registered.append(record) + + def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]): + """ + Unregister callback + + :param callback: callable of subclass of :obj:`AbstractFilter` + """ + for record in self._registered: + if record.callback == callback: + self._registered.remove(record) + + def resolve(self, event_handler, *custom_filters, **full_config + ) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]: + """ + Resolve filters to filters-set + + :param event_handler: + :param custom_filters: + :param full_config: + :return: + """ + filters_set = [] + filters_set.extend(self._resolve_registered(event_handler, + {k: v for k, v in full_config.items() if v is not None})) + if custom_filters: + filters_set.extend(custom_filters) + + return filters_set + + def _resolve_registered(self, event_handler, full_config) -> typing.Generator: + """ + Resolve registered filters + + :param event_handler: + :param full_config: + :return: + """ + for record in self._registered: + filter_ = record.resolve(self._dispatcher, event_handler, full_config) + if filter_: + yield filter_ + + if full_config: + raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'') diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py new file mode 100644 index 00000000..816f4722 --- /dev/null +++ b/aiogram/dispatcher/filters/filters.py @@ -0,0 +1,250 @@ +import abc +import inspect +import typing + +from ..handler import Handler +from ...types.base import TelegramObject + + +class FilterNotPassed(Exception): + pass + + +def wrap_async(func): + async def async_wrapper(*args, **kwargs): + return func(*args, **kwargs) + + if inspect.isawaitable(func) \ + or inspect.iscoroutinefunction(func) \ + or isinstance(func, AbstractFilter): + return func + return async_wrapper + + +async def check_filter(dispatcher, filter_, args): + """ + Helper for executing filter + + :param dispatcher: + :param filter_: + :param args: + :return: + """ + kwargs = {} + if not callable(filter_): + raise TypeError('Filter must be callable and/or awaitable!') + + spec = inspect.getfullargspec(filter_) + if 'dispatcher' in spec: + kwargs['dispatcher'] = dispatcher + if inspect.isawaitable(filter_) \ + or inspect.iscoroutinefunction(filter_) \ + or isinstance(filter_, AbstractFilter): + return await filter_(*args, **kwargs) + else: + return filter_(*args, **kwargs) + + +async def check_filters(dispatcher, filters, args): + """ + Check list of filters + + :param dispatcher: + :param filters: + :param args: + :return: + """ + data = {} + if filters is not None: + for filter_ in filters: + f = await check_filter(dispatcher, filter_, args) + if not f: + raise FilterNotPassed() + elif isinstance(f, dict): + data.update(f) + return data + + +class FilterRecord: + """ + Filters record for factory + """ + + def __init__(self, callback: typing.Callable, + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.Iterable[Handler]] = None, + exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None): + if event_handlers and exclude_event_handlers: + raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.") + + self.callback = callback + self.event_handlers = event_handlers + self.exclude_event_handlers = exclude_event_handlers + + if validator is not None: + if not callable(validator): + raise TypeError(f"validator must be callable, not {type(validator)}") + self.resolver = validator + elif issubclass(callback, AbstractFilter): + self.resolver = callback.validate + else: + raise RuntimeError('validator is required!') + + def resolve(self, dispatcher, event_handler, full_config): + if not self._check_event_handler(event_handler): + return + config = self.resolver(full_config) + if config: + if 'dispatcher' not in config: + spec = inspect.getfullargspec(self.callback) + if 'dispatcher' in spec.args: + config['dispatcher'] = dispatcher + + for key in config: + if key in full_config: + full_config.pop(key) + + return self.callback(**config) + + def _check_event_handler(self, event_handler) -> bool: + if self.event_handlers: + return event_handler in self.event_handlers + elif self.exclude_event_handlers: + return event_handler not in self.exclude_event_handlers + return True + + +class AbstractFilter(abc.ABC): + """ + Abstract class for custom filters + """ + + @classmethod + @abc.abstractmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + """ + Validate and parse config + + :param full_config: + :return: config + """ + pass + + @abc.abstractmethod + async def check(self, *args) -> bool: + """ + Check object + + :param args: + :return: + """ + pass + + async def __call__(self, obj: TelegramObject) -> bool: + return await self.check(obj) + + def __invert__(self): + return NotFilter(self) + + def __and__(self, other): + if isinstance(self, AndFilter): + self.append(other) + return self + return AndFilter(self, other) + + def __or__(self, other): + if isinstance(self, OrFilter): + self.append(other) + return self + return OrFilter(self, other) + + +class Filter(AbstractFilter): + """ + You can make subclasses of that class for custom filters + """ + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + pass + + +class BoundFilter(Filter): + """ + Base class for filters with default validator + """ + key = None + required = False + default = None + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: + if cls.key is not None: + if cls.key in full_config: + return {cls.key: full_config[cls.key]} + elif cls.required: + return {cls.key: cls.default} + + +class _LogicFilter(Filter): + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]): + raise ValueError('That filter can\'t be used in filters factory!') + + +class NotFilter(_LogicFilter): + def __init__(self, target): + self.target = wrap_async(target) + + async def check(self, *args): + return not bool(await self.target(*args)) + + +class AndFilter(_LogicFilter): + + def __init__(self, *targets): + self.targets = list(wrap_async(target) for target in targets) + + async def check(self, *args): + """ + All filters must return a positive result + + :param args: + :return: + """ + data = {} + for target in self.targets: + result = await target(*args) + if not result: + return False + if isinstance(result, dict): + data.update(result) + if not data: + return True + return data + + def append(self, target): + self.targets.append(wrap_async(target)) + + +class OrFilter(_LogicFilter): + def __init__(self, *targets): + self.targets = list(wrap_async(target) for target in targets) + + async def check(self, *args): + """ + One of filters must return a positive result + + :param args: + :return: + """ + for target in self.targets: + result = await target(*args) + if result: + if isinstance(result, dict): + return result + return True + return False + + def append(self, target): + self.targets.append(wrap_async(target)) diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py new file mode 100644 index 00000000..fadc3687 --- /dev/null +++ b/aiogram/dispatcher/filters/state.py @@ -0,0 +1,198 @@ +import inspect +from typing import Optional + +from ..dispatcher import Dispatcher + + +class State: + """ + State object + """ + + def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None): + self._state = state + self._group_name = group_name + self._group = None + + @property + def group(self): + if not self._group: + raise RuntimeError('This state is not in any group.') + return self._group + + def get_root(self): + return self.group.get_root() + + @property + def state(self): + if self._state is None: + return None + elif self._state == '*': + return self._state + elif self._group_name is None and self._group: + group = self._group.__full_group_name__ + elif self._group_name: + group = self._group_name + else: + group = '@' + return f"{group}:{self._state}" + + def set_parent(self, group): + if not issubclass(group, StatesGroup): + raise ValueError('Group must be subclass of StatesGroup') + self._group = group + + def __set_name__(self, owner, name): + if self._state is None: + self._state = name + self.set_parent(owner) + + def __str__(self): + return f"" + + __repr__ = __str__ + + async def set(self): + state = Dispatcher.current().current_state() + await state.set_state(self.state) + + +class StatesGroupMeta(type): + def __new__(mcs, name, bases, namespace, **kwargs): + cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace) + + states = [] + childs = [] + + cls._group_name = name + + for name, prop in namespace.items(): + + if isinstance(prop, State): + states.append(prop) + elif inspect.isclass(prop) and issubclass(prop, StatesGroup): + childs.append(prop) + prop._parent = cls + # continue + + cls._parent = None + cls._childs = tuple(childs) + cls._states = tuple(states) + cls._state_names = tuple(state.state for state in states) + + return cls + + @property + def __group_name__(cls): + return cls._group_name + + @property + def __full_group_name__(cls): + if cls._parent: + return cls._parent.__full_group_name__ + '.' + cls._group_name + return cls._group_name + + @property + def states(cls) -> tuple: + return cls._states + + @property + def childs(cls): + return cls._childs + + @property + def all_childs(cls): + result = cls.childs + for child in cls.childs: + result += child.childs + return result + + @property + def all_states(cls): + result = cls.states + for group in cls.childs: + result += group.all_states + return result + + @property + def all_states_names(cls): + return tuple(state.state for state in cls.all_states) + + @property + def states_names(cls) -> tuple: + return tuple(state.state for state in cls.states) + + def get_root(cls): + if cls._parent is None: + return cls + return cls._parent.get_root() + + def __contains__(cls, item): + if isinstance(item, str): + return item in cls.all_states_names + elif isinstance(item, State): + return item in cls.all_states + elif isinstance(item, StatesGroup): + return item in cls.all_childs + return False + + def __str__(self): + return f"" + + +class StatesGroup(metaclass=StatesGroupMeta): + @classmethod + async def next(cls) -> str: + state = Dispatcher.current().current_state() + state_name = await state.get_state() + + try: + next_step = cls.states_names.index(state_name) + 1 + except ValueError: + next_step = 0 + + try: + next_state_name = cls.states[next_step].state + except IndexError: + next_state_name = None + + await state.set_state(next_state_name) + return next_state_name + + @classmethod + async def previous(cls) -> str: + state = Dispatcher.current().current_state() + state_name = await state.get_state() + + try: + previous_step = cls.states_names.index(state_name) - 1 + except ValueError: + previous_step = 0 + + if previous_step < 0: + previous_state_name = None + else: + previous_state_name = cls.states[previous_step].state + + await state.set_state(previous_state_name) + return previous_state_name + + @classmethod + async def first(cls) -> str: + state = Dispatcher.current().current_state() + first_step_name = cls.states_names[0] + + await state.set_state(first_step_name) + return first_step_name + + @classmethod + async def last(cls) -> str: + state = Dispatcher.current().current_state() + last_step_name = cls.states_names[-1] + + await state.set_state(last_step_name) + return last_step_name + + +default_state = State() +any_state = State(state='*') diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 8d75273f..4ded9316 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,5 +1,7 @@ -from .filters import check_filters -from ..utils import context +import inspect +from contextvars import ContextVar + +ctx_data = ContextVar('ctx_handler_data') class SkipHandler(BaseException): @@ -10,6 +12,14 @@ class CancelHandler(BaseException): pass +def _check_spec(func: callable, kwargs: dict): + spec = inspect.getfullargspec(func) + if spec.varkw: + return kwargs + + return {k: v for k, v in kwargs.items() if k in spec.args} + + class Handler: def __init__(self, dispatcher, once=True, middleware_key=None): self.dispatcher = dispatcher @@ -57,31 +67,43 @@ class Handler: :param args: :return: """ + from .filters import check_filters, FilterNotPassed + results = [] + data = {} + ctx_data.set(data) + if self.middleware_key: try: - await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args) + await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,)) except CancelHandler: # Allow to cancel current event return results - for filters, handler in self.handlers: - if await check_filters(filters, args): + try: + for filters, handler in self.handlers: try: - if self.middleware_key: - context.set_value('handler', handler) - await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) - response = await handler(*args) - if response is not None: - results.append(response) - if self.once: - break - except SkipHandler: + data.update(await check_filters(self.dispatcher, filters, args)) + except FilterNotPassed: continue - except CancelHandler: - break - if self.middleware_key: - await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", - args + (results,)) + else: + try: + if self.middleware_key: + # context.set_value('handler', handler) + await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,)) + partial_data = _check_spec(handler, data) + response = await handler(*args, **partial_data) + if response is not None: + results.append(response) + if self.once: + break + except SkipHandler: + continue + except CancelHandler: + break + finally: + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", + args + (results, data,)) return results diff --git a/aiogram/dispatcher/middlewares.py b/aiogram/dispatcher/middlewares.py index 4de9d61f..dba3db4c 100644 --- a/aiogram/dispatcher/middlewares.py +++ b/aiogram/dispatcher/middlewares.py @@ -101,3 +101,28 @@ class BaseMiddleware: if not handler: return None await handler(*args) + + +class LifetimeControllerMiddleware(BaseMiddleware): + # TODO: Rename class + + skip_patterns = None + + async def pre_process(self, obj, data, *args): + pass + + async def post_process(self, obj, data, *args): + pass + + async def trigger(self, action, args): + if self.skip_patterns is not None and any(item in action for item in self.skip_patterns): + return False + + obj, *args, data = args + if action.startswith('pre_process_'): + await self.pre_process(obj, data, *args) + elif action.startswith('post_process_'): + await self.post_process(obj, data, *args) + else: + return False + return True diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py index 76a23ee6..96431796 100644 --- a/aiogram/dispatcher/storage.py +++ b/aiogram/dispatcher/storage.py @@ -281,8 +281,20 @@ class FSMContext: def __exit__(self, exc_type, exc_val, exc_tb): pass + @staticmethod + def _resolve_state(value): + from .filters.state import State + + if value is None: + return + elif isinstance(value, str): + return value + elif isinstance(value, State): + return value.state + return str(value) + async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]: - return await self.storage.get_state(chat=self.chat, user=self.user, default=default) + return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default)) async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict: return await self.storage.get_data(chat=self.chat, user=self.user, default=default) @@ -291,7 +303,7 @@ class FSMContext: await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs) async def set_state(self, state: typing.Union[typing.AnyStr, None] = None): - await self.storage.set_state(chat=self.chat, user=self.user, state=state) + await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state)) async def set_data(self, data: typing.Dict = None): await self.storage.set_data(chat=self.chat, user=self.user, data=data) diff --git a/aiogram/dispatcher/webhook.py b/aiogram/dispatcher/webhook.py index c9e39a4f..bc2a0e60 100644 --- a/aiogram/dispatcher/webhook.py +++ b/aiogram/dispatcher/webhook.py @@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union from aiohttp import web from aiohttp.web_exceptions import HTTPGone + from .. import types from ..bot import api from ..types import ParseMode from ..types.base import Boolean, Float, Integer, String -from ..utils import context from ..utils import helper, markdown from ..utils import json from ..utils.deprecated import warn_deprecated as warn @@ -89,8 +89,10 @@ class WebhookRequestHandler(web.View): """ dp = self.request.app[BOT_DISPATCHER_KEY] try: - context.set_value('dispatcher', dp) - context.set_value('bot', dp.bot) + from aiogram.bot import bot + from aiogram.dispatcher import dispatcher + dispatcher.set(dp) + bot.bot.set(dp.bot) except RuntimeError: pass return dp @@ -117,9 +119,9 @@ class WebhookRequestHandler(web.View): """ self.validate_ip() - context.update_state({'CALLER': WEBHOOK, - WEBHOOK_CONNECTION: True, - WEBHOOK_REQUEST: self.request}) + # context.update_state({'CALLER': WEBHOOK, + # WEBHOOK_CONNECTION: True, + # WEBHOOK_REQUEST: self.request}) dispatcher = self.get_dispatcher() update = await self.parse_update(dispatcher.bot) @@ -177,7 +179,7 @@ class WebhookRequestHandler(web.View): if fut.done(): return fut.result() else: - context.set_value(WEBHOOK_CONNECTION, False) + # context.set_value(WEBHOOK_CONNECTION, False) fut.remove_done_callback(cb) fut.add_done_callback(self.respond_via_request) finally: @@ -202,7 +204,7 @@ class WebhookRequestHandler(web.View): results = task.result() except Exception as e: loop.create_task( - dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e)) + dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e)) else: response = self.get_response(results) if response is not None: @@ -249,7 +251,7 @@ class WebhookRequestHandler(web.View): ip_address, accept = self.check_ip() if not accept: raise web.HTTPUnauthorized() - context.set_value('TELEGRAM_IP', ip_address) + # context.set_value('TELEGRAM_IP', ip_address) class GoneRequestHandler(web.View): @@ -352,8 +354,8 @@ class BaseResponse: async def __call__(self, bot=None): if bot is None: - from aiogram.dispatcher import ctx - bot = ctx.get_bot() + from aiogram import Bot + bot = Bot.current() return await self.execute_response(bot) async def __aenter__(self): @@ -446,7 +448,8 @@ class ParseModeMixin: :return: """ - bot = context.get_value('bot', None) + from aiogram import Bot + bot = Bot.current() if bot is not None: return bot.parse_mode @@ -952,7 +955,7 @@ class SendMediaGroup(BaseResponse, ReplyToMixin, DisableNotificationMixin): self.reply_to_message_id = reply_to_message_id def prepare(self): - files = self.media.get_files() + files = dict(self.media.get_files()) if files: raise TypeError('Allowed only file ID or URL\'s') diff --git a/aiogram/types/__init__.py b/aiogram/types/__init__.py index 943efb4b..a1e98158 100644 --- a/aiogram/types/__init__.py +++ b/aiogram/types/__init__.py @@ -34,7 +34,7 @@ from .invoice import Invoice from .labeled_price import LabeledPrice from .location import Location from .mask_position import MaskPosition -from .message import ContentType, Message, ParseMode +from .message import ContentType, ContentTypes, Message, ParseMode from .message_entity import MessageEntity, MessageEntityType from .order_info import OrderInfo from .passport_data import PassportData @@ -77,6 +77,7 @@ __all__ = ( 'ChosenInlineResult', 'Contact', 'ContentType', + 'ContentTypes', 'Document', 'EncryptedCredentials', 'EncryptedPassportElement', diff --git a/aiogram/types/base.py b/aiogram/types/base.py index 166a9848..9982ad35 100644 --- a/aiogram/types/base.py +++ b/aiogram/types/base.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import io import typing +from contextvars import ContextVar from typing import TypeVar from .fields import BaseField @@ -53,6 +56,8 @@ class MetaTelegramObject(type): setattr(cls, ALIASES_ATTR_NAME, aliases) mcs._objects[cls.__name__] = cls + + cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None? return cls @property @@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject): if value.default and key not in self.values: self.values[key] = value.default + @classmethod + def current(cls): + return cls._current.get() + + @classmethod + def set_current(cls, obj: TelegramObject): + return cls._current.set(obj) + @property def conf(self) -> typing.Dict[str, typing.Any]: return self._conf @@ -137,8 +150,8 @@ class TelegramObject(metaclass=MetaTelegramObject): @property def bot(self): - from ..dispatcher import ctx - return ctx.get_bot() + from ..bot.bot import Bot + return Bot.current() def to_python(self) -> typing.Dict: """ diff --git a/aiogram/types/chat.py b/aiogram/types/chat.py index ae70c519..947add4d 100644 --- a/aiogram/types/chat.py +++ b/aiogram/types/chat.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import typing +from contextvars import ContextVar from . import base from . import fields @@ -64,7 +67,7 @@ class Chat(base.TelegramObject): if as_html: return markdown.hlink(name, self.user_url) return markdown.link(name, self.user_url) - + async def get_url(self): """ Use this method to get chat link. @@ -507,8 +510,8 @@ class ChatActions(helper.Helper): @classmethod async def _do(cls, action: str, sleep=None): - from ..dispatcher.ctx import get_bot, get_chat - await get_bot().send_chat_action(get_chat(), action) + from aiogram import Bot + await Bot.current().send_chat_action(Chat.current(), action) if sleep: await asyncio.sleep(sleep) diff --git a/aiogram/types/fields.py b/aiogram/types/fields.py index fc12dd2e..81156d5d 100644 --- a/aiogram/types/fields.py +++ b/aiogram/types/fields.py @@ -9,7 +9,7 @@ class BaseField(metaclass=abc.ABCMeta): Base field (prop) """ - def __init__(self, *, base=None, default=None, alias=None): + def __init__(self, *, base=None, default=None, alias=None, on_change=None): """ Init prop @@ -17,10 +17,12 @@ class BaseField(metaclass=abc.ABCMeta): :param default: default value :param alias: alias name (for e.g. field 'from' has to be named 'from_user' as 'from' is a builtin Python keyword + :param on_change: callback will be called when value is changed """ self.base_object = base self.default = default self.alias = alias + self.on_change = on_change def __set_name__(self, owner, name): if self.alias is None: @@ -53,6 +55,13 @@ class BaseField(metaclass=abc.ABCMeta): self.resolve_base(instance) value = self.deserialize(value, parent) instance.values[self.alias] = value + self._trigger_changed(instance, value) + + def _trigger_changed(self, instance, value): + if not self.on_change and instance is not None: + return + callback = getattr(instance, self.on_change) + callback(value) def __get__(self, instance, owner): return self.get_value(instance) @@ -154,7 +163,7 @@ class ListOfLists(Field): return result -class DateTimeField(BaseField): +class DateTimeField(Field): """ In this field st_ored datetime @@ -167,3 +176,24 @@ class DateTimeField(BaseField): def deserialize(self, value, parent=None): return datetime.datetime.fromtimestamp(value) + + +class TextField(Field): + def __init__(self, *, prefix=None, suffix=None, default=None, alias=None): + super(TextField, self).__init__(default=default, alias=alias) + self.prefix = prefix + self.suffix = suffix + + def serialize(self, value): + if value is None: + return value + if self.prefix: + value = self.prefix + value + if self.suffix: + value += self.suffix + return value + + def deserialize(self, value, parent=None): + if value is not None and not isinstance(value, str): + raise TypeError(f"Field '{self.alias}' should be str not {type(value).__name__}") + return value diff --git a/aiogram/types/input_file.py b/aiogram/types/input_file.py index 9d42c6b7..5c271701 100644 --- a/aiogram/types/input_file.py +++ b/aiogram/types/input_file.py @@ -1,6 +1,7 @@ import io import logging import os +import secrets import time import aiohttp @@ -45,6 +46,8 @@ class InputFile(base.TelegramObject): self._filename = filename + self.attachment_key = secrets.token_urlsafe(16) + def __del__(self): """ Close file descriptor @@ -54,13 +57,17 @@ class InputFile(base.TelegramObject): @property def filename(self): if self._filename is None: - self._filename = api._guess_filename(self._file) + self._filename = api.guess_filename(self._file) return self._filename @filename.setter def filename(self, value): self._filename = value + @property + def attach(self): + return f"attach://{self.attachment_key}" + def get_filename(self) -> str: """ Get file name @@ -159,6 +166,9 @@ class InputFile(base.TelegramObject): return writer + def __str__(self): + return f"" + def to_python(self): raise TypeError('Object of this type is not exportable!') diff --git a/aiogram/types/input_media.py b/aiogram/types/input_media.py index 1f68e632..7bb58a7a 100644 --- a/aiogram/types/input_media.py +++ b/aiogram/types/input_media.py @@ -12,6 +12,9 @@ ATTACHMENT_PREFIX = 'attach://' class InputMedia(base.TelegramObject): """ This object represents the content of a media message to be sent. It should be one of + - InputMediaAnimation + - InputMediaDocument + - InputMediaAudio - InputMediaPhoto - InputMediaVideo @@ -20,36 +23,76 @@ class InputMedia(base.TelegramObject): https://core.telegram.org/bots/api#inputmedia """ type: base.String = fields.Field(default='photo') - media: base.String = fields.Field() - thumb: typing.Union[base.InputFile, base.String] = fields.Field() + media: base.String = fields.Field(alias='media', on_change='_media_changed') + thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed') caption: base.String = fields.Field() parse_mode: base.Boolean = fields.Field() def __init__(self, *args, **kwargs): + self._thumb_file = None + self._media_file = None + + media = kwargs.pop('media', None) + if isinstance(media, (io.IOBase, InputFile)): + self.file = media + elif media is not None: + self.media = media + + thumb = kwargs.pop('thumb', None) + if isinstance(thumb, (io.IOBase, InputFile)): + self.thumb_file = thumb + elif thumb is not None: + self.thumb = thumb + super(InputMedia, self).__init__(*args, **kwargs) + try: - if self.parse_mode is None and self.bot.parse_mode: + if self.parse_mode is None and self.bot and self.bot.parse_mode: self.parse_mode = self.bot.parse_mode except RuntimeError: pass @property def file(self): - return getattr(self, '_file', None) + return self._media_file @file.setter def file(self, file: io.IOBase): - setattr(self, '_file', file) - attachment_key = self.attachment_key = secrets.token_urlsafe(16) - self.media = ATTACHMENT_PREFIX + attachment_key + self.media = 'attach://' + secrets.token_urlsafe(16) + self._media_file = file + + @file.deleter + def file(self): + self.media = None + self._media_file = None + + def _media_changed(self, value): + if value is None or isinstance(value, str) and not value.startswith('attach://'): + self._media_file = None @property - def attachment_key(self): - return self.conf.get('attachment_key', None) + def thumb_file(self): + return self._thumb_file - @attachment_key.setter - def attachment_key(self, value): - self.conf['attachment_key'] = value + @thumb_file.setter + def thumb_file(self, file: io.IOBase): + self.thumb = 'attach://' + secrets.token_urlsafe(16) + self._thumb_file = file + + @thumb_file.deleter + def thumb_file(self): + self.thumb = None + self._thumb_file = None + + def _thumb_changed(self, value): + if value is None or isinstance(value, str) and not value.startswith('attach://'): + self._thumb_file = None + + def get_files(self): + if self._media_file: + yield self.media[9:], self._media_file + if self._thumb_file: + yield self.thumb[9:], self._thumb_file class InputMediaAnimation(InputMedia): @@ -72,9 +115,6 @@ class InputMediaAnimation(InputMedia): width=width, height=height, duration=duration, parse_mode=parse_mode, conf=kwargs) - if isinstance(media, (io.IOBase, InputFile)): - self.file = media - class InputMediaDocument(InputMedia): """ @@ -89,9 +129,6 @@ class InputMediaDocument(InputMedia): caption=caption, parse_mode=parse_mode, conf=kwargs) - if isinstance(media, (io.IOBase, InputFile)): - self.file = media - class InputMediaAudio(InputMedia): """ @@ -119,9 +156,6 @@ class InputMediaAudio(InputMedia): performer=performer, title=title, parse_mode=parse_mode, conf=kwargs) - if isinstance(media, (io.IOBase, InputFile)): - self.file = media - class InputMediaPhoto(InputMedia): """ @@ -136,9 +170,6 @@ class InputMediaPhoto(InputMedia): caption=caption, parse_mode=parse_mode, conf=kwargs) - if isinstance(media, (io.IOBase, InputFile)): - self.file = media - class InputMediaVideo(InputMedia): """ @@ -151,18 +182,17 @@ class InputMediaVideo(InputMedia): duration: base.Integer = fields.Field() supports_streaming: base.Boolean = fields.Field() - def __init__(self, media: base.InputFile, caption: base.String = None, + def __init__(self, media: base.InputFile, + thumb: typing.Union[base.InputFile, base.String] = None, + caption: base.String = None, width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None, parse_mode: base.Boolean = None, supports_streaming: base.Boolean = None, **kwargs): - super(InputMediaVideo, self).__init__(type='video', media=media, caption=caption, + super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption, width=width, height=height, duration=duration, parse_mode=parse_mode, supports_streaming=supports_streaming, conf=kwargs) - if isinstance(media, (io.IOBase, InputFile)): - self.file = media - class MediaGroup(base.TelegramObject): """ @@ -296,6 +326,7 @@ class MediaGroup(base.TelegramObject): self.attach(photo) def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile], + thumb: typing.Union[base.InputFile, base.String] = None, caption: base.String = None, width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None): """ @@ -308,7 +339,7 @@ class MediaGroup(base.TelegramObject): :param duration: """ if not isinstance(video, InputMedia): - video = InputMediaVideo(media=video, caption=caption, + video = InputMediaVideo(media=video, thumb=thumb, caption=caption, width=width, height=height, duration=duration) self.attach(video) @@ -327,6 +358,7 @@ class MediaGroup(base.TelegramObject): return result def get_files(self): - return {inputmedia.attachment_key: inputmedia.file - for inputmedia in self.media - if isinstance(inputmedia, InputMedia) and inputmedia.file} + for inputmedia in self.media: + if not isinstance(inputmedia, InputMedia) or not inputmedia.file: + continue + yield from inputmedia.get_files() diff --git a/aiogram/types/message.py b/aiogram/types/message.py index ead27faa..43e962db 100644 --- a/aiogram/types/message.py +++ b/aiogram/types/message.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import functools import sys @@ -42,7 +44,7 @@ class Message(base.TelegramObject): forward_from_message_id: base.Integer = fields.Field() forward_signature: base.String = fields.Field() forward_date: datetime.datetime = fields.DateTimeField() - reply_to_message: 'Message' = fields.Field(base='Message') + reply_to_message: Message = fields.Field(base='Message') edit_date: datetime.datetime = fields.DateTimeField() media_group_id: base.String = fields.Field() author_signature: base.String = fields.Field() @@ -72,7 +74,7 @@ class Message(base.TelegramObject): channel_chat_created: base.Boolean = fields.Field() migrate_to_chat_id: base.Integer = fields.Field() migrate_from_chat_id: base.Integer = fields.Field() - pinned_message: 'Message' = fields.Field(base='Message') + pinned_message: Message = fields.Field(base='Message') invoice: Invoice = fields.Field(base=Invoice) successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment) connected_website: base.String = fields.Field() @@ -82,59 +84,59 @@ class Message(base.TelegramObject): @functools.lru_cache() def content_type(self): if self.text: - return ContentType.TEXT[0] + return ContentType.TEXT elif self.audio: - return ContentType.AUDIO[0] + return ContentType.AUDIO elif self.animation: - return ContentType.ANIMATION[0] + return ContentType.ANIMATION elif self.document: - return ContentType.DOCUMENT[0] + return ContentType.DOCUMENT elif self.game: - return ContentType.GAME[0] + return ContentType.GAME elif self.photo: - return ContentType.PHOTO[0] + return ContentType.PHOTO elif self.sticker: - return ContentType.STICKER[0] + return ContentType.STICKER elif self.video: - return ContentType.VIDEO[0] + return ContentType.VIDEO elif self.video_note: - return ContentType.VIDEO_NOTE[0] + return ContentType.VIDEO_NOTE elif self.voice: - return ContentType.VOICE[0] + return ContentType.VOICE elif self.contact: - return ContentType.CONTACT[0] + return ContentType.CONTACT elif self.venue: - return ContentType.VENUE[0] + return ContentType.VENUE elif self.location: - return ContentType.LOCATION[0] + return ContentType.LOCATION elif self.new_chat_members: - return ContentType.NEW_CHAT_MEMBERS[0] + return ContentType.NEW_CHAT_MEMBERS elif self.left_chat_member: - return ContentType.LEFT_CHAT_MEMBER[0] + return ContentType.LEFT_CHAT_MEMBER elif self.invoice: - return ContentType.INVOICE[0] + return ContentType.INVOICE elif self.successful_payment: - return ContentType.SUCCESSFUL_PAYMENT[0] + return ContentType.SUCCESSFUL_PAYMENT elif self.connected_website: - return ContentType.CONNECTED_WEBSITE[0] + return ContentType.CONNECTED_WEBSITE elif self.migrate_from_chat_id: - return ContentType.MIGRATE_FROM_CHAT_ID[0] + return ContentType.MIGRATE_FROM_CHAT_ID elif self.migrate_to_chat_id: - return ContentType.MIGRATE_TO_CHAT_ID[0] + return ContentType.MIGRATE_TO_CHAT_ID elif self.pinned_message: - return ContentType.PINNED_MESSAGE[0] + return ContentType.PINNED_MESSAGE elif self.new_chat_title: - return ContentType.NEW_CHAT_TITLE[0] + return ContentType.NEW_CHAT_TITLE elif self.new_chat_photo: - return ContentType.NEW_CHAT_PHOTO[0] + return ContentType.NEW_CHAT_PHOTO elif self.delete_chat_photo: - return ContentType.DELETE_CHAT_PHOTO[0] + return ContentType.DELETE_CHAT_PHOTO elif self.group_chat_created: - return ContentType.GROUP_CHAT_CREATED[0] + return ContentType.GROUP_CHAT_CREATED elif self.passport_data: - return ContentType.PASSPORT_DATA[0] + return ContentType.PASSPORT_DATA else: - return ContentType.UNKNOWN[0] + return ContentType.UNKNOWN def is_command(self): """ @@ -239,7 +241,7 @@ class Message(base.TelegramObject): return self.parse_entities() async def reply(self, text, parse_mode=None, disable_web_page_preview=None, - disable_notification=None, reply_markup=None, reply=True) -> 'Message': + disable_notification=None, reply_markup=None, reply=True) -> Message: """ Reply to this message @@ -729,6 +731,69 @@ class ContentType(helper.Helper): """ List of message content types + WARNING: Single elements + + :key: TEXT + :key: AUDIO + :key: DOCUMENT + :key: GAME + :key: PHOTO + :key: STICKER + :key: VIDEO + :key: VIDEO_NOTE + :key: VOICE + :key: CONTACT + :key: LOCATION + :key: VENUE + :key: NEW_CHAT_MEMBERS + :key: LEFT_CHAT_MEMBER + :key: INVOICE + :key: SUCCESSFUL_PAYMENT + :key: CONNECTED_WEBSITE + :key: MIGRATE_TO_CHAT_ID + :key: MIGRATE_FROM_CHAT_ID + :key: UNKNOWN + :key: ANY + """ + mode = helper.HelperMode.snake_case + + TEXT = helper.Item() # text + AUDIO = helper.Item() # audio + DOCUMENT = helper.Item() # document + ANIMATION = helper.Item() # animation + GAME = helper.Item() # game + PHOTO = helper.Item() # photo + STICKER = helper.Item() # sticker + VIDEO = helper.Item() # video + VIDEO_NOTE = helper.Item() # video_note + VOICE = helper.Item() # voice + CONTACT = helper.Item() # contact + LOCATION = helper.Item() # location + VENUE = helper.Item() # venue + NEW_CHAT_MEMBERS = helper.Item() # new_chat_member + LEFT_CHAT_MEMBER = helper.Item() # left_chat_member + INVOICE = helper.Item() # invoice + SUCCESSFUL_PAYMENT = helper.Item() # successful_payment + CONNECTED_WEBSITE = helper.Item() # connected_website + MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id + MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id + PINNED_MESSAGE = helper.Item() # pinned_message + NEW_CHAT_TITLE = helper.Item() # new_chat_title + NEW_CHAT_PHOTO = helper.Item() # new_chat_photo + DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo + GROUP_CHAT_CREATED = helper.Item() # group_chat_created + PASSPORT_DATA = helper.Item() # passport_data + + UNKNOWN = helper.Item() # unknown + ANY = helper.Item() # any + + +class ContentTypes(helper.Helper): + """ + List of message content types + + WARNING: List elements. + :key: TEXT :key: AUDIO :key: DOCUMENT diff --git a/aiogram/types/update.py b/aiogram/types/update.py index 7f9cf11a..2753ae5f 100644 --- a/aiogram/types/update.py +++ b/aiogram/types/update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import base from . import fields from .callback_query import CallbackQuery diff --git a/aiogram/types/user.py b/aiogram/types/user.py index c4c64844..441c275f 100644 --- a/aiogram/types/user.py +++ b/aiogram/types/user.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import babel from . import base diff --git a/aiogram/utils/context.py b/aiogram/utils/context.py deleted file mode 100644 index 376d9aa9..00000000 --- a/aiogram/utils/context.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -You need to setup task factory: - >>> from aiogram.utils import context - >>> loop = asyncio.get_event_loop() - >>> loop.set_task_factory(context.task_factory) -""" - -import asyncio -import typing - -CONFIGURED = '@CONFIGURED_TASK_FACTORY' - - -def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine): - """ - Task factory for implementing context processor - - :param loop: - :param coro: - :return: new task - :rtype: :obj:`asyncio.Task` - """ - # Is not allowed when loop is closed. - if loop.is_closed(): - raise RuntimeError('Event loop is closed.') - - task = asyncio.Task(coro, loop=loop) - - # Hide factory - if task._source_traceback: - del task._source_traceback[-1] - - try: - task.context = asyncio.Task.current_task().context.copy() - except AttributeError: - task.context = {CONFIGURED: True} - - return task - - -def get_current_state() -> typing.Dict: - """ - Get current execution context from task - - :return: context - :rtype: :obj:`dict` - """ - task = asyncio.Task.current_task() - if task is None: - raise RuntimeError('Can be used only in Task context.') - context_ = getattr(task, 'context', None) - if context_ is None: - context_ = task.context = {} - return context_ - - -def get_value(key, default=None): - """ - Get value from task - - :param key: - :param default: - :return: value - """ - return get_current_state().get(key, default) - - -def check_value(key): - """ - Key in context? - - :param key: - :return: - """ - return key in get_current_state() - - -def set_value(key, value): - """ - Set value - - :param key: - :param value: - :return: - """ - get_current_state()[key] = value - - -def del_value(key): - """ - Remove value from context - - :param key: - :return: - """ - del get_current_state()[key] - - -def update_state(data=None, **kwargs): - """ - Update multiple state items - - :param data: - :param kwargs: - :return: - """ - if data is None: - data = {} - state = get_current_state() - state.update(data, **kwargs) - - -def check_configured(): - """ - Check loop is configured - :return: - """ - return get_value(CONFIGURED) - - -class _Context: - """ - Other things for interactions with the execution context. - """ - - def __getitem__(self, item): - return get_value(item) - - def __setitem__(self, key, value): - set_value(key, value) - - def __delitem__(self, key): - del_value(key) - - @staticmethod - def get_context(): - return get_current_state() - - -context = _Context() diff --git a/aiogram/utils/exceptions.py b/aiogram/utils/exceptions.py index 3e4a7dba..cddd0c74 100644 --- a/aiogram/utils/exceptions.py +++ b/aiogram/utils/exceptions.py @@ -182,6 +182,10 @@ class MessageToEditNotFound(MessageError): match = 'message to edit not found' +class MessageIsTooLong(MessageError): + match = 'message is too long' + + class ToMuchMessages(MessageError): """ Will be raised when you try to send media group with more than 10 items. diff --git a/aiogram/utils/executor.py b/aiogram/utils/executor.py index e3d8aa1f..ac1a9657 100644 --- a/aiogram/utils/executor.py +++ b/aiogram/utils/executor.py @@ -6,7 +6,6 @@ from warnings import warn from aiohttp import web -from . import context from ..bot.api import log from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler @@ -104,6 +103,11 @@ class Executor: self._freeze = False + from aiogram.bot.bot import bot as ctx_bot + from aiogram.dispatcher import dispatcher as ctx_dp + ctx_bot.set(dispatcher.bot) + ctx_dp.set(dispatcher) + @property def frozen(self): return self._freeze @@ -176,13 +180,13 @@ class Executor: self._check_frozen() self._freeze = True - self.loop.set_task_factory(context.task_factory) + # self.loop.set_task_factory(context.task_factory) def _prepare_webhook(self, path=None, handler=WebhookRequestHandler): self._check_frozen() self._freeze = True - self.loop.set_task_factory(context.task_factory) + # self.loop.set_task_factory(context.task_factory) app = self._web_app if app is None: @@ -203,6 +207,7 @@ class Executor: for callback in self._on_startup_webhook: app.on_startup.append(functools.partial(_wrap_callback, callback)) + # for callback in self._on_shutdown_webhook: # app.on_shutdown.append(functools.partial(_wrap_callback, callback)) diff --git a/aiogram/utils/json.py b/aiogram/utils/json.py index 4cd02bc6..a8777593 100644 --- a/aiogram/utils/json.py +++ b/aiogram/utils/json.py @@ -1,26 +1,52 @@ -import json +import os + +JSON = 'json' +RAPIDJSON = 'rapidjson' +UJSON = 'ujson' try: - import ujson + if 'DISABLE_UJSON' not in os.environ: + import ujson as json + mode = UJSON + + + def dumps(data): + return json.dumps(data, ensure_ascii=False) + + else: + mode = JSON except ImportError: - ujson = None + mode = JSON -_use_ujson = True if ujson else False +try: + if 'DISABLE_RAPIDJSON' not in os.environ: + import rapidjson as json + + mode = RAPIDJSON -def disable_ujson(): - global _use_ujson - _use_ujson = False + def dumps(data): + return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE, + datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC) -def dumps(data): - if _use_ujson: - return ujson.dumps(data) - return json.dumps(data) + def loads(data): + return json.loads(data, number_mode=json.NM_NATIVE, + datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC) + + else: + mode = JSON +except ImportError: + mode = JSON + +if mode == JSON: + import json -def loads(data): - if _use_ujson: - return ujson.loads(data) - return json.loads(data) + def dumps(data): + return json.dumps(data, ensure_ascii=False) + + + def loads(data): + return json.loads(data) diff --git a/aiogram/utils/payload.py b/aiogram/utils/payload.py index dac43492..bbed1967 100644 --- a/aiogram/utils/payload.py +++ b/aiogram/utils/payload.py @@ -1,5 +1,7 @@ import datetime +import secrets +from aiogram import types from . import json DEFAULT_FILTER = ['self', 'cls'] @@ -56,3 +58,22 @@ def prepare_arg(value): elif isinstance(value, datetime.datetime): return round(value.timestamp()) return value + + +def prepare_file(payload, files, key, file): + if isinstance(file, str): + payload[key] = file + elif file is not None: + files[key] = file + + +def prepare_attachment(payload, files, key, file): + if isinstance(file, str): + payload[key] = file + elif isinstance(file, types.InputFile): + payload[key] = file.attach + files[file.attachment_key] = file.file + elif file is not None: + file_attach_name = secrets.token_urlsafe(16) + payload[key] = "attach://" + file_attach_name + files[file_attach_name] = file diff --git a/dev_requirements.txt b/dev_requirements.txt index 7f6f9b19..ea4d686a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,7 @@ -r requirements.txt ujson>=1.35 +python-rapidjson>=0.6.3 emoji>=0.5.0 pytest>=3.5.0 pytest-asyncio>=0.8.0 diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index c05f0ca1..6ce619be 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -25,16 +25,16 @@ Next step: interaction with bots starts with one command. Register your first co .. code-block:: python3 - @dp.message_handler(commands=['start', 'help']) - async def send_welcome(message: types.Message): - await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") + @dp.message_handler(commands=['start', 'help']) + async def send_welcome(message: types.Message): + await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") Last step: run long polling. .. code-block:: python3 - if __name__ == '__main__': - executor.start_polling(dp) + if __name__ == '__main__': + executor.start_polling(dp) Summary ------- @@ -48,9 +48,9 @@ Summary bot = Bot(token='BOT TOKEN HERE') dp = Dispatcher(bot) - @dp.message_handler(commands=['start', 'help']) - async def send_welcome(message: types.Message): - await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") + @dp.message_handler(commands=['start', 'help']) + async def send_welcome(message: types.Message): + await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") - if __name__ == '__main__': - executor.start_polling(dp) + if __name__ == '__main__': + executor.start_polling(dp) diff --git a/environment.yml b/environment.yml index b6d1c93e..026e9cf8 100644 --- a/environment.yml +++ b/environment.yml @@ -1,8 +1,8 @@ -name: py36 +name: py37 channels: - conda-forge dependencies: - - python=3.6 + - python=3.7 - sphinx=1.5.3 - sphinx_rtd_theme=0.2.4 - pip diff --git a/examples/broadcast_example.py b/examples/broadcast_example.py index 40ba5e0c..9e654d44 100644 --- a/examples/broadcast_example.py +++ b/examples/broadcast_example.py @@ -29,6 +29,7 @@ async def send_message(user_id: int, text: str, disable_notification: bool = Fal :param user_id: :param text: + :param disable_notification: :return: """ try: diff --git a/examples/check_user_language.py b/examples/check_user_language.py index 1e1046a9..bd0ba7f9 100644 --- a/examples/check_user_language.py +++ b/examples/check_user_language.py @@ -5,18 +5,14 @@ Babel is required. import asyncio import logging -from aiogram import Bot, types -from aiogram.dispatcher import Dispatcher -from aiogram.types import ParseMode -from aiogram.utils.executor import start_polling -from aiogram.utils.markdown import * +from aiogram import Bot, Dispatcher, executor, md, types API_TOKEN = 'BOT TOKEN HERE' logging.basicConfig(level=logging.INFO) loop = asyncio.get_event_loop() -bot = Bot(token=API_TOKEN, loop=loop) +bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.MARKDOWN) dp = Dispatcher(bot) @@ -24,14 +20,14 @@ dp = Dispatcher(bot) async def check_language(message: types.Message): locale = message.from_user.locale - await message.reply(text( - bold('Info about your language:'), - text(' 🔸', bold('Code:'), italic(locale.locale)), - text(' 🔸', bold('Territory:'), italic(locale.territory or 'Unknown')), - text(' 🔸', bold('Language name:'), italic(locale.language_name)), - text(' 🔸', bold('English language name:'), italic(locale.english_name)), - sep='\n'), parse_mode=ParseMode.MARKDOWN) + await message.reply(md.text( + md.bold('Info about your language:'), + md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)), + md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')), + md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)), + md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)), + sep='\n')) if __name__ == '__main__': - start_polling(dp, loop=loop, skip_updates=True) + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/echo_bot.py b/examples/echo_bot.py index 7f4b0324..617dbad7 100644 --- a/examples/echo_bot.py +++ b/examples/echo_bot.py @@ -1,9 +1,7 @@ import asyncio import logging -from aiogram import Bot, types -from aiogram.dispatcher import Dispatcher -from aiogram.utils.executor import start_polling +from aiogram import Bot, types, Dispatcher, executor API_TOKEN = 'BOT TOKEN HERE' @@ -32,10 +30,4 @@ async def echo(message: types.Message): if __name__ == '__main__': - start_polling(dp, loop=loop, skip_updates=True) - - # Also you can use another execution method - # >>> try: - # >>> loop.run_until_complete(main()) - # >>> except KeyboardInterrupt: - # >>> loop.stop() + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/example_context_middleware.py b/examples/example_context_middleware.py deleted file mode 100644 index d909b52d..00000000 --- a/examples/example_context_middleware.py +++ /dev/null @@ -1,47 +0,0 @@ -from aiogram import Bot, types -from aiogram.contrib.middlewares.context import ContextMiddleware -from aiogram.dispatcher import Dispatcher -from aiogram.types import ParseMode -from aiogram.utils import markdown as md -from aiogram.utils.executor import start_polling - -API_TOKEN = 'BOT TOKEN HERE' - -bot = Bot(token=API_TOKEN) -dp = Dispatcher(bot) - -# Setup Context middleware -data: ContextMiddleware = dp.middleware.setup(ContextMiddleware()) - - -# Write custom filter -async def demo_filter(message: types.Message): - # Store some data in context - command = data['command'] = message.get_command() or '' - args = data['args'] = message.get_args() or '' - data['has_args'] = bool(args) - data['some_random_data'] = 42 - return command != '/bad_command' - - -@dp.message_handler(demo_filter) -async def send_welcome(message: types.Message): - # Get data from context - # All of this is available only in current context and from current update object - # `data`- pseudo-alias for `ctx.get_update().conf['_context_data']` - command = data['command'] - args = data['args'] - rand = data['some_random_data'] - has_args = data['has_args'] - - # Send as pre-formatted code block. - await message.reply(md.hpre(f"""command: {command} -args: {['Not available', 'available'][has_args]}: {args} -some random data: {rand} -message ID: {message.message_id} -message: {message.html_text} - """), parse_mode=ParseMode.HTML) - - -if __name__ == '__main__': - start_polling(dp) diff --git a/examples/finite_state_machine_example.py b/examples/finite_state_machine_example.py index e9a25ef2..7a989e5b 100644 --- a/examples/finite_state_machine_example.py +++ b/examples/finite_state_machine_example.py @@ -1,11 +1,13 @@ import asyncio +from typing import Optional -from aiogram import Bot, types +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types from aiogram.contrib.fsm_storage.memory import MemoryStorage -from aiogram.dispatcher import Dispatcher +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters.state import State, StatesGroup from aiogram.types import ParseMode from aiogram.utils import executor -from aiogram.utils.markdown import text, bold API_TOKEN = 'BOT TOKEN HERE' @@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop) storage = MemoryStorage() dp = Dispatcher(bot, storage=storage) + # States -AGE = 'process_age' -NAME = 'process_name' -GENDER = 'process_gender' +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' @dp.message_handler(commands=['start']) @@ -28,48 +32,41 @@ async def cmd_start(message: types.Message): """ Conversation's entry point """ - # Get current state - state = dp.current_state(chat=message.chat.id, user=message.from_user.id) - # Update user's state - await state.set_state(NAME) + # Set state + await Form.name.set() await message.reply("Hi there! What's your name?") # You can use state '*' if you need to handle all states @dp.message_handler(state='*', commands=['cancel']) -@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel') -async def cancel_handler(message: types.Message): +@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*') +async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None): """ Allow user to cancel any action """ - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - # Ignore command if user is not in any (defined) state - if await state.get_state() is None: - return + if raw_state is None: + return - # Otherwise cancel state and inform user about it - # And remove keyboard (just in case) - await state.reset_state(with_data=True) - await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) + # Cancel state and inform user about it + await state.finish() + # And remove keyboard (just in case) + await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) -@dp.message_handler(state=NAME) -async def process_name(message: types.Message): +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state: FSMContext): """ Process user name """ - # Save name to storage and go to next step - # You can use context manager - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - await state.update_data(name=message.text) - await state.set_state(AGE) + await Form.next() + await state.update_data(name=message.text) await message.reply("How old are you?") # Check age. Age gotta be digit -@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit()) +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) async def failed_process_age(message: types.Message): """ If age is invalid @@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message): return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") -@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit()) -async def process_age(message: types.Message): +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state: FSMContext): # Update state and data - with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: - await state.set_state(GENDER) - await state.update_data(age=int(message.text)) + await Form.next() + await state.update_data(age=int(message.text)) # Configure ReplyKeyboardMarkup markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) @@ -92,7 +88,7 @@ async def process_age(message: types.Message): await message.reply("What is your gender?", reply_markup=markup) -@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"]) +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) async def failed_process_gender(message: types.Message): """ In this example gender has to be one of: Male, Female, Other. @@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message): return await message.reply("Bad gender name. Choose you gender from keyboard.") -@dp.message_handler(state=GENDER) -async def process_gender(message: types.Message): - state = dp.current_state(chat=message.chat.id, user=message.from_user.id) - +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state: FSMContext): data = await state.get_data() data['gender'] = message.text @@ -111,10 +105,10 @@ async def process_gender(message: types.Message): markup = types.ReplyKeyboardRemove() # And send message - await bot.send_message(message.chat.id, text( - text('Hi! Nice to meet you,', bold(data['name'])), - text('Age:', data['age']), - text('Gender:', data['gender']), + await bot.send_message(message.chat.id, md.text( + md.text('Hi! Nice to meet you,', md.bold(data['name'])), + md.text('Age:', data['age']), + md.text('Gender:', data['gender']), sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) # Finish conversation @@ -122,10 +116,5 @@ async def process_gender(message: types.Message): await state.finish() -async def shutdown(dispatcher: Dispatcher): - await dispatcher.storage.close() - await dispatcher.storage.wait_closed() - - if __name__ == '__main__': - executor.start_polling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown) + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/finite_state_machine_example_2.py b/examples/finite_state_machine_example_2.py new file mode 100644 index 00000000..5a2996bd --- /dev/null +++ b/examples/finite_state_machine_example_2.py @@ -0,0 +1,126 @@ +""" +This example is equals with 'finite_state_machine_example.py' but with FSM Middleware + +Note that FSM Middleware implements the more simple methods for working with storage. + +With that middleware all data from storage will be loaded before event will be processed +and data will be stored after processing the event. +""" +import asyncio + +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy +from aiogram.dispatcher.filters.state import State, StatesGroup +from aiogram.utils import executor + +API_TOKEN = 'BOT TOKEN HERE' + +loop = asyncio.get_event_loop() + +bot = Bot(token=API_TOKEN, loop=loop) + +# For example use simple MemoryStorage for Dispatcher. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) +dp.middleware.setup(FSMMiddleware()) + + +# States +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' + + +@dp.message_handler(commands=['start']) +async def cmd_start(message: types.Message): + """ + Conversation's entry point + """ + # Set state + await Form.first() + + await message.reply("Hi there! What's your name?") + + +# You can use state '*' if you need to handle all states +@dp.message_handler(state='*', commands=['cancel']) +@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*') +async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy): + """ + Allow user to cancel any action + """ + if state_data.state is None: + return + + # Cancel state and inform user about it + del state_data.state + # And remove keyboard (just in case) + await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) + + +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state_data: FSMSStorageProxy): + """ + Process user name + """ + state_data.state = Form.age + state_data['name'] = message.text + + await message.reply("How old are you?") + + +# Check age. Age gotta be digit +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) +async def failed_process_age(message: types.Message): + """ + If age is invalid + """ + return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") + + +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state_data: FSMSStorageProxy): + # Update state and data + state_data.state = Form.gender + state_data['age'] = int(message.text) + + # Configure ReplyKeyboardMarkup + markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) + markup.add("Male", "Female") + markup.add("Other") + + await message.reply("What is your gender?", reply_markup=markup) + + +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) +async def failed_process_gender(message: types.Message): + """ + In this example gender has to be one of: Male, Female, Other. + """ + return await message.reply("Bad gender name. Choose you gender from keyboard.") + + +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state_data: FSMSStorageProxy): + state_data['gender'] = message.text + + # Remove keyboard + markup = types.ReplyKeyboardRemove() + + # And send message + await bot.send_message(message.chat.id, md.text( + md.text('Hi! Nice to meet you,', md.bold(state_data['name'])), + md.text('Age:', state_data['age']), + md.text('Gender:', state_data['gender']), + sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN) + + # Finish conversation + # WARNING! This method will destroy all data in storage for current user! + state_data.clear() + + +if __name__ == '__main__': + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/i18n_example.py b/examples/i18n_example.py new file mode 100644 index 00000000..6469ed5b --- /dev/null +++ b/examples/i18n_example.py @@ -0,0 +1,56 @@ +""" +Internalize your bot + +Step 1: extract texts + # pybabel extract i18n_example.py -o locales/mybot.pot +Step 2: create *.po files. For e.g. create en, ru, uk locales. + # echo {en,ru,uk} | xargs -n1 pybabel init -i locales/mybot.pot -d locales -D mybot -l +Step 3: translate texts +Step 4: compile translations + # pybabel compile -d locales -D mybot + +Step 5: When you change the code of your bot you need to update po & mo files. + Step 5.1: regenerate pot file: + command from step 1 + Step 5.2: update po files + # pybabel update -d locales -D mybot -i locales/mybot.pot + Step 5.3: update your translations + Step 5.4: compile mo files + command from step 4 +""" + +from pathlib import Path + +from aiogram import Bot, Dispatcher, executor, types +from aiogram.contrib.middlewares.i18n import I18nMiddleware + +TOKEN = 'BOT TOKEN HERE' +I18N_DOMAIN = 'mybot' + +BASE_DIR = Path(__file__).parent +LOCALES_DIR = BASE_DIR / 'locales' + +bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML) +dp = Dispatcher(bot) + +# Setup i18n middleware +i18n = I18nMiddleware(I18N_DOMAIN, LOCALES_DIR) +dp.middleware.setup(i18n) + +# Alias for gettext method +_ = i18n.gettext + + +@dp.message_handler(commands=['start']) +async def cmd_start(message: types.Message): + # Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts. + await message.reply(_('Hello, {user}!').format(user=message.from_user.full_name)) + + +@dp.message_handler(commands=['lang']) +async def cmd_lang(message: types.Message, locale): + await message.reply(_('Your current language: {language}').format(language=locale)) + + +if __name__ == '__main__': + executor.start_polling(dp, skip_updates=True) diff --git a/examples/inline_bot.py b/examples/inline_bot.py index bb6d0f89..4a771210 100644 --- a/examples/inline_bot.py +++ b/examples/inline_bot.py @@ -1,9 +1,7 @@ import asyncio import logging -from aiogram import Bot, types -from aiogram.dispatcher import Dispatcher -from aiogram.utils.executor import start_polling +from aiogram import Bot, types, Dispatcher, executor API_TOKEN = 'BOT TOKEN HERE' @@ -23,4 +21,4 @@ async def inline_echo(inline_query: types.InlineQuery): if __name__ == '__main__': - start_polling(dp, loop=loop, skip_updates=True) + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/locales/en/LC_MESSAGES/mybot.po b/examples/locales/en/LC_MESSAGES/mybot.po new file mode 100644 index 00000000..75970929 --- /dev/null +++ b/examples/locales/en/LC_MESSAGES/mybot.po @@ -0,0 +1,28 @@ +# English translations for PROJECT. +# Copyright (C) 2018 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2018. +# +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2018-06-30 03:50+0300\n" +"PO-Revision-Date: 2018-06-30 03:43+0300\n" +"Last-Translator: FULL NAME \n" +"Language: en\n" +"Language-Team: en \n" +"Plural-Forms: nplurals=2; plural=(n != 1)\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.6.0\n" + +#: i18n_example.py:48 +msgid "Hello, {user}!" +msgstr "" + +#: i18n_example.py:53 +msgid "Your current language: {language}" +msgstr "" + diff --git a/examples/locales/mybot.pot b/examples/locales/mybot.pot new file mode 100644 index 00000000..988ed463 --- /dev/null +++ b/examples/locales/mybot.pot @@ -0,0 +1,27 @@ +# Translations template for PROJECT. +# Copyright (C) 2018 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2018. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2018-06-30 03:50+0300\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.6.0\n" + +#: i18n_example.py:48 +msgid "Hello, {user}!" +msgstr "" + +#: i18n_example.py:53 +msgid "Your current language: {language}" +msgstr "" + diff --git a/examples/locales/ru/LC_MESSAGES/mybot.po b/examples/locales/ru/LC_MESSAGES/mybot.po new file mode 100644 index 00000000..73876f30 --- /dev/null +++ b/examples/locales/ru/LC_MESSAGES/mybot.po @@ -0,0 +1,29 @@ +# Russian translations for PROJECT. +# Copyright (C) 2018 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2018. +# +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2018-06-30 03:50+0300\n" +"PO-Revision-Date: 2018-06-30 03:43+0300\n" +"Last-Translator: FULL NAME \n" +"Language: ru\n" +"Language-Team: ru \n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.6.0\n" + +#: i18n_example.py:48 +msgid "Hello, {user}!" +msgstr "Привет, {user}!" + +#: i18n_example.py:53 +msgid "Your current language: {language}" +msgstr "Твой язык: {language}" + diff --git a/examples/locales/uk/LC_MESSAGES/mybot.po b/examples/locales/uk/LC_MESSAGES/mybot.po new file mode 100644 index 00000000..25970c19 --- /dev/null +++ b/examples/locales/uk/LC_MESSAGES/mybot.po @@ -0,0 +1,29 @@ +# Ukrainian translations for PROJECT. +# Copyright (C) 2018 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2018. +# +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2018-06-30 03:50+0300\n" +"PO-Revision-Date: 2018-06-30 03:43+0300\n" +"Last-Translator: FULL NAME \n" +"Language: uk\n" +"Language-Team: uk \n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.6.0\n" + +#: i18n_example.py:48 +msgid "Hello, {user}!" +msgstr "Привіт, {user}!" + +#: i18n_example.py:53 +msgid "Your current language: {language}" +msgstr "Твоя мова: {language}" + diff --git a/examples/media_group.py b/examples/media_group.py index 0194733d..b1f5246a 100644 --- a/examples/media_group.py +++ b/examples/media_group.py @@ -1,9 +1,6 @@ import asyncio -from aiogram import Bot, types -from aiogram.dispatcher import Dispatcher -from aiogram.types import ChatActions -from aiogram.utils.executor import start_polling +from aiogram import Bot, Dispatcher, executor, filters, types API_TOKEN = 'BOT TOKEN HERE' @@ -12,7 +9,7 @@ bot = Bot(token=API_TOKEN, loop=loop) dp = Dispatcher(bot) -@dp.message_handler(commands=['start']) +@dp.message_handler(filters.CommandStart()) async def send_welcome(message: types.Message): # So... At first I want to send something like this: await message.reply("Do you want to see many pussies? Are you ready?") @@ -21,7 +18,7 @@ async def send_welcome(message: types.Message): await asyncio.sleep(1) # Good bots should send chat actions. Or not. - await ChatActions.upload_photo() + await types.ChatActions.upload_photo() # Create media group media = types.MediaGroup() @@ -39,9 +36,8 @@ async def send_welcome(message: types.Message): # media.attach_photo('', 'cat-cat-cat.') # Done! Send media group - await bot.send_media_group(message.chat.id, media=media, - reply_to_message_id=message.message_id) + await message.reply_media_group(media=media) if __name__ == '__main__': - start_polling(dp, loop=loop, skip_updates=True) + executor.start_polling(dp, loop=loop, skip_updates=True) diff --git a/examples/middleware_and_antiflood.py b/examples/middleware_and_antiflood.py index 2d0f002c..7b83d9a4 100644 --- a/examples/middleware_and_antiflood.py +++ b/examples/middleware_and_antiflood.py @@ -1,11 +1,10 @@ import asyncio -from aiogram import Bot, types +from aiogram import Bot, Dispatcher, executor, types from aiogram.contrib.fsm_storage.redis import RedisStorage2 -from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx +from aiogram.dispatcher import DEFAULT_RATE_LIMIT +from aiogram.dispatcher.handler import CancelHandler from aiogram.dispatcher.middlewares import BaseMiddleware -from aiogram.utils import context, executor -from aiogram.utils.exceptions import Throttled TOKEN = 'BOT TOKEN HERE' @@ -53,10 +52,10 @@ class ThrottlingMiddleware(BaseMiddleware): :param message: """ # Get current handler - handler = context.get_value('handler') + # handler = context.get_value('handler') # Get dispatcher from context - dispatcher = ctx.get_dispatcher() + dispatcher = Dispatcher.current() # If handler was configured, get rate limit and key from handler if handler: @@ -83,8 +82,8 @@ class ThrottlingMiddleware(BaseMiddleware): :param message: :param throttled: """ - handler = context.get_value('handler') - dispatcher = ctx.get_dispatcher() + # handler = context.get_value('handler') + dispatcher = Dispatcher.current() if handler: key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") else: diff --git a/examples/payments.py b/examples/payments.py index 74b78456..d85e94ab 100644 --- a/examples/payments.py +++ b/examples/payments.py @@ -4,7 +4,7 @@ from aiogram import Bot from aiogram import types from aiogram.utils import executor from aiogram.dispatcher import Dispatcher -from aiogram.types.message import ContentType +from aiogram.types.message import ContentTypes BOT_TOKEN = 'BOT TOKEN HERE' @@ -86,7 +86,7 @@ async def checkout(pre_checkout_query: types.PreCheckoutQuery): " try to pay again in a few minutes, we need a small rest.") -@dp.message_handler(content_types=ContentType.SUCCESSFUL_PAYMENT) +@dp.message_handler(content_types=ContentTypes.SUCCESSFUL_PAYMENT) async def got_payment(message: types.Message): await bot.send_message(message.chat.id, 'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`' diff --git a/examples/proxy_and_emojize.py b/examples/proxy_and_emojize.py index d979243c..7e4452ee 100644 --- a/examples/proxy_and_emojize.py +++ b/examples/proxy_and_emojize.py @@ -18,7 +18,7 @@ PROXY_URL = 'http://PROXY_URL' # Or 'socks5://...' # PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password') # And add `proxy_auth=PROXY_AUTH` argument in line 25, like this: # >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH) -# Also you can use Socks5 proxy but you need manually install aiosocksy package. +# Also you can use Socks5 proxy but you need manually install aiohttp_socks package. # Get my ip URL GET_IP_URL = 'http://bot.whatismyipaddress.com/' diff --git a/examples/webhook_example.py b/examples/webhook_example.py index 1a4b8198..a1b48c0f 100644 --- a/examples/webhook_example.py +++ b/examples/webhook_example.py @@ -9,7 +9,7 @@ from aiogram import Bot, types, Version from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.dispatcher import Dispatcher from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage -from aiogram.types import ChatType, ParseMode, ContentType +from aiogram.types import ChatType, ParseMode, ContentTypes from aiogram.utils.markdown import hbold, bold, text, link TOKEN = 'BOT TOKEN HERE' @@ -31,7 +31,7 @@ WEBHOOK_URL = f"https://{WEBHOOK_HOST}:{WEBHOOK_PORT}{WEBHOOK_URL_PATH}" WEBAPP_HOST = 'localhost' WEBAPP_PORT = 3001 -BAD_CONTENT = ContentType.PHOTO & ContentType.DOCUMENT & ContentType.STICKER & ContentType.AUDIO +BAD_CONTENT = ContentTypes.PHOTO & ContentTypes.DOCUMENT & ContentTypes.STICKER & ContentTypes.AUDIO loop = asyncio.get_event_loop() bot = Bot(TOKEN, loop=loop) diff --git a/setup.py b/setup.py index 9b583400..630325e0 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ except ImportError: # pip >= 10.0.0 WORK_DIR = pathlib.Path(__file__).parent # Check python version -MINIMAL_PY_VERSION = (3, 6) +MINIMAL_PY_VERSION = (3, 7) if sys.version_info < MINIMAL_PY_VERSION: raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION)))) @@ -65,7 +65,7 @@ setup( url='https://github.com/aiogram/aiogram', license='MIT', author='Alex Root Junior', - requires_python='>=3.6', + requires_python='>=3.7', author_email='aiogram@illemius.xyz', description='Is a pretty simple and fully asynchronous library for Telegram Bot API', long_description=get_description(), @@ -76,7 +76,7 @@ setup( 'Intended Audience :: Developers', 'Intended Audience :: System Administrators', 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Software Development :: Libraries :: Application Frameworks', ], install_requires=get_requirements() diff --git a/tests/states_group.py b/tests/states_group.py new file mode 100644 index 00000000..8593cea3 --- /dev/null +++ b/tests/states_group.py @@ -0,0 +1,102 @@ +import pytest + +from aiogram.dispatcher.filters.state import State, StatesGroup, any_state, default_state + + +class MyGroup(StatesGroup): + state = State() + state_1 = State() + state_2 = State() + + class MySubGroup(StatesGroup): + sub_state = State() + sub_state_1 = State() + sub_state_2 = State() + + in_custom_group = State(group_name='custom_group') + + class NewGroup(StatesGroup): + spam = State() + renamed_state = State(state='spam_state') + + +alone_state = State('alone') +alone_in_group = State('alone', group_name='home') + + +def test_default_state(): + assert default_state.state is None + + +def test_any_state(): + assert any_state.state == '*' + + +def test_alone_state(): + assert alone_state.state == '@:alone' + assert alone_in_group.state == 'home:alone' + + +def test_group_names(): + assert MyGroup.__group_name__ == 'MyGroup' + assert MyGroup.__full_group_name__ == 'MyGroup' + + assert MyGroup.MySubGroup.__group_name__ == 'MySubGroup' + assert MyGroup.MySubGroup.__full_group_name__ == 'MyGroup.MySubGroup' + + assert MyGroup.MySubGroup.NewGroup.__group_name__ == 'NewGroup' + assert MyGroup.MySubGroup.NewGroup.__full_group_name__ == 'MyGroup.MySubGroup.NewGroup' + + +def test_custom_group_in_group(): + assert MyGroup.MySubGroup.in_custom_group.state == 'custom_group:in_custom_group' + + +def test_custom_state_name_in_group(): + assert MyGroup.MySubGroup.NewGroup.renamed_state.state == 'MyGroup.MySubGroup.NewGroup:spam_state' + + +def test_group_states_names(): + assert len(MyGroup.states) == 3 + assert len(MyGroup.all_states) == 9 + + assert MyGroup.states_names == ('MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2') + assert MyGroup.MySubGroup.states_names == ( + 'MyGroup.MySubGroup:sub_state', 'MyGroup.MySubGroup:sub_state_1', 'MyGroup.MySubGroup:sub_state_2', + 'custom_group:in_custom_group') + assert MyGroup.MySubGroup.NewGroup.states_names == ( + 'MyGroup.MySubGroup.NewGroup:spam', 'MyGroup.MySubGroup.NewGroup:spam_state') + + assert MyGroup.all_states_names == ( + 'MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2', + 'MyGroup.MySubGroup:sub_state', + 'MyGroup.MySubGroup:sub_state_1', + 'MyGroup.MySubGroup:sub_state_2', + 'custom_group:in_custom_group', + 'MyGroup.MySubGroup.NewGroup:spam', + 'MyGroup.MySubGroup.NewGroup:spam_state') + + assert MyGroup.MySubGroup.all_states_names == ( + 'MyGroup.MySubGroup:sub_state', + 'MyGroup.MySubGroup:sub_state_1', + 'MyGroup.MySubGroup:sub_state_2', + 'custom_group:in_custom_group', + 'MyGroup.MySubGroup.NewGroup:spam', + 'MyGroup.MySubGroup.NewGroup:spam_state') + + assert MyGroup.MySubGroup.NewGroup.all_states_names == ( + 'MyGroup.MySubGroup.NewGroup:spam', + 'MyGroup.MySubGroup.NewGroup:spam_state') + + +def test_root_element(): + root = MyGroup.MySubGroup.NewGroup.spam.get_root() + + assert issubclass(root, StatesGroup) + assert root == MyGroup + + assert root == MyGroup.state.get_root() + assert root == MyGroup.MySubGroup.get_root() + + with pytest.raises(RuntimeError): + any_state.get_root() diff --git a/tox.ini b/tox.ini index 1460b55c..aff44213 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36 +envlist = py37 [testenv] deps = -rdev_requirements.txt