mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge c733257be5 into 5f7a495b10
This commit is contained in:
commit
bfccd6268c
64 changed files with 3455 additions and 2299 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -57,3 +57,6 @@ experiment.py
|
||||||
|
|
||||||
# Doc's
|
# Doc's
|
||||||
docs/html
|
docs/html
|
||||||
|
|
||||||
|
# i18n/l10n
|
||||||
|
*.mo
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
[](https://github.com/aiogram/aiogram/issues)
|
[](https://github.com/aiogram/aiogram/issues)
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.6 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
|
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.7 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
|
||||||
|
|
||||||
You can [read the docs here](http://aiogram.readthedocs.io/en/latest/).
|
You can [read the docs here](http://aiogram.readthedocs.io/en/latest/).
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,42 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from . import bot
|
||||||
|
from . import contrib
|
||||||
|
from . import dispatcher
|
||||||
|
from . import types
|
||||||
|
from . import utils
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
|
from .dispatcher import filters
|
||||||
|
from .dispatcher import middlewares
|
||||||
|
from .utils import exceptions, executor, helper, markdown as md
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop
|
||||||
except ImportError:
|
except ImportError:
|
||||||
uvloop = None
|
uvloop = None
|
||||||
else:
|
else:
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
if 'DISABLE_UVLOOP' not in os.environ:
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
__version__ = '1.4'
|
__all__ = [
|
||||||
|
'Bot',
|
||||||
|
'Dispatcher',
|
||||||
|
'__api_version__',
|
||||||
|
'__version__',
|
||||||
|
'bot',
|
||||||
|
'contrib',
|
||||||
|
'dispatcher',
|
||||||
|
'exceptions',
|
||||||
|
'executor',
|
||||||
|
'filters',
|
||||||
|
'helper',
|
||||||
|
'md',
|
||||||
|
'middlewares',
|
||||||
|
'types',
|
||||||
|
'utils'
|
||||||
|
]
|
||||||
|
|
||||||
|
__version__ = '2.0.dev1'
|
||||||
__api_version__ = '3.6'
|
__api_version__ = '3.6'
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,14 @@
|
||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import certifi
|
||||||
|
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..utils import exceptions
|
from ..utils import exceptions
|
||||||
|
|
@ -34,58 +40,73 @@ def check_token(token: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def _check_result(method_name, response):
|
async def check_result(method_name: str, content_type: str, status_code: int, body: str):
|
||||||
"""
|
"""
|
||||||
Checks whether `result` is a valid API response.
|
Checks whether `result` is a valid API response.
|
||||||
A result is considered invalid if:
|
A result is considered invalid if:
|
||||||
- The server returned an HTTP response code other than 200
|
- The server returned an HTTP response code other than 200
|
||||||
- The content of the result is invalid JSON.
|
- The content of the result is invalid JSON.
|
||||||
- The method call was unsuccessful (The JSON 'ok' field equals False)
|
- The method call was unsuccessful (The JSON 'ok' field equals False)
|
||||||
|
|
||||||
:raises ApiException: if one of the above listed cases is applicable
|
:param method_name: The name of the method called
|
||||||
:param method_name: The name of the method called
|
:param status_code: status code
|
||||||
:param response: The returned response of the method request
|
:param content_type: content type of result
|
||||||
:return: The result parsed to a JSON dictionary.
|
:param body: result body
|
||||||
"""
|
:return: The result parsed to a JSON dictionary
|
||||||
body = await response.text()
|
:raises ApiException: if one of the above listed cases is applicable
|
||||||
log.debug(f"Response for {method_name}: [{response.status}] {body}")
|
"""
|
||||||
|
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
|
||||||
|
|
||||||
if response.content_type != 'application/json':
|
if content_type != 'application/json':
|
||||||
raise exceptions.NetworkError(f"Invalid response with content type {response.content_type}: \"{body}\"")
|
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_json = json.loads(body)
|
||||||
|
except ValueError:
|
||||||
|
result_json = {}
|
||||||
|
|
||||||
|
description = result_json.get('description') or body
|
||||||
|
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
||||||
|
|
||||||
|
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
|
||||||
|
return result_json.get('result')
|
||||||
|
elif parameters.retry_after:
|
||||||
|
raise exceptions.RetryAfter(parameters.retry_after)
|
||||||
|
elif parameters.migrate_to_chat_id:
|
||||||
|
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
|
||||||
|
elif status_code == HTTPStatus.BAD_REQUEST:
|
||||||
|
exceptions.BadRequest.detect(description)
|
||||||
|
elif status_code == HTTPStatus.NOT_FOUND:
|
||||||
|
exceptions.NotFound.detect(description)
|
||||||
|
elif status_code == HTTPStatus.CONFLICT:
|
||||||
|
exceptions.ConflictError.detect(description)
|
||||||
|
elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
|
||||||
|
exceptions.Unauthorized.detect(description)
|
||||||
|
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
||||||
|
raise exceptions.NetworkError('File too large for uploading. '
|
||||||
|
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
||||||
|
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||||
|
if 'restart' in description:
|
||||||
|
raise exceptions.RestartingTelegram()
|
||||||
|
raise exceptions.TelegramAPIError(description)
|
||||||
|
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
|
||||||
|
|
||||||
|
|
||||||
|
async def make_request(session, token, method, data=None, files=None, **kwargs):
|
||||||
|
# log.debug(f"Make request: '{method}' with data: {data} and files {files}")
|
||||||
|
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
|
||||||
|
|
||||||
|
url = Methods.api_url(token=token, method=method)
|
||||||
|
|
||||||
|
req = compose_data(data, files)
|
||||||
try:
|
try:
|
||||||
result_json = await response.json(loads=json.loads)
|
async with session.post(url, data=req, **kwargs) as response:
|
||||||
except ValueError:
|
return await check_result(method, response.content_type, response.status, await response.text())
|
||||||
result_json = {}
|
except aiohttp.ClientError as e:
|
||||||
|
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||||
description = result_json.get('description') or body
|
|
||||||
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
|
||||||
|
|
||||||
if HTTPStatus.OK <= response.status <= HTTPStatus.IM_USED:
|
|
||||||
return result_json.get('result')
|
|
||||||
elif parameters.retry_after:
|
|
||||||
raise exceptions.RetryAfter(parameters.retry_after)
|
|
||||||
elif parameters.migrate_to_chat_id:
|
|
||||||
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
|
|
||||||
elif response.status == HTTPStatus.BAD_REQUEST:
|
|
||||||
exceptions.BadRequest.detect(description)
|
|
||||||
elif response.status == HTTPStatus.NOT_FOUND:
|
|
||||||
exceptions.NotFound.detect(description)
|
|
||||||
elif response.status == HTTPStatus.CONFLICT:
|
|
||||||
exceptions.ConflictError.detect(description)
|
|
||||||
elif response.status in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
|
|
||||||
exceptions.Unauthorized.detect(description)
|
|
||||||
elif response.status == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
|
||||||
raise exceptions.NetworkError('File too large for uploading. '
|
|
||||||
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
|
||||||
elif response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
|
||||||
if 'restart' in description:
|
|
||||||
raise exceptions.RestartingTelegram()
|
|
||||||
raise exceptions.TelegramAPIError(description)
|
|
||||||
raise exceptions.TelegramAPIError(f"{description} [{response.status}]")
|
|
||||||
|
|
||||||
|
|
||||||
def _guess_filename(obj):
|
def guess_filename(obj):
|
||||||
"""
|
"""
|
||||||
Get file name from object
|
Get file name from object
|
||||||
|
|
||||||
|
|
@ -97,7 +118,7 @@ def _guess_filename(obj):
|
||||||
return os.path.basename(name)
|
return os.path.basename(name)
|
||||||
|
|
||||||
|
|
||||||
def _compose_data(params=None, files=None):
|
def compose_data(params=None, files=None):
|
||||||
"""
|
"""
|
||||||
Prepare request data
|
Prepare request data
|
||||||
|
|
||||||
|
|
@ -121,47 +142,13 @@ def _compose_data(params=None, files=None):
|
||||||
elif isinstance(f, types.InputFile):
|
elif isinstance(f, types.InputFile):
|
||||||
filename, fileobj = f.filename, f.file
|
filename, fileobj = f.filename, f.file
|
||||||
else:
|
else:
|
||||||
filename, fileobj = _guess_filename(f) or key, f
|
filename, fileobj = guess_filename(f) or key, f
|
||||||
|
|
||||||
data.add_field(key, fileobj, filename=filename)
|
data.add_field(key, fileobj, filename=filename)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
|
|
||||||
"""
|
|
||||||
Make request to API
|
|
||||||
|
|
||||||
That make request with Content-Type:
|
|
||||||
application/x-www-form-urlencoded - For simple request
|
|
||||||
and multipart/form-data - for files uploading
|
|
||||||
|
|
||||||
https://core.telegram.org/bots/api#making-requests
|
|
||||||
|
|
||||||
:param session: HTTP Client session
|
|
||||||
:type session: :obj:`aiohttp.ClientSession`
|
|
||||||
:param token: BOT token
|
|
||||||
:type token: :obj:`str`
|
|
||||||
:param method: API method
|
|
||||||
:type method: :obj:`str`
|
|
||||||
:param data: request payload
|
|
||||||
:type data: :obj:`dict`
|
|
||||||
:param files: files
|
|
||||||
:type files: :obj:`dict`
|
|
||||||
:return: result
|
|
||||||
:rtype :obj:`bool` or :obj:`dict`
|
|
||||||
"""
|
|
||||||
log.debug("Make request: '{0}' with data: {1} and files {2}".format(
|
|
||||||
method, data or {}, files or {}))
|
|
||||||
data = _compose_data(data, files)
|
|
||||||
url = Methods.api_url(token=token, method=method)
|
|
||||||
try:
|
|
||||||
async with session.post(url, data=data, **kwargs) as response:
|
|
||||||
return await _check_result(method, response)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class Methods(Helper):
|
class Methods(Helper):
|
||||||
"""
|
"""
|
||||||
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,6 @@ class BaseBot:
|
||||||
api.check_token(token)
|
api.check_token(token)
|
||||||
self.__token = token
|
self.__token = token
|
||||||
|
|
||||||
# Proxy settings
|
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.proxy_auth = proxy_auth
|
self.proxy_auth = proxy_auth
|
||||||
|
|
||||||
|
|
@ -59,37 +58,42 @@ class BaseBot:
|
||||||
# aiohttp main session
|
# aiohttp main session
|
||||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
|
||||||
if isinstance(proxy, str) and proxy.startswith('socks5://'):
|
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
|
||||||
from aiosocksy.connector import ProxyClientRequest, ProxyConnector
|
from aiohttp_socks import SocksConnector
|
||||||
connector = ProxyConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop)
|
from aiohttp_socks.helpers import parse_socks_url
|
||||||
request_class = ProxyClientRequest
|
|
||||||
|
socks_ver, host, port, username, password = parse_socks_url(proxy)
|
||||||
|
if proxy_auth and not username or password:
|
||||||
|
username = proxy_auth.login
|
||||||
|
password = proxy_auth.password
|
||||||
|
|
||||||
|
connector = SocksConnector(socks_ver=socks_ver, host=host, port=port,
|
||||||
|
username=username, password=password,
|
||||||
|
limit=connections_limit, ssl_context=ssl_context,
|
||||||
|
loop=self.loop)
|
||||||
|
|
||||||
|
self.proxy = None
|
||||||
|
self.proxy_auth = None
|
||||||
else:
|
else:
|
||||||
connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context,
|
connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context,
|
||||||
loop=self.loop)
|
loop=self.loop)
|
||||||
request_class = aiohttp.ClientRequest
|
|
||||||
|
|
||||||
self.session = aiohttp.ClientSession(connector=connector, request_class=request_class,
|
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps)
|
||||||
loop=self.loop, json_serialize=json.dumps)
|
|
||||||
|
|
||||||
# Data stored in bot instance
|
|
||||||
self._data = {}
|
|
||||||
|
|
||||||
self.parse_mode = parse_mode
|
self.parse_mode = parse_mode
|
||||||
|
|
||||||
def __del__(self):
|
# Data stored in bot instance
|
||||||
# asyncio.ensure_future(self.close())
|
self._data = {}
|
||||||
pass
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
Close all client sessions
|
Close all client sessions
|
||||||
"""
|
"""
|
||||||
if self.session and not self.session.closed:
|
await self.session.close()
|
||||||
await self.session.close()
|
|
||||||
|
|
||||||
async def request(self, method: base.String,
|
async def request(self, method: base.String,
|
||||||
data: Optional[Dict] = None,
|
data: Optional[Dict] = None,
|
||||||
files: Optional[Dict] = None) -> Union[List, Dict, base.Boolean]:
|
files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]:
|
||||||
"""
|
"""
|
||||||
Make an request to Telegram Bot API
|
Make an request to Telegram Bot API
|
||||||
|
|
||||||
|
|
@ -105,8 +109,8 @@ class BaseBot:
|
||||||
:rtype: Union[List, Dict]
|
:rtype: Union[List, Dict]
|
||||||
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
||||||
"""
|
"""
|
||||||
return await api.request(self.session, self.__token, method, data, files,
|
return await api.make_request(self.session, self.__token, method, data, files,
|
||||||
proxy=self.proxy, proxy_auth=self.proxy_auth)
|
proxy=self.proxy, proxy_auth=self.proxy_auth, **kwargs)
|
||||||
|
|
||||||
async def download_file(self, file_path: base.String,
|
async def download_file(self, file_path: base.String,
|
||||||
destination: Optional[base.InputFile] = None,
|
destination: Optional[base.InputFile] = None,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...dispatcher import BaseStorage
|
from ...dispatcher.storage import BaseStorage
|
||||||
|
|
||||||
|
|
||||||
class MemoryStorage(BaseStorage):
|
class MemoryStorage(BaseStorage):
|
||||||
|
|
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
|
||||||
chat, user = self.check_address(chat=chat, user=user)
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
user = self._get_user(chat, user)
|
user = self._get_user(chat, user)
|
||||||
if data is None:
|
if data is None:
|
||||||
data = []
|
data = {}
|
||||||
user['data'].update(data, **kwargs)
|
user['data'].update(data, **kwargs)
|
||||||
|
|
||||||
async def set_state(self, *,
|
async def set_state(self, *,
|
||||||
|
|
|
||||||
|
|
@ -303,6 +303,8 @@ class RedisStorage2(BaseStorage):
|
||||||
|
|
||||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
data: typing.Dict = None, **kwargs):
|
data: typing.Dict = None, **kwargs):
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
temp_data = await self.get_data(chat=chat, user=user, default={})
|
temp_data = await self.get_data(chat=chat, user=user, default={})
|
||||||
temp_data.update(data, **kwargs)
|
temp_data.update(data, **kwargs)
|
||||||
await self.set_data(chat=chat, user=user, data=temp_data)
|
await self.set_data(chat=chat, user=user, data=temp_data)
|
||||||
|
|
@ -330,6 +332,8 @@ class RedisStorage2(BaseStorage):
|
||||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
user: typing.Union[str, int, None] = None,
|
user: typing.Union[str, int, None] = None,
|
||||||
bucket: typing.Dict = None, **kwargs):
|
bucket: typing.Dict = None, **kwargs):
|
||||||
|
if bucket is None:
|
||||||
|
bucket = {}
|
||||||
temp_bucket = await self.get_data(chat=chat, user=user)
|
temp_bucket = await self.get_data(chat=chat, user=user)
|
||||||
temp_bucket.update(bucket, **kwargs)
|
temp_bucket.update(bucket, **kwargs)
|
||||||
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import weakref
|
||||||
|
|
||||||
import rethinkdb as r
|
import rethinkdb as r
|
||||||
|
|
||||||
from ...dispatcher import BaseStorage
|
from ...dispatcher.storage import BaseStorage
|
||||||
|
|
||||||
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
|
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
||||||
from aiogram import types
|
|
||||||
from aiogram.dispatcher import ctx
|
|
||||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
|
||||||
|
|
||||||
OBJ_KEY = '_context_data'
|
|
||||||
|
|
||||||
|
|
||||||
class ContextMiddleware(BaseMiddleware):
|
|
||||||
"""
|
|
||||||
Allow to store data at all of lifetime of Update object
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def on_pre_process_update(self, update: types.Update):
|
|
||||||
"""
|
|
||||||
Start of Update lifetime
|
|
||||||
|
|
||||||
:param update:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self._configure_update(update)
|
|
||||||
|
|
||||||
async def on_post_process_update(self, update: types.Update, result):
|
|
||||||
"""
|
|
||||||
On finishing of processing update
|
|
||||||
|
|
||||||
:param update:
|
|
||||||
:param result:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if OBJ_KEY in update.conf:
|
|
||||||
del update.conf[OBJ_KEY]
|
|
||||||
|
|
||||||
def _configure_update(self, update: types.Update = None):
|
|
||||||
"""
|
|
||||||
Setup data storage
|
|
||||||
|
|
||||||
:param update:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
obj = update.conf[OBJ_KEY] = {}
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def _get_dict(self):
|
|
||||||
"""
|
|
||||||
Get data from update stored in current context
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
update = ctx.get_update()
|
|
||||||
obj = update.conf.get(OBJ_KEY, None)
|
|
||||||
if obj is None:
|
|
||||||
obj = self._configure_update(update)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
"""
|
|
||||||
Item getter
|
|
||||||
|
|
||||||
:param item:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict()[item]
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
"""
|
|
||||||
Item setter
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param value:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
data = self._get_dict()
|
|
||||||
data[key] = value
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""
|
|
||||||
Iterate over dict
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict().__iter__()
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
"""
|
|
||||||
Iterate over dict keys
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict().keys()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
"""
|
|
||||||
Iterate over dict values
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict().values()
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
"""
|
|
||||||
Get item from dit or return default value
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param default:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict().get(key, default)
|
|
||||||
|
|
||||||
def export(self):
|
|
||||||
"""
|
|
||||||
Export all data s dict
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_dict()
|
|
||||||
25
aiogram/contrib/middlewares/environment.py
Normal file
25
aiogram/contrib/middlewares/environment.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentMiddleware(BaseMiddleware):
|
||||||
|
def __init__(self, context=None):
|
||||||
|
super(EnvironmentMiddleware, self).__init__()
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
context = {}
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
def update_data(self, data):
|
||||||
|
dp = self.manager.dispatcher
|
||||||
|
data.update(
|
||||||
|
bot=dp.bot,
|
||||||
|
dispatcher=dp,
|
||||||
|
loop=dp.loop
|
||||||
|
)
|
||||||
|
if self.context:
|
||||||
|
data.update(self.context)
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
if 'error' not in action and action.startswith('pre_process_'):
|
||||||
|
self.update_data(args[-1])
|
||||||
|
return True
|
||||||
80
aiogram/contrib/middlewares/fsm.py
Normal file
80
aiogram/contrib/middlewares/fsm.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
import copy
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||||
|
from aiogram.dispatcher.storage import FSMContext
|
||||||
|
|
||||||
|
|
||||||
|
class FSMMiddleware(LifetimeControllerMiddleware):
|
||||||
|
skip_patterns = ['error', 'update']
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FSMMiddleware, self).__init__()
|
||||||
|
self._proxies = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
async def pre_process(self, obj, data, *args):
|
||||||
|
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
|
||||||
|
data['state_data'] = proxy
|
||||||
|
|
||||||
|
async def post_process(self, obj, data, *args):
|
||||||
|
proxy = data.get('state_data', None)
|
||||||
|
if isinstance(proxy, FSMSStorageProxy):
|
||||||
|
await proxy.save()
|
||||||
|
|
||||||
|
|
||||||
|
class FSMSStorageProxy(dict):
|
||||||
|
def __init__(self, fsm_context: FSMContext):
|
||||||
|
super(FSMSStorageProxy, self).__init__()
|
||||||
|
self.fsm_context = fsm_context
|
||||||
|
self._copy = {}
|
||||||
|
self._data = {}
|
||||||
|
self._state = None
|
||||||
|
self._is_dirty = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, fsm_context: FSMContext):
|
||||||
|
"""
|
||||||
|
:param fsm_context:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
proxy = cls(fsm_context)
|
||||||
|
await proxy.load()
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
async def load(self):
|
||||||
|
self.clear()
|
||||||
|
self._state = await self.fsm_context.get_state()
|
||||||
|
self.update(await self.fsm_context.get_data())
|
||||||
|
self._copy = copy.deepcopy(self)
|
||||||
|
self._is_dirty = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, value):
|
||||||
|
self._state = value
|
||||||
|
self._is_dirty = True
|
||||||
|
|
||||||
|
@state.deleter
|
||||||
|
def state(self):
|
||||||
|
self._state = None
|
||||||
|
self._is_dirty = True
|
||||||
|
|
||||||
|
async def save(self, force=False):
|
||||||
|
if self._copy != self or force:
|
||||||
|
await self.fsm_context.set_data(data=self)
|
||||||
|
if self._is_dirty or force:
|
||||||
|
await self.fsm_context.set_state(self.state)
|
||||||
|
self._is_dirty = False
|
||||||
|
self._copy = copy.deepcopy(self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = super(FSMSStorageProxy, self).__str__()
|
||||||
|
readable_state = f"'{self.state}'" if self.state else "''"
|
||||||
|
return f"<{self.__class__.__name__}(state={readable_state}, data={s})>"
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
del self.state
|
||||||
|
return super(FSMSStorageProxy, self).clear()
|
||||||
140
aiogram/contrib/middlewares/i18n.py
Normal file
140
aiogram/contrib/middlewares/i18n.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
import gettext
|
||||||
|
import os
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
from babel import Locale
|
||||||
|
|
||||||
|
from ... import types
|
||||||
|
from ...dispatcher.middlewares import BaseMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class I18nMiddleware(BaseMiddleware):
|
||||||
|
"""
|
||||||
|
I18n middleware based on gettext util
|
||||||
|
|
||||||
|
>>> dp = Dispatcher(bot)
|
||||||
|
>>> i18n = I18nMiddleware(DOMAIN, LOCALES_DIR)
|
||||||
|
>>> dp.middleware.setup(i18n)
|
||||||
|
and then
|
||||||
|
>>> _ = i18n.gettext
|
||||||
|
or
|
||||||
|
>>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ctx_locale = ContextVar('ctx_user_locale', default=None)
|
||||||
|
|
||||||
|
def __init__(self, domain, path=None, default='en'):
|
||||||
|
"""
|
||||||
|
:param domain: domain
|
||||||
|
:param path: path where located all *.mo files
|
||||||
|
:param default: default locale name
|
||||||
|
"""
|
||||||
|
super(I18nMiddleware, self).__init__()
|
||||||
|
|
||||||
|
if path is None:
|
||||||
|
path = os.path.join(os.getcwd(), 'locales')
|
||||||
|
|
||||||
|
self.domain = domain
|
||||||
|
self.path = path
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
self.locales = self.find_locales()
|
||||||
|
|
||||||
|
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
|
||||||
|
"""
|
||||||
|
Load all compiled locales from path
|
||||||
|
|
||||||
|
:return: dict with locales
|
||||||
|
"""
|
||||||
|
translations = {}
|
||||||
|
|
||||||
|
for name in os.listdir(self.path):
|
||||||
|
if not os.path.isdir(os.path.join(self.path, name)):
|
||||||
|
continue
|
||||||
|
mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo')
|
||||||
|
|
||||||
|
if os.path.exists(mo_path):
|
||||||
|
with open(mo_path, 'rb') as fp:
|
||||||
|
translations[name] = gettext.GNUTranslations(fp)
|
||||||
|
elif os.path.exists(mo_path[:-2] + 'po'):
|
||||||
|
raise RuntimeError(f"Found locale '{name} but this language is not compiled!")
|
||||||
|
|
||||||
|
return translations
|
||||||
|
|
||||||
|
def reload(self):
|
||||||
|
"""
|
||||||
|
Hot reload locles
|
||||||
|
"""
|
||||||
|
self.locales = self.find_locales()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_locales(self) -> Tuple[str]:
|
||||||
|
"""
|
||||||
|
list of loaded locales
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return tuple(self.locales.keys())
|
||||||
|
|
||||||
|
def __call__(self, singular, plural=None, n=1, locale=None) -> str:
|
||||||
|
return self.gettext(singular, plural, n, locale)
|
||||||
|
|
||||||
|
def gettext(self, singular, plural=None, n=1, locale=None) -> str:
|
||||||
|
"""
|
||||||
|
Get text
|
||||||
|
|
||||||
|
:param singular:
|
||||||
|
:param plural:
|
||||||
|
:param n:
|
||||||
|
:param locale:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if locale is None:
|
||||||
|
locale = self.ctx_locale.get()
|
||||||
|
|
||||||
|
if locale not in self.locales:
|
||||||
|
if n is 1:
|
||||||
|
return singular
|
||||||
|
else:
|
||||||
|
return plural
|
||||||
|
|
||||||
|
translator = self.locales[locale]
|
||||||
|
|
||||||
|
if plural is None:
|
||||||
|
return translator.gettext(singular)
|
||||||
|
else:
|
||||||
|
return translator.ngettext(singular, plural, n)
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic,PyUnusedLocal
|
||||||
|
async def get_user_locale(self, action: str, args: Tuple[Any]) -> str:
|
||||||
|
"""
|
||||||
|
User locale getter
|
||||||
|
You can override the method if you want to use different way of getting user language.
|
||||||
|
|
||||||
|
:param action: event name
|
||||||
|
:param args: event arguments
|
||||||
|
:return: locale name
|
||||||
|
"""
|
||||||
|
user: types.User = types.User.current()
|
||||||
|
locale: Locale = user.locale
|
||||||
|
|
||||||
|
if locale:
|
||||||
|
*_, data = args
|
||||||
|
language = data['locale'] = locale.language
|
||||||
|
return language
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
"""
|
||||||
|
Event trigger
|
||||||
|
|
||||||
|
:param action: event name
|
||||||
|
:param args: event arguments
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if 'update' not in action \
|
||||||
|
and 'error' not in action \
|
||||||
|
and action.startswith('pre_process'):
|
||||||
|
locale = await self.get_user_locale(action, args)
|
||||||
|
self.ctx_locale.set(locale)
|
||||||
|
return True
|
||||||
|
|
@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware):
|
||||||
return round((time.time() - start) * 1000)
|
return round((time.time() - start) * 1000)
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
async def on_pre_process_update(self, update: types.Update):
|
async def on_pre_process_update(self, update: types.Update, data: dict):
|
||||||
update.conf['_start'] = time.time()
|
update.conf['_start'] = time.time()
|
||||||
self.logger.debug(f"Received update [ID:{update.update_id}]")
|
self.logger.debug(f"Received update [ID:{update.update_id}]")
|
||||||
|
|
||||||
async def on_post_process_update(self, update: types.Update, result):
|
async def on_post_process_update(self, update: types.Update, result, data: dict):
|
||||||
timeout = self.check_timeout(update)
|
timeout = self.check_timeout(update)
|
||||||
if timeout > 0:
|
if timeout > 0:
|
||||||
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
|
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
|
||||||
|
|
||||||
async def on_pre_process_message(self, message: types.Message):
|
async def on_pre_process_message(self, message: types.Message, data: dict):
|
||||||
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||||
|
|
||||||
async def on_post_process_message(self, message: types.Message, results):
|
async def on_post_process_message(self, message: types.Message, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||||
|
|
||||||
async def on_pre_process_edited_message(self, edited_message):
|
async def on_pre_process_edited_message(self, edited_message, data: dict):
|
||||||
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
|
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
|
||||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||||
|
|
||||||
async def on_post_process_edited_message(self, edited_message, results):
|
async def on_post_process_edited_message(self, edited_message, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"edited message [ID:{edited_message.message_id}] "
|
f"edited message [ID:{edited_message.message_id}] "
|
||||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||||
|
|
||||||
async def on_pre_process_channel_post(self, channel_post: types.Message):
|
async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
|
||||||
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
|
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
|
||||||
f"in channel [ID:{channel_post.chat.id}]")
|
f"in channel [ID:{channel_post.chat.id}]")
|
||||||
|
|
||||||
async def on_post_process_channel_post(self, channel_post: types.Message, results):
|
async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"channel post [ID:{channel_post.message_id}] "
|
f"channel post [ID:{channel_post.message_id}] "
|
||||||
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
|
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
|
||||||
|
|
||||||
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
|
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
|
||||||
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
|
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
|
||||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||||
|
|
||||||
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
|
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"edited channel post [ID:{edited_channel_post.message_id}] "
|
f"edited channel post [ID:{edited_channel_post.message_id}] "
|
||||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||||
|
|
||||||
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
|
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
|
||||||
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
|
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
|
||||||
f"from user [ID:{inline_query.from_user.id}]")
|
f"from user [ID:{inline_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
|
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"inline query [ID:{inline_query.id}] "
|
f"inline query [ID:{inline_query.id}] "
|
||||||
f"from user [ID:{inline_query.from_user.id}]")
|
f"from user [ID:{inline_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
|
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
|
||||||
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||||
f"result [ID:{chosen_inline_result.result_id}]")
|
f"result [ID:{chosen_inline_result.result_id}]")
|
||||||
|
|
||||||
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
|
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||||
f"result [ID:{chosen_inline_result.result_id}]")
|
f"result [ID:{chosen_inline_result.result_id}]")
|
||||||
|
|
||||||
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
|
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
|
||||||
if callback_query.message:
|
if callback_query.message:
|
||||||
if callback_query.message.from_user:
|
if callback_query.message.from_user:
|
||||||
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||||
|
|
@ -100,7 +100,7 @@ class LoggingMiddleware(BaseMiddleware):
|
||||||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||||
f"from user [ID:{callback_query.from_user.id}]")
|
f"from user [ID:{callback_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_post_process_callback_query(self, callback_query, results):
|
async def on_post_process_callback_query(self, callback_query, results, data: dict):
|
||||||
if callback_query.message:
|
if callback_query.message:
|
||||||
if callback_query.message.from_user:
|
if callback_query.message.from_user:
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
|
@ -117,25 +117,25 @@ class LoggingMiddleware(BaseMiddleware):
|
||||||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||||
f"from user [ID:{callback_query.from_user.id}]")
|
f"from user [ID:{callback_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
|
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
|
||||||
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
|
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
|
||||||
f"from user [ID:{shipping_query.from_user.id}]")
|
f"from user [ID:{shipping_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_post_process_shipping_query(self, shipping_query, results):
|
async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"shipping query [ID:{shipping_query.id}] "
|
f"shipping query [ID:{shipping_query.id}] "
|
||||||
f"from user [ID:{shipping_query.from_user.id}]")
|
f"from user [ID:{shipping_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
|
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
|
||||||
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
|
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
|
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
|
||||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
f"pre-checkout query [ID:{pre_checkout_query.id}] "
|
f"pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||||
|
|
||||||
async def on_pre_process_error(self, dispatcher, update, error):
|
async def on_pre_process_error(self, dispatcher, update, error, data: dict):
|
||||||
timeout = self.check_timeout(update)
|
timeout = self.check_timeout(update)
|
||||||
if timeout > 0:
|
if timeout > 0:
|
||||||
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")
|
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,42 +0,0 @@
|
||||||
from . import Bot
|
|
||||||
from .. import types
|
|
||||||
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
|
|
||||||
from ..utils import context
|
|
||||||
|
|
||||||
|
|
||||||
def _get(key, default=None, no_error=False):
|
|
||||||
result = context.get_value(key, default)
|
|
||||||
if not no_error and result is None:
|
|
||||||
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
|
||||||
f"Maybe asyncio task factory is not configured!\n"
|
|
||||||
f"\t>>> from aiogram.utils import context\n"
|
|
||||||
f"\t>>> loop.set_task_factory(context.task_factory)")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot() -> Bot:
|
|
||||||
return _get('bot')
|
|
||||||
|
|
||||||
|
|
||||||
def get_dispatcher() -> Dispatcher:
|
|
||||||
return _get('dispatcher')
|
|
||||||
|
|
||||||
|
|
||||||
def get_update() -> types.Update:
|
|
||||||
return _get(UPDATE_OBJECT)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mode() -> str:
|
|
||||||
return _get(MODE, 'unknown')
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat() -> int:
|
|
||||||
return _get('chat', no_error=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_user() -> int:
|
|
||||||
return _get('user', no_error=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_state() -> FSMContext:
|
|
||||||
return get_dispatcher().current_state()
|
|
||||||
1007
aiogram/dispatcher/dispatcher.py
Normal file
1007
aiogram/dispatcher/dispatcher.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,289 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import re
|
|
||||||
|
|
||||||
from ..types import CallbackQuery, ContentType, Message
|
|
||||||
from ..utils import context
|
|
||||||
from ..utils.helper import Helper, HelperMode, Item
|
|
||||||
|
|
||||||
USER_STATE = 'USER_STATE'
|
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args):
|
|
||||||
"""
|
|
||||||
Helper for executing filter
|
|
||||||
|
|
||||||
:param filter_:
|
|
||||||
:param args:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if not callable(filter_):
|
|
||||||
raise TypeError('Filter must be callable and/or awaitable!')
|
|
||||||
|
|
||||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
|
||||||
return await filter_(*args)
|
|
||||||
else:
|
|
||||||
return filter_(*args)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_filters(filters, args):
|
|
||||||
"""
|
|
||||||
Check list of filters
|
|
||||||
|
|
||||||
:param filters:
|
|
||||||
:param args:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if filters is not None:
|
|
||||||
for filter_ in filters:
|
|
||||||
f = await check_filter(filter_, args)
|
|
||||||
if not f:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
|
||||||
"""
|
|
||||||
Base class for filters
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.check(*args, **kwargs)
|
|
||||||
|
|
||||||
def check(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncFilter(Filter):
|
|
||||||
"""
|
|
||||||
Base class for asynchronous filters
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __await__(self):
|
|
||||||
return self.check
|
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AnyFilter(AsyncFilter):
|
|
||||||
"""
|
|
||||||
One filter from many
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *filters: callable):
|
|
||||||
self.filters = filters
|
|
||||||
|
|
||||||
async def check(self, *args):
|
|
||||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
|
||||||
return any(await asyncio.gather(*f))
|
|
||||||
|
|
||||||
|
|
||||||
class NotFilter(AsyncFilter):
|
|
||||||
"""
|
|
||||||
Reverse filter
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filter_: callable):
|
|
||||||
self.filter = filter_
|
|
||||||
|
|
||||||
async def check(self, *args):
|
|
||||||
return not await check_filter(self.filter, args)
|
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(AsyncFilter):
|
|
||||||
"""
|
|
||||||
Check commands in message
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, commands):
|
|
||||||
self.commands = commands
|
|
||||||
|
|
||||||
async def check(self, message):
|
|
||||||
if not message.is_command():
|
|
||||||
return False
|
|
||||||
|
|
||||||
command = message.text.split()[0][1:]
|
|
||||||
command, _, mention = command.partition('@')
|
|
||||||
|
|
||||||
if mention and mention != (await message.bot.me).username:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if command not in self.commands:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class RegexpFilter(Filter):
|
|
||||||
"""
|
|
||||||
Regexp filter for messages
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, regexp):
|
|
||||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
|
||||||
|
|
||||||
def check(self, obj):
|
|
||||||
if isinstance(obj, Message) and obj.text:
|
|
||||||
return bool(self.regexp.search(obj.text))
|
|
||||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
|
||||||
return bool(self.regexp.search(obj.data))
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class RegexpCommandsFilter(AsyncFilter):
|
|
||||||
"""
|
|
||||||
Check commands by regexp in message
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, regexp_commands):
|
|
||||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
|
||||||
|
|
||||||
async def check(self, message):
|
|
||||||
if not message.is_command():
|
|
||||||
return False
|
|
||||||
|
|
||||||
command = message.text.split()[0][1:]
|
|
||||||
command, _, mention = command.partition('@')
|
|
||||||
|
|
||||||
if mention and mention != (await message.bot.me).username:
|
|
||||||
return False
|
|
||||||
|
|
||||||
for command in self.regexp_commands:
|
|
||||||
search = command.search(message.text)
|
|
||||||
if search:
|
|
||||||
message.conf['regexp_command'] = search
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeFilter(Filter):
|
|
||||||
"""
|
|
||||||
Check message content type
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, content_types):
|
|
||||||
self.content_types = content_types
|
|
||||||
|
|
||||||
def check(self, message):
|
|
||||||
return ContentType.ANY[0] in self.content_types or \
|
|
||||||
message.content_type in self.content_types
|
|
||||||
|
|
||||||
|
|
||||||
class CancelFilter(Filter):
|
|
||||||
"""
|
|
||||||
Find cancel in message text
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cancel_set=None):
|
|
||||||
if cancel_set is None:
|
|
||||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
|
||||||
self.cancel_set = cancel_set
|
|
||||||
|
|
||||||
def check(self, message):
|
|
||||||
if message.text:
|
|
||||||
return message.text.lower() in self.cancel_set
|
|
||||||
|
|
||||||
|
|
||||||
class StateFilter(AsyncFilter):
|
|
||||||
"""
|
|
||||||
Check user state
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
|
||||||
self.dispatcher = dispatcher
|
|
||||||
self.state = state
|
|
||||||
|
|
||||||
def get_target(self, obj):
|
|
||||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
|
||||||
|
|
||||||
async def check(self, obj):
|
|
||||||
if self.state == '*':
|
|
||||||
return True
|
|
||||||
|
|
||||||
if context.check_value(USER_STATE):
|
|
||||||
context_state = context.get_value(USER_STATE)
|
|
||||||
return self.state == context_state
|
|
||||||
else:
|
|
||||||
chat, user = self.get_target(obj)
|
|
||||||
|
|
||||||
if chat or user:
|
|
||||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StatesListFilter(StateFilter):
|
|
||||||
"""
|
|
||||||
List of states
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def check(self, obj):
|
|
||||||
chat, user = self.get_target(obj)
|
|
||||||
|
|
||||||
if chat or user:
|
|
||||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class ExceptionsFilter(Filter):
|
|
||||||
"""
|
|
||||||
Filter for exceptions
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, exception):
|
|
||||||
self.exception = exception
|
|
||||||
|
|
||||||
def check(self, dispatcher, update, exception):
|
|
||||||
return isinstance(exception, self.exception)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Prepare filters
|
|
||||||
|
|
||||||
:param dispatcher:
|
|
||||||
:param args:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
filters_set = []
|
|
||||||
|
|
||||||
for name, filter_ in kwargs.items():
|
|
||||||
if filter_ is None and name != DefaultFilters.STATE:
|
|
||||||
continue
|
|
||||||
if name == DefaultFilters.COMMANDS:
|
|
||||||
if isinstance(filter_, str):
|
|
||||||
filters_set.append(CommandsFilter([filter_]))
|
|
||||||
else:
|
|
||||||
filters_set.append(CommandsFilter(filter_))
|
|
||||||
elif name == DefaultFilters.REGEXP:
|
|
||||||
filters_set.append(RegexpFilter(filter_))
|
|
||||||
elif name == DefaultFilters.CONTENT_TYPES:
|
|
||||||
filters_set.append(ContentTypeFilter(filter_))
|
|
||||||
elif name == DefaultFilters.FUNC:
|
|
||||||
filters_set.append(filter_)
|
|
||||||
elif name == DefaultFilters.STATE:
|
|
||||||
if isinstance(filter_, (list, set, tuple)):
|
|
||||||
filters_set.append(StatesListFilter(dispatcher, filter_))
|
|
||||||
else:
|
|
||||||
filters_set.append(StateFilter(dispatcher, filter_))
|
|
||||||
elif isinstance(filter_, Filter):
|
|
||||||
filters_set.append(filter_)
|
|
||||||
|
|
||||||
filters_set += list(args)
|
|
||||||
|
|
||||||
return filters_set
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultFilters(Helper):
|
|
||||||
mode = HelperMode.snake_case
|
|
||||||
|
|
||||||
COMMANDS = Item() # commands
|
|
||||||
REGEXP = Item() # regexp
|
|
||||||
CONTENT_TYPES = Item() # content_type
|
|
||||||
FUNC = Item() # func
|
|
||||||
STATE = Item() # state
|
|
||||||
24
aiogram/dispatcher/filters/__init__.py
Normal file
24
aiogram/dispatcher/filters/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
|
||||||
|
RegexpCommandsFilter, StateFilter, Text
|
||||||
|
from .factory import FiltersFactory
|
||||||
|
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'AbstractFilter',
|
||||||
|
'BoundFilter',
|
||||||
|
'Command',
|
||||||
|
'CommandStart',
|
||||||
|
'CommandHelp',
|
||||||
|
'ContentTypeFilter',
|
||||||
|
'ExceptionsFilter',
|
||||||
|
'Filter',
|
||||||
|
'FilterNotPassed',
|
||||||
|
'FilterRecord',
|
||||||
|
'FiltersFactory',
|
||||||
|
'RegexpCommandsFilter',
|
||||||
|
'Regexp',
|
||||||
|
'StateFilter',
|
||||||
|
'Text',
|
||||||
|
'check_filter',
|
||||||
|
'check_filters'
|
||||||
|
]
|
||||||
313
aiogram/dispatcher/filters/builtin.py
Normal file
313
aiogram/dispatcher/filters/builtin.py
Normal file
|
|
@ -0,0 +1,313 @@
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
|
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
|
||||||
|
from aiogram.types import CallbackQuery, Message
|
||||||
|
|
||||||
|
|
||||||
|
class Command(Filter):
|
||||||
|
"""
|
||||||
|
You can handle commands by using this filter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, commands: Union[Iterable, str],
|
||||||
|
prefixes: Union[Iterable, str] = '/',
|
||||||
|
ignore_case: bool = True,
|
||||||
|
ignore_mention: bool = False):
|
||||||
|
"""
|
||||||
|
Filter can be initialized from filters factory or by simply creating instance of this class
|
||||||
|
|
||||||
|
:param commands: command or list of commands
|
||||||
|
:param prefixes:
|
||||||
|
:param ignore_case:
|
||||||
|
:param ignore_mention:
|
||||||
|
"""
|
||||||
|
if isinstance(commands, str):
|
||||||
|
commands = (commands,)
|
||||||
|
|
||||||
|
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
||||||
|
self.prefixes = prefixes
|
||||||
|
self.ignore_case = ignore_case
|
||||||
|
self.ignore_mention = ignore_mention
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Validator for filters factory
|
||||||
|
|
||||||
|
:param full_config:
|
||||||
|
:return: config or empty dict
|
||||||
|
"""
|
||||||
|
config = {}
|
||||||
|
if 'commands' in full_config:
|
||||||
|
config['commands'] = full_config.pop('commands')
|
||||||
|
if 'commands_prefix' in full_config:
|
||||||
|
config['prefixes'] = full_config.pop('commands_prefix')
|
||||||
|
if 'commands_ignore_mention' in full_config:
|
||||||
|
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def check(self, message: types.Message):
|
||||||
|
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||||
|
full_command = message.text.split()[0]
|
||||||
|
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
||||||
|
|
||||||
|
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
||||||
|
return False
|
||||||
|
elif prefix not in prefixes:
|
||||||
|
return False
|
||||||
|
elif (command.lower() if ignore_case else command) not in commands:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandObj:
|
||||||
|
prefix: str = '/'
|
||||||
|
command: str = ''
|
||||||
|
mention: str = None
|
||||||
|
args: str = field(repr=False, default=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mentioned(self) -> bool:
|
||||||
|
return bool(self.mention)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
line = self.prefix + self.command
|
||||||
|
if self.mentioned:
|
||||||
|
line += '@' + self.mention
|
||||||
|
if self.args:
|
||||||
|
line += ' ' + self.args
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
class CommandStart(Command):
|
||||||
|
def __init__(self):
|
||||||
|
super(CommandStart, self).__init__(['start'])
|
||||||
|
|
||||||
|
|
||||||
|
class CommandHelp(Command):
|
||||||
|
def __init__(self):
|
||||||
|
super(CommandHelp, self).__init__(['help'])
|
||||||
|
|
||||||
|
|
||||||
|
class Text(Filter):
|
||||||
|
"""
|
||||||
|
Simple text filter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
equals: Optional[str] = None,
|
||||||
|
contains: Optional[str] = None,
|
||||||
|
startswith: Optional[str] = None,
|
||||||
|
endswith: Optional[str] = None,
|
||||||
|
ignore_case=False):
|
||||||
|
"""
|
||||||
|
Check text for one of pattern. Only one mode can be used in one filter.
|
||||||
|
|
||||||
|
:param equals:
|
||||||
|
:param contains:
|
||||||
|
:param startswith:
|
||||||
|
:param endswith:
|
||||||
|
:param ignore_case: case insensitive
|
||||||
|
"""
|
||||||
|
# Only one mode can be used. check it.
|
||||||
|
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||||
|
if check > 1:
|
||||||
|
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
||||||
|
('contains', contains),
|
||||||
|
('startswith', startswith),
|
||||||
|
('endswith', endswith)
|
||||||
|
] if arg[1]])
|
||||||
|
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||||
|
elif check == 0:
|
||||||
|
raise ValueError(f"No one mode is specified!")
|
||||||
|
|
||||||
|
self.equals = equals
|
||||||
|
self.contains = contains
|
||||||
|
self.endswith = endswith
|
||||||
|
self.startswith = startswith
|
||||||
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
|
if 'text' in full_config:
|
||||||
|
return {'equals': full_config.pop('text')}
|
||||||
|
elif 'text_contains' in full_config:
|
||||||
|
return {'contains': full_config.pop('text_contains')}
|
||||||
|
elif 'text_startswith' in full_config:
|
||||||
|
return {'startswith': full_config.pop('text_startswith')}
|
||||||
|
elif 'text_endswith' in full_config:
|
||||||
|
return {'endswith': full_config.pop('text_endswith')}
|
||||||
|
|
||||||
|
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||||
|
if isinstance(obj, Message):
|
||||||
|
text = obj.text or obj.caption or ''
|
||||||
|
elif isinstance(obj, CallbackQuery):
|
||||||
|
text = obj.data
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.ignore_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
if self.equals:
|
||||||
|
return text == self.equals
|
||||||
|
elif self.contains:
|
||||||
|
return self.contains in text
|
||||||
|
elif self.startswith:
|
||||||
|
return text.startswith(self.startswith)
|
||||||
|
elif self.endswith:
|
||||||
|
return text.endswith(self.endswith)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Regexp(Filter):
|
||||||
|
"""
|
||||||
|
Regexp filter for messages and callback query
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, regexp):
|
||||||
|
if not isinstance(regexp, re.Pattern):
|
||||||
|
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
self.regexp = regexp
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
|
if 'regexp' in full_config:
|
||||||
|
return {'regexp': full_config.pop('regexp')}
|
||||||
|
|
||||||
|
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||||
|
if isinstance(obj, Message):
|
||||||
|
match = self.regexp.search(obj.text or obj.caption or '')
|
||||||
|
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||||
|
match = self.regexp.search(obj.data)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if match:
|
||||||
|
return {'regexp': match}
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class RegexpCommandsFilter(BoundFilter):
|
||||||
|
"""
|
||||||
|
Check commands by regexp in message
|
||||||
|
"""
|
||||||
|
|
||||||
|
key = 'regexp_commands'
|
||||||
|
|
||||||
|
def __init__(self, regexp_commands):
|
||||||
|
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||||
|
|
||||||
|
async def check(self, message):
|
||||||
|
if not message.is_command():
|
||||||
|
return False
|
||||||
|
|
||||||
|
command = message.text.split()[0][1:]
|
||||||
|
command, _, mention = command.partition('@')
|
||||||
|
|
||||||
|
if mention and mention != (await message.bot.me).username:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for command in self.regexp_commands:
|
||||||
|
search = command.search(message.text)
|
||||||
|
if search:
|
||||||
|
return {'regexp_command': search}
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeFilter(BoundFilter):
|
||||||
|
"""
|
||||||
|
Check message content type
|
||||||
|
"""
|
||||||
|
|
||||||
|
key = 'content_types'
|
||||||
|
required = True
|
||||||
|
default = types.ContentTypes.TEXT
|
||||||
|
|
||||||
|
def __init__(self, content_types):
|
||||||
|
self.content_types = content_types
|
||||||
|
|
||||||
|
async def check(self, message):
|
||||||
|
return types.ContentType.ANY in self.content_types or \
|
||||||
|
message.content_type in self.content_types
|
||||||
|
|
||||||
|
|
||||||
|
class StateFilter(BoundFilter):
|
||||||
|
"""
|
||||||
|
Check user state
|
||||||
|
"""
|
||||||
|
key = 'state'
|
||||||
|
required = True
|
||||||
|
|
||||||
|
ctx_state = ContextVar('user_state')
|
||||||
|
|
||||||
|
def __init__(self, dispatcher, state):
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
|
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
states = []
|
||||||
|
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||||
|
state = [state, ]
|
||||||
|
for item in state:
|
||||||
|
if isinstance(item, State):
|
||||||
|
states.append(item.state)
|
||||||
|
elif inspect.isclass(item) and issubclass(item, StatesGroup):
|
||||||
|
states.extend(item.all_states_names)
|
||||||
|
else:
|
||||||
|
states.append(item)
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
def get_target(self, obj):
|
||||||
|
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||||
|
|
||||||
|
async def check(self, obj):
|
||||||
|
if '*' in self.states:
|
||||||
|
return {'state': self.dispatcher.current_state()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
state = self.ctx_state.get()
|
||||||
|
except LookupError:
|
||||||
|
chat, user = self.get_target(obj)
|
||||||
|
|
||||||
|
if chat or user:
|
||||||
|
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
||||||
|
self.ctx_state.set(state)
|
||||||
|
if state in self.states:
|
||||||
|
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||||
|
|
||||||
|
else:
|
||||||
|
if state in self.states:
|
||||||
|
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionsFilter(BoundFilter):
|
||||||
|
"""
|
||||||
|
Filter for exceptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
key = 'exception'
|
||||||
|
|
||||||
|
def __init__(self, dispatcher, exception):
|
||||||
|
super().__init__(dispatcher)
|
||||||
|
self.exception = exception
|
||||||
|
|
||||||
|
async def check(self, dispatcher, update, exception):
|
||||||
|
try:
|
||||||
|
raise exception
|
||||||
|
except self.exception:
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
73
aiogram/dispatcher/filters/factory.py
Normal file
73
aiogram/dispatcher/filters/factory.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from .filters import AbstractFilter, FilterRecord
|
||||||
|
from ..handler import Handler
|
||||||
|
|
||||||
|
|
||||||
|
class FiltersFactory:
|
||||||
|
"""
|
||||||
|
Default filters factory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher):
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
self._registered: typing.List[FilterRecord] = []
|
||||||
|
|
||||||
|
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
||||||
|
validator: typing.Optional[typing.Callable] = None,
|
||||||
|
event_handlers: typing.Optional[typing.List[Handler]] = None,
|
||||||
|
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||||
|
"""
|
||||||
|
Register filter
|
||||||
|
|
||||||
|
:param callback: callable or subclass of :obj:`AbstractFilter`
|
||||||
|
:param validator: custom validator.
|
||||||
|
:param event_handlers: list of instances of :obj:`Handler`
|
||||||
|
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
|
||||||
|
"""
|
||||||
|
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
|
||||||
|
self._registered.append(record)
|
||||||
|
|
||||||
|
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
||||||
|
"""
|
||||||
|
Unregister callback
|
||||||
|
|
||||||
|
:param callback: callable of subclass of :obj:`AbstractFilter`
|
||||||
|
"""
|
||||||
|
for record in self._registered:
|
||||||
|
if record.callback == callback:
|
||||||
|
self._registered.remove(record)
|
||||||
|
|
||||||
|
def resolve(self, event_handler, *custom_filters, **full_config
|
||||||
|
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
|
||||||
|
"""
|
||||||
|
Resolve filters to filters-set
|
||||||
|
|
||||||
|
:param event_handler:
|
||||||
|
:param custom_filters:
|
||||||
|
:param full_config:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
filters_set = []
|
||||||
|
filters_set.extend(self._resolve_registered(event_handler,
|
||||||
|
{k: v for k, v in full_config.items() if v is not None}))
|
||||||
|
if custom_filters:
|
||||||
|
filters_set.extend(custom_filters)
|
||||||
|
|
||||||
|
return filters_set
|
||||||
|
|
||||||
|
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
|
||||||
|
"""
|
||||||
|
Resolve registered filters
|
||||||
|
|
||||||
|
:param event_handler:
|
||||||
|
:param full_config:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for record in self._registered:
|
||||||
|
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
|
||||||
|
if filter_:
|
||||||
|
yield filter_
|
||||||
|
|
||||||
|
if full_config:
|
||||||
|
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')
|
||||||
250
aiogram/dispatcher/filters/filters.py
Normal file
250
aiogram/dispatcher/filters/filters.py
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
import abc
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from ..handler import Handler
|
||||||
|
from ...types.base import TelegramObject
|
||||||
|
|
||||||
|
|
||||||
|
class FilterNotPassed(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_async(func):
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if inspect.isawaitable(func) \
|
||||||
|
or inspect.iscoroutinefunction(func) \
|
||||||
|
or isinstance(func, AbstractFilter):
|
||||||
|
return func
|
||||||
|
return async_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def check_filter(dispatcher, filter_, args):
|
||||||
|
"""
|
||||||
|
Helper for executing filter
|
||||||
|
|
||||||
|
:param dispatcher:
|
||||||
|
:param filter_:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
kwargs = {}
|
||||||
|
if not callable(filter_):
|
||||||
|
raise TypeError('Filter must be callable and/or awaitable!')
|
||||||
|
|
||||||
|
spec = inspect.getfullargspec(filter_)
|
||||||
|
if 'dispatcher' in spec:
|
||||||
|
kwargs['dispatcher'] = dispatcher
|
||||||
|
if inspect.isawaitable(filter_) \
|
||||||
|
or inspect.iscoroutinefunction(filter_) \
|
||||||
|
or isinstance(filter_, AbstractFilter):
|
||||||
|
return await filter_(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return filter_(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_filters(dispatcher, filters, args):
|
||||||
|
"""
|
||||||
|
Check list of filters
|
||||||
|
|
||||||
|
:param dispatcher:
|
||||||
|
:param filters:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
if filters is not None:
|
||||||
|
for filter_ in filters:
|
||||||
|
f = await check_filter(dispatcher, filter_, args)
|
||||||
|
if not f:
|
||||||
|
raise FilterNotPassed()
|
||||||
|
elif isinstance(f, dict):
|
||||||
|
data.update(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class FilterRecord:
|
||||||
|
"""
|
||||||
|
Filters record for factory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callback: typing.Callable,
|
||||||
|
validator: typing.Optional[typing.Callable] = None,
|
||||||
|
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||||
|
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||||
|
if event_handlers and exclude_event_handlers:
|
||||||
|
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
|
||||||
|
|
||||||
|
self.callback = callback
|
||||||
|
self.event_handlers = event_handlers
|
||||||
|
self.exclude_event_handlers = exclude_event_handlers
|
||||||
|
|
||||||
|
if validator is not None:
|
||||||
|
if not callable(validator):
|
||||||
|
raise TypeError(f"validator must be callable, not {type(validator)}")
|
||||||
|
self.resolver = validator
|
||||||
|
elif issubclass(callback, AbstractFilter):
|
||||||
|
self.resolver = callback.validate
|
||||||
|
else:
|
||||||
|
raise RuntimeError('validator is required!')
|
||||||
|
|
||||||
|
def resolve(self, dispatcher, event_handler, full_config):
|
||||||
|
if not self._check_event_handler(event_handler):
|
||||||
|
return
|
||||||
|
config = self.resolver(full_config)
|
||||||
|
if config:
|
||||||
|
if 'dispatcher' not in config:
|
||||||
|
spec = inspect.getfullargspec(self.callback)
|
||||||
|
if 'dispatcher' in spec.args:
|
||||||
|
config['dispatcher'] = dispatcher
|
||||||
|
|
||||||
|
for key in config:
|
||||||
|
if key in full_config:
|
||||||
|
full_config.pop(key)
|
||||||
|
|
||||||
|
return self.callback(**config)
|
||||||
|
|
||||||
|
def _check_event_handler(self, event_handler) -> bool:
|
||||||
|
if self.event_handlers:
|
||||||
|
return event_handler in self.event_handlers
|
||||||
|
elif self.exclude_event_handlers:
|
||||||
|
return event_handler not in self.exclude_event_handlers
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractFilter(abc.ABC):
|
||||||
|
"""
|
||||||
|
Abstract class for custom filters
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
"""
|
||||||
|
Validate and parse config
|
||||||
|
|
||||||
|
:param full_config:
|
||||||
|
:return: config
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def check(self, *args) -> bool:
|
||||||
|
"""
|
||||||
|
Check object
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __call__(self, obj: TelegramObject) -> bool:
|
||||||
|
return await self.check(obj)
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return NotFilter(self)
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
if isinstance(self, AndFilter):
|
||||||
|
self.append(other)
|
||||||
|
return self
|
||||||
|
return AndFilter(self, other)
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
if isinstance(self, OrFilter):
|
||||||
|
self.append(other)
|
||||||
|
return self
|
||||||
|
return OrFilter(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
class Filter(AbstractFilter):
|
||||||
|
"""
|
||||||
|
You can make subclasses of that class for custom filters
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoundFilter(Filter):
|
||||||
|
"""
|
||||||
|
Base class for filters with default validator
|
||||||
|
"""
|
||||||
|
key = None
|
||||||
|
required = False
|
||||||
|
default = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||||
|
if cls.key is not None:
|
||||||
|
if cls.key in full_config:
|
||||||
|
return {cls.key: full_config[cls.key]}
|
||||||
|
elif cls.required:
|
||||||
|
return {cls.key: cls.default}
|
||||||
|
|
||||||
|
|
||||||
|
class _LogicFilter(Filter):
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||||
|
raise ValueError('That filter can\'t be used in filters factory!')
|
||||||
|
|
||||||
|
|
||||||
|
class NotFilter(_LogicFilter):
|
||||||
|
def __init__(self, target):
|
||||||
|
self.target = wrap_async(target)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
return not bool(await self.target(*args))
|
||||||
|
|
||||||
|
|
||||||
|
class AndFilter(_LogicFilter):
|
||||||
|
|
||||||
|
def __init__(self, *targets):
|
||||||
|
self.targets = list(wrap_async(target) for target in targets)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
"""
|
||||||
|
All filters must return a positive result
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
for target in self.targets:
|
||||||
|
result = await target(*args)
|
||||||
|
if not result:
|
||||||
|
return False
|
||||||
|
if isinstance(result, dict):
|
||||||
|
data.update(result)
|
||||||
|
if not data:
|
||||||
|
return True
|
||||||
|
return data
|
||||||
|
|
||||||
|
def append(self, target):
|
||||||
|
self.targets.append(wrap_async(target))
|
||||||
|
|
||||||
|
|
||||||
|
class OrFilter(_LogicFilter):
|
||||||
|
def __init__(self, *targets):
|
||||||
|
self.targets = list(wrap_async(target) for target in targets)
|
||||||
|
|
||||||
|
async def check(self, *args):
|
||||||
|
"""
|
||||||
|
One of filters must return a positive result
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for target in self.targets:
|
||||||
|
result = await target(*args)
|
||||||
|
if result:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def append(self, target):
|
||||||
|
self.targets.append(wrap_async(target))
|
||||||
198
aiogram/dispatcher/filters/state.py
Normal file
198
aiogram/dispatcher/filters/state.py
Normal file
|
|
@ -0,0 +1,198 @@
|
||||||
|
import inspect
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..dispatcher import Dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
"""
|
||||||
|
State object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
|
||||||
|
self._state = state
|
||||||
|
self._group_name = group_name
|
||||||
|
self._group = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def group(self):
|
||||||
|
if not self._group:
|
||||||
|
raise RuntimeError('This state is not in any group.')
|
||||||
|
return self._group
|
||||||
|
|
||||||
|
def get_root(self):
|
||||||
|
return self.group.get_root()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
if self._state is None:
|
||||||
|
return None
|
||||||
|
elif self._state == '*':
|
||||||
|
return self._state
|
||||||
|
elif self._group_name is None and self._group:
|
||||||
|
group = self._group.__full_group_name__
|
||||||
|
elif self._group_name:
|
||||||
|
group = self._group_name
|
||||||
|
else:
|
||||||
|
group = '@'
|
||||||
|
return f"{group}:{self._state}"
|
||||||
|
|
||||||
|
def set_parent(self, group):
|
||||||
|
if not issubclass(group, StatesGroup):
|
||||||
|
raise ValueError('Group must be subclass of StatesGroup')
|
||||||
|
self._group = group
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
if self._state is None:
|
||||||
|
self._state = name
|
||||||
|
self.set_parent(owner)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<State '{self.state or ''}'>"
|
||||||
|
|
||||||
|
__repr__ = __str__
|
||||||
|
|
||||||
|
async def set(self):
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
await state.set_state(self.state)
|
||||||
|
|
||||||
|
|
||||||
|
class StatesGroupMeta(type):
|
||||||
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||||
|
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
|
||||||
|
|
||||||
|
states = []
|
||||||
|
childs = []
|
||||||
|
|
||||||
|
cls._group_name = name
|
||||||
|
|
||||||
|
for name, prop in namespace.items():
|
||||||
|
|
||||||
|
if isinstance(prop, State):
|
||||||
|
states.append(prop)
|
||||||
|
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
|
||||||
|
childs.append(prop)
|
||||||
|
prop._parent = cls
|
||||||
|
# continue
|
||||||
|
|
||||||
|
cls._parent = None
|
||||||
|
cls._childs = tuple(childs)
|
||||||
|
cls._states = tuple(states)
|
||||||
|
cls._state_names = tuple(state.state for state in states)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __group_name__(cls):
|
||||||
|
return cls._group_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __full_group_name__(cls):
|
||||||
|
if cls._parent:
|
||||||
|
return cls._parent.__full_group_name__ + '.' + cls._group_name
|
||||||
|
return cls._group_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def states(cls) -> tuple:
|
||||||
|
return cls._states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def childs(cls):
|
||||||
|
return cls._childs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_childs(cls):
|
||||||
|
result = cls.childs
|
||||||
|
for child in cls.childs:
|
||||||
|
result += child.childs
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_states(cls):
|
||||||
|
result = cls.states
|
||||||
|
for group in cls.childs:
|
||||||
|
result += group.all_states
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_states_names(cls):
|
||||||
|
return tuple(state.state for state in cls.all_states)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def states_names(cls) -> tuple:
|
||||||
|
return tuple(state.state for state in cls.states)
|
||||||
|
|
||||||
|
def get_root(cls):
|
||||||
|
if cls._parent is None:
|
||||||
|
return cls
|
||||||
|
return cls._parent.get_root()
|
||||||
|
|
||||||
|
def __contains__(cls, item):
|
||||||
|
if isinstance(item, str):
|
||||||
|
return item in cls.all_states_names
|
||||||
|
elif isinstance(item, State):
|
||||||
|
return item in cls.all_states
|
||||||
|
elif isinstance(item, StatesGroup):
|
||||||
|
return item in cls.all_childs
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<StatesGroup '{self.__full_group_name__}'>"
|
||||||
|
|
||||||
|
|
||||||
|
class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
|
@classmethod
|
||||||
|
async def next(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
state_name = await state.get_state()
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_step = cls.states_names.index(state_name) + 1
|
||||||
|
except ValueError:
|
||||||
|
next_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_state_name = cls.states[next_step].state
|
||||||
|
except IndexError:
|
||||||
|
next_state_name = None
|
||||||
|
|
||||||
|
await state.set_state(next_state_name)
|
||||||
|
return next_state_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def previous(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
state_name = await state.get_state()
|
||||||
|
|
||||||
|
try:
|
||||||
|
previous_step = cls.states_names.index(state_name) - 1
|
||||||
|
except ValueError:
|
||||||
|
previous_step = 0
|
||||||
|
|
||||||
|
if previous_step < 0:
|
||||||
|
previous_state_name = None
|
||||||
|
else:
|
||||||
|
previous_state_name = cls.states[previous_step].state
|
||||||
|
|
||||||
|
await state.set_state(previous_state_name)
|
||||||
|
return previous_state_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def first(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
first_step_name = cls.states_names[0]
|
||||||
|
|
||||||
|
await state.set_state(first_step_name)
|
||||||
|
return first_step_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def last(cls) -> str:
|
||||||
|
state = Dispatcher.current().current_state()
|
||||||
|
last_step_name = cls.states_names[-1]
|
||||||
|
|
||||||
|
await state.set_state(last_step_name)
|
||||||
|
return last_step_name
|
||||||
|
|
||||||
|
|
||||||
|
default_state = State()
|
||||||
|
any_state = State(state='*')
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from .filters import check_filters
|
import inspect
|
||||||
from ..utils import context
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
ctx_data = ContextVar('ctx_handler_data')
|
||||||
|
|
||||||
|
|
||||||
class SkipHandler(BaseException):
|
class SkipHandler(BaseException):
|
||||||
|
|
@ -10,6 +12,14 @@ class CancelHandler(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_spec(func: callable, kwargs: dict):
|
||||||
|
spec = inspect.getfullargspec(func)
|
||||||
|
if spec.varkw:
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
return {k: v for k, v in kwargs.items() if k in spec.args}
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
|
|
@ -57,31 +67,43 @@ class Handler:
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
from .filters import check_filters, FilterNotPassed
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
ctx_data.set(data)
|
||||||
|
|
||||||
if self.middleware_key:
|
if self.middleware_key:
|
||||||
try:
|
try:
|
||||||
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
|
||||||
except CancelHandler: # Allow to cancel current event
|
except CancelHandler: # Allow to cancel current event
|
||||||
return results
|
return results
|
||||||
|
|
||||||
for filters, handler in self.handlers:
|
try:
|
||||||
if await check_filters(filters, args):
|
for filters, handler in self.handlers:
|
||||||
try:
|
try:
|
||||||
if self.middleware_key:
|
data.update(await check_filters(self.dispatcher, filters, args))
|
||||||
context.set_value('handler', handler)
|
except FilterNotPassed:
|
||||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
|
||||||
response = await handler(*args)
|
|
||||||
if response is not None:
|
|
||||||
results.append(response)
|
|
||||||
if self.once:
|
|
||||||
break
|
|
||||||
except SkipHandler:
|
|
||||||
continue
|
continue
|
||||||
except CancelHandler:
|
else:
|
||||||
break
|
try:
|
||||||
if self.middleware_key:
|
if self.middleware_key:
|
||||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
# context.set_value('handler', handler)
|
||||||
args + (results,))
|
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
|
||||||
|
partial_data = _check_spec(handler, data)
|
||||||
|
response = await handler(*args, **partial_data)
|
||||||
|
if response is not None:
|
||||||
|
results.append(response)
|
||||||
|
if self.once:
|
||||||
|
break
|
||||||
|
except SkipHandler:
|
||||||
|
continue
|
||||||
|
except CancelHandler:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||||
|
args + (results, data,))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -101,3 +101,28 @@ class BaseMiddleware:
|
||||||
if not handler:
|
if not handler:
|
||||||
return None
|
return None
|
||||||
await handler(*args)
|
await handler(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class LifetimeControllerMiddleware(BaseMiddleware):
|
||||||
|
# TODO: Rename class
|
||||||
|
|
||||||
|
skip_patterns = None
|
||||||
|
|
||||||
|
async def pre_process(self, obj, data, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post_process(self, obj, data, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
if self.skip_patterns is not None and any(item in action for item in self.skip_patterns):
|
||||||
|
return False
|
||||||
|
|
||||||
|
obj, *args, data = args
|
||||||
|
if action.startswith('pre_process_'):
|
||||||
|
await self.pre_process(obj, data, *args)
|
||||||
|
elif action.startswith('post_process_'):
|
||||||
|
await self.post_process(obj, data, *args)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
||||||
|
|
@ -281,8 +281,20 @@ class FSMContext:
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_state(value):
|
||||||
|
from .filters.state import State
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, State):
|
||||||
|
return value.state
|
||||||
|
return str(value)
|
||||||
|
|
||||||
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
|
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
|
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
|
||||||
|
|
||||||
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
|
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
|
||||||
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
|
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
|
||||||
|
|
@ -291,7 +303,7 @@ class FSMContext:
|
||||||
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
|
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
|
||||||
|
|
||||||
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
|
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
|
||||||
await self.storage.set_state(chat=self.chat, user=self.user, state=state)
|
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
|
||||||
|
|
||||||
async def set_data(self, data: typing.Dict = None):
|
async def set_data(self, data: typing.Dict = None):
|
||||||
await self.storage.set_data(chat=self.chat, user=self.user, data=data)
|
await self.storage.set_data(chat=self.chat, user=self.user, data=data)
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web_exceptions import HTTPGone
|
from aiohttp.web_exceptions import HTTPGone
|
||||||
|
|
||||||
|
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..bot import api
|
from ..bot import api
|
||||||
from ..types import ParseMode
|
from ..types import ParseMode
|
||||||
from ..types.base import Boolean, Float, Integer, String
|
from ..types.base import Boolean, Float, Integer, String
|
||||||
from ..utils import context
|
|
||||||
from ..utils import helper, markdown
|
from ..utils import helper, markdown
|
||||||
from ..utils import json
|
from ..utils import json
|
||||||
from ..utils.deprecated import warn_deprecated as warn
|
from ..utils.deprecated import warn_deprecated as warn
|
||||||
|
|
@ -89,8 +89,10 @@ class WebhookRequestHandler(web.View):
|
||||||
"""
|
"""
|
||||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||||
try:
|
try:
|
||||||
context.set_value('dispatcher', dp)
|
from aiogram.bot import bot
|
||||||
context.set_value('bot', dp.bot)
|
from aiogram.dispatcher import dispatcher
|
||||||
|
dispatcher.set(dp)
|
||||||
|
bot.bot.set(dp.bot)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
return dp
|
return dp
|
||||||
|
|
@ -117,9 +119,9 @@ class WebhookRequestHandler(web.View):
|
||||||
"""
|
"""
|
||||||
self.validate_ip()
|
self.validate_ip()
|
||||||
|
|
||||||
context.update_state({'CALLER': WEBHOOK,
|
# context.update_state({'CALLER': WEBHOOK,
|
||||||
WEBHOOK_CONNECTION: True,
|
# WEBHOOK_CONNECTION: True,
|
||||||
WEBHOOK_REQUEST: self.request})
|
# WEBHOOK_REQUEST: self.request})
|
||||||
|
|
||||||
dispatcher = self.get_dispatcher()
|
dispatcher = self.get_dispatcher()
|
||||||
update = await self.parse_update(dispatcher.bot)
|
update = await self.parse_update(dispatcher.bot)
|
||||||
|
|
@ -177,7 +179,7 @@ class WebhookRequestHandler(web.View):
|
||||||
if fut.done():
|
if fut.done():
|
||||||
return fut.result()
|
return fut.result()
|
||||||
else:
|
else:
|
||||||
context.set_value(WEBHOOK_CONNECTION, False)
|
# context.set_value(WEBHOOK_CONNECTION, False)
|
||||||
fut.remove_done_callback(cb)
|
fut.remove_done_callback(cb)
|
||||||
fut.add_done_callback(self.respond_via_request)
|
fut.add_done_callback(self.respond_via_request)
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -202,7 +204,7 @@ class WebhookRequestHandler(web.View):
|
||||||
results = task.result()
|
results = task.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
loop.create_task(
|
loop.create_task(
|
||||||
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
||||||
else:
|
else:
|
||||||
response = self.get_response(results)
|
response = self.get_response(results)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
@ -249,7 +251,7 @@ class WebhookRequestHandler(web.View):
|
||||||
ip_address, accept = self.check_ip()
|
ip_address, accept = self.check_ip()
|
||||||
if not accept:
|
if not accept:
|
||||||
raise web.HTTPUnauthorized()
|
raise web.HTTPUnauthorized()
|
||||||
context.set_value('TELEGRAM_IP', ip_address)
|
# context.set_value('TELEGRAM_IP', ip_address)
|
||||||
|
|
||||||
|
|
||||||
class GoneRequestHandler(web.View):
|
class GoneRequestHandler(web.View):
|
||||||
|
|
@ -352,8 +354,8 @@ class BaseResponse:
|
||||||
|
|
||||||
async def __call__(self, bot=None):
|
async def __call__(self, bot=None):
|
||||||
if bot is None:
|
if bot is None:
|
||||||
from aiogram.dispatcher import ctx
|
from aiogram import Bot
|
||||||
bot = ctx.get_bot()
|
bot = Bot.current()
|
||||||
return await self.execute_response(bot)
|
return await self.execute_response(bot)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
|
@ -446,7 +448,8 @@ class ParseModeMixin:
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
bot = context.get_value('bot', None)
|
from aiogram import Bot
|
||||||
|
bot = Bot.current()
|
||||||
if bot is not None:
|
if bot is not None:
|
||||||
return bot.parse_mode
|
return bot.parse_mode
|
||||||
|
|
||||||
|
|
@ -952,7 +955,7 @@ class SendMediaGroup(BaseResponse, ReplyToMixin, DisableNotificationMixin):
|
||||||
self.reply_to_message_id = reply_to_message_id
|
self.reply_to_message_id = reply_to_message_id
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
files = self.media.get_files()
|
files = dict(self.media.get_files())
|
||||||
if files:
|
if files:
|
||||||
raise TypeError('Allowed only file ID or URL\'s')
|
raise TypeError('Allowed only file ID or URL\'s')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from .invoice import Invoice
|
||||||
from .labeled_price import LabeledPrice
|
from .labeled_price import LabeledPrice
|
||||||
from .location import Location
|
from .location import Location
|
||||||
from .mask_position import MaskPosition
|
from .mask_position import MaskPosition
|
||||||
from .message import ContentType, Message, ParseMode
|
from .message import ContentType, ContentTypes, Message, ParseMode
|
||||||
from .message_entity import MessageEntity, MessageEntityType
|
from .message_entity import MessageEntity, MessageEntityType
|
||||||
from .order_info import OrderInfo
|
from .order_info import OrderInfo
|
||||||
from .passport_data import PassportData
|
from .passport_data import PassportData
|
||||||
|
|
@ -77,6 +77,7 @@ __all__ = (
|
||||||
'ChosenInlineResult',
|
'ChosenInlineResult',
|
||||||
'Contact',
|
'Contact',
|
||||||
'ContentType',
|
'ContentType',
|
||||||
|
'ContentTypes',
|
||||||
'Document',
|
'Document',
|
||||||
'EncryptedCredentials',
|
'EncryptedCredentials',
|
||||||
'EncryptedPassportElement',
|
'EncryptedPassportElement',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import typing
|
import typing
|
||||||
|
from contextvars import ContextVar
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from .fields import BaseField
|
from .fields import BaseField
|
||||||
|
|
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
|
||||||
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
||||||
|
|
||||||
mcs._objects[cls.__name__] = cls
|
mcs._objects[cls.__name__] = cls
|
||||||
|
|
||||||
|
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
if value.default and key not in self.values:
|
if value.default and key not in self.values:
|
||||||
self.values[key] = value.default
|
self.values[key] = value.default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def current(cls):
|
||||||
|
return cls._current.get()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_current(cls, obj: TelegramObject):
|
||||||
|
return cls._current.set(obj)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||||
return self._conf
|
return self._conf
|
||||||
|
|
@ -137,8 +150,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self):
|
def bot(self):
|
||||||
from ..dispatcher import ctx
|
from ..bot.bot import Bot
|
||||||
return ctx.get_bot()
|
return Bot.current()
|
||||||
|
|
||||||
def to_python(self) -> typing.Dict:
|
def to_python(self) -> typing.Dict:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
from . import fields
|
from . import fields
|
||||||
|
|
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
|
||||||
if as_html:
|
if as_html:
|
||||||
return markdown.hlink(name, self.user_url)
|
return markdown.hlink(name, self.user_url)
|
||||||
return markdown.link(name, self.user_url)
|
return markdown.link(name, self.user_url)
|
||||||
|
|
||||||
async def get_url(self):
|
async def get_url(self):
|
||||||
"""
|
"""
|
||||||
Use this method to get chat link.
|
Use this method to get chat link.
|
||||||
|
|
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _do(cls, action: str, sleep=None):
|
async def _do(cls, action: str, sleep=None):
|
||||||
from ..dispatcher.ctx import get_bot, get_chat
|
from aiogram import Bot
|
||||||
await get_bot().send_chat_action(get_chat(), action)
|
await Bot.current().send_chat_action(Chat.current(), action)
|
||||||
if sleep:
|
if sleep:
|
||||||
await asyncio.sleep(sleep)
|
await asyncio.sleep(sleep)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ class BaseField(metaclass=abc.ABCMeta):
|
||||||
Base field (prop)
|
Base field (prop)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, base=None, default=None, alias=None):
|
def __init__(self, *, base=None, default=None, alias=None, on_change=None):
|
||||||
"""
|
"""
|
||||||
Init prop
|
Init prop
|
||||||
|
|
||||||
|
|
@ -17,10 +17,12 @@ class BaseField(metaclass=abc.ABCMeta):
|
||||||
:param default: default value
|
:param default: default value
|
||||||
:param alias: alias name (for e.g. field 'from' has to be named 'from_user'
|
:param alias: alias name (for e.g. field 'from' has to be named 'from_user'
|
||||||
as 'from' is a builtin Python keyword
|
as 'from' is a builtin Python keyword
|
||||||
|
:param on_change: callback will be called when value is changed
|
||||||
"""
|
"""
|
||||||
self.base_object = base
|
self.base_object = base
|
||||||
self.default = default
|
self.default = default
|
||||||
self.alias = alias
|
self.alias = alias
|
||||||
|
self.on_change = on_change
|
||||||
|
|
||||||
def __set_name__(self, owner, name):
|
def __set_name__(self, owner, name):
|
||||||
if self.alias is None:
|
if self.alias is None:
|
||||||
|
|
@ -53,6 +55,13 @@ class BaseField(metaclass=abc.ABCMeta):
|
||||||
self.resolve_base(instance)
|
self.resolve_base(instance)
|
||||||
value = self.deserialize(value, parent)
|
value = self.deserialize(value, parent)
|
||||||
instance.values[self.alias] = value
|
instance.values[self.alias] = value
|
||||||
|
self._trigger_changed(instance, value)
|
||||||
|
|
||||||
|
def _trigger_changed(self, instance, value):
|
||||||
|
if not self.on_change and instance is not None:
|
||||||
|
return
|
||||||
|
callback = getattr(instance, self.on_change)
|
||||||
|
callback(value)
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
return self.get_value(instance)
|
return self.get_value(instance)
|
||||||
|
|
@ -154,7 +163,7 @@ class ListOfLists(Field):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class DateTimeField(BaseField):
|
class DateTimeField(Field):
|
||||||
"""
|
"""
|
||||||
In this field st_ored datetime
|
In this field st_ored datetime
|
||||||
|
|
||||||
|
|
@ -167,3 +176,24 @@ class DateTimeField(BaseField):
|
||||||
|
|
||||||
def deserialize(self, value, parent=None):
|
def deserialize(self, value, parent=None):
|
||||||
return datetime.datetime.fromtimestamp(value)
|
return datetime.datetime.fromtimestamp(value)
|
||||||
|
|
||||||
|
|
||||||
|
class TextField(Field):
|
||||||
|
def __init__(self, *, prefix=None, suffix=None, default=None, alias=None):
|
||||||
|
super(TextField, self).__init__(default=default, alias=alias)
|
||||||
|
self.prefix = prefix
|
||||||
|
self.suffix = suffix
|
||||||
|
|
||||||
|
def serialize(self, value):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if self.prefix:
|
||||||
|
value = self.prefix + value
|
||||||
|
if self.suffix:
|
||||||
|
value += self.suffix
|
||||||
|
return value
|
||||||
|
|
||||||
|
def deserialize(self, value, parent=None):
|
||||||
|
if value is not None and not isinstance(value, str):
|
||||||
|
raise TypeError(f"Field '{self.alias}' should be str not {type(value).__name__}")
|
||||||
|
return value
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -45,6 +46,8 @@ class InputFile(base.TelegramObject):
|
||||||
|
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
|
|
||||||
|
self.attachment_key = secrets.token_urlsafe(16)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""
|
"""
|
||||||
Close file descriptor
|
Close file descriptor
|
||||||
|
|
@ -54,13 +57,17 @@ class InputFile(base.TelegramObject):
|
||||||
@property
|
@property
|
||||||
def filename(self):
|
def filename(self):
|
||||||
if self._filename is None:
|
if self._filename is None:
|
||||||
self._filename = api._guess_filename(self._file)
|
self._filename = api.guess_filename(self._file)
|
||||||
return self._filename
|
return self._filename
|
||||||
|
|
||||||
@filename.setter
|
@filename.setter
|
||||||
def filename(self, value):
|
def filename(self, value):
|
||||||
self._filename = value
|
self._filename = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attach(self):
|
||||||
|
return f"attach://{self.attachment_key}"
|
||||||
|
|
||||||
def get_filename(self) -> str:
|
def get_filename(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get file name
|
Get file name
|
||||||
|
|
@ -159,6 +166,9 @@ class InputFile(base.TelegramObject):
|
||||||
|
|
||||||
return writer
|
return writer
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<InputFile 'attach://{self.attachment_key}' with file='{self.file}'>"
|
||||||
|
|
||||||
def to_python(self):
|
def to_python(self):
|
||||||
raise TypeError('Object of this type is not exportable!')
|
raise TypeError('Object of this type is not exportable!')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,9 @@ ATTACHMENT_PREFIX = 'attach://'
|
||||||
class InputMedia(base.TelegramObject):
|
class InputMedia(base.TelegramObject):
|
||||||
"""
|
"""
|
||||||
This object represents the content of a media message to be sent. It should be one of
|
This object represents the content of a media message to be sent. It should be one of
|
||||||
|
- InputMediaAnimation
|
||||||
|
- InputMediaDocument
|
||||||
|
- InputMediaAudio
|
||||||
- InputMediaPhoto
|
- InputMediaPhoto
|
||||||
- InputMediaVideo
|
- InputMediaVideo
|
||||||
|
|
||||||
|
|
@ -20,36 +23,76 @@ class InputMedia(base.TelegramObject):
|
||||||
https://core.telegram.org/bots/api#inputmedia
|
https://core.telegram.org/bots/api#inputmedia
|
||||||
"""
|
"""
|
||||||
type: base.String = fields.Field(default='photo')
|
type: base.String = fields.Field(default='photo')
|
||||||
media: base.String = fields.Field()
|
media: base.String = fields.Field(alias='media', on_change='_media_changed')
|
||||||
thumb: typing.Union[base.InputFile, base.String] = fields.Field()
|
thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed')
|
||||||
caption: base.String = fields.Field()
|
caption: base.String = fields.Field()
|
||||||
parse_mode: base.Boolean = fields.Field()
|
parse_mode: base.Boolean = fields.Field()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._thumb_file = None
|
||||||
|
self._media_file = None
|
||||||
|
|
||||||
|
media = kwargs.pop('media', None)
|
||||||
|
if isinstance(media, (io.IOBase, InputFile)):
|
||||||
|
self.file = media
|
||||||
|
elif media is not None:
|
||||||
|
self.media = media
|
||||||
|
|
||||||
|
thumb = kwargs.pop('thumb', None)
|
||||||
|
if isinstance(thumb, (io.IOBase, InputFile)):
|
||||||
|
self.thumb_file = thumb
|
||||||
|
elif thumb is not None:
|
||||||
|
self.thumb = thumb
|
||||||
|
|
||||||
super(InputMedia, self).__init__(*args, **kwargs)
|
super(InputMedia, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.parse_mode is None and self.bot.parse_mode:
|
if self.parse_mode is None and self.bot and self.bot.parse_mode:
|
||||||
self.parse_mode = self.bot.parse_mode
|
self.parse_mode = self.bot.parse_mode
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def file(self):
|
def file(self):
|
||||||
return getattr(self, '_file', None)
|
return self._media_file
|
||||||
|
|
||||||
@file.setter
|
@file.setter
|
||||||
def file(self, file: io.IOBase):
|
def file(self, file: io.IOBase):
|
||||||
setattr(self, '_file', file)
|
self.media = 'attach://' + secrets.token_urlsafe(16)
|
||||||
attachment_key = self.attachment_key = secrets.token_urlsafe(16)
|
self._media_file = file
|
||||||
self.media = ATTACHMENT_PREFIX + attachment_key
|
|
||||||
|
@file.deleter
|
||||||
|
def file(self):
|
||||||
|
self.media = None
|
||||||
|
self._media_file = None
|
||||||
|
|
||||||
|
def _media_changed(self, value):
|
||||||
|
if value is None or isinstance(value, str) and not value.startswith('attach://'):
|
||||||
|
self._media_file = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attachment_key(self):
|
def thumb_file(self):
|
||||||
return self.conf.get('attachment_key', None)
|
return self._thumb_file
|
||||||
|
|
||||||
@attachment_key.setter
|
@thumb_file.setter
|
||||||
def attachment_key(self, value):
|
def thumb_file(self, file: io.IOBase):
|
||||||
self.conf['attachment_key'] = value
|
self.thumb = 'attach://' + secrets.token_urlsafe(16)
|
||||||
|
self._thumb_file = file
|
||||||
|
|
||||||
|
@thumb_file.deleter
|
||||||
|
def thumb_file(self):
|
||||||
|
self.thumb = None
|
||||||
|
self._thumb_file = None
|
||||||
|
|
||||||
|
def _thumb_changed(self, value):
|
||||||
|
if value is None or isinstance(value, str) and not value.startswith('attach://'):
|
||||||
|
self._thumb_file = None
|
||||||
|
|
||||||
|
def get_files(self):
|
||||||
|
if self._media_file:
|
||||||
|
yield self.media[9:], self._media_file
|
||||||
|
if self._thumb_file:
|
||||||
|
yield self.thumb[9:], self._thumb_file
|
||||||
|
|
||||||
|
|
||||||
class InputMediaAnimation(InputMedia):
|
class InputMediaAnimation(InputMedia):
|
||||||
|
|
@ -72,9 +115,6 @@ class InputMediaAnimation(InputMedia):
|
||||||
width=width, height=height, duration=duration,
|
width=width, height=height, duration=duration,
|
||||||
parse_mode=parse_mode, conf=kwargs)
|
parse_mode=parse_mode, conf=kwargs)
|
||||||
|
|
||||||
if isinstance(media, (io.IOBase, InputFile)):
|
|
||||||
self.file = media
|
|
||||||
|
|
||||||
|
|
||||||
class InputMediaDocument(InputMedia):
|
class InputMediaDocument(InputMedia):
|
||||||
"""
|
"""
|
||||||
|
|
@ -89,9 +129,6 @@ class InputMediaDocument(InputMedia):
|
||||||
caption=caption, parse_mode=parse_mode,
|
caption=caption, parse_mode=parse_mode,
|
||||||
conf=kwargs)
|
conf=kwargs)
|
||||||
|
|
||||||
if isinstance(media, (io.IOBase, InputFile)):
|
|
||||||
self.file = media
|
|
||||||
|
|
||||||
|
|
||||||
class InputMediaAudio(InputMedia):
|
class InputMediaAudio(InputMedia):
|
||||||
"""
|
"""
|
||||||
|
|
@ -119,9 +156,6 @@ class InputMediaAudio(InputMedia):
|
||||||
performer=performer, title=title,
|
performer=performer, title=title,
|
||||||
parse_mode=parse_mode, conf=kwargs)
|
parse_mode=parse_mode, conf=kwargs)
|
||||||
|
|
||||||
if isinstance(media, (io.IOBase, InputFile)):
|
|
||||||
self.file = media
|
|
||||||
|
|
||||||
|
|
||||||
class InputMediaPhoto(InputMedia):
|
class InputMediaPhoto(InputMedia):
|
||||||
"""
|
"""
|
||||||
|
|
@ -136,9 +170,6 @@ class InputMediaPhoto(InputMedia):
|
||||||
caption=caption, parse_mode=parse_mode,
|
caption=caption, parse_mode=parse_mode,
|
||||||
conf=kwargs)
|
conf=kwargs)
|
||||||
|
|
||||||
if isinstance(media, (io.IOBase, InputFile)):
|
|
||||||
self.file = media
|
|
||||||
|
|
||||||
|
|
||||||
class InputMediaVideo(InputMedia):
|
class InputMediaVideo(InputMedia):
|
||||||
"""
|
"""
|
||||||
|
|
@ -151,18 +182,17 @@ class InputMediaVideo(InputMedia):
|
||||||
duration: base.Integer = fields.Field()
|
duration: base.Integer = fields.Field()
|
||||||
supports_streaming: base.Boolean = fields.Field()
|
supports_streaming: base.Boolean = fields.Field()
|
||||||
|
|
||||||
def __init__(self, media: base.InputFile, caption: base.String = None,
|
def __init__(self, media: base.InputFile,
|
||||||
|
thumb: typing.Union[base.InputFile, base.String] = None,
|
||||||
|
caption: base.String = None,
|
||||||
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
|
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
|
||||||
parse_mode: base.Boolean = None,
|
parse_mode: base.Boolean = None,
|
||||||
supports_streaming: base.Boolean = None, **kwargs):
|
supports_streaming: base.Boolean = None, **kwargs):
|
||||||
super(InputMediaVideo, self).__init__(type='video', media=media, caption=caption,
|
super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption,
|
||||||
width=width, height=height, duration=duration,
|
width=width, height=height, duration=duration,
|
||||||
parse_mode=parse_mode,
|
parse_mode=parse_mode,
|
||||||
supports_streaming=supports_streaming, conf=kwargs)
|
supports_streaming=supports_streaming, conf=kwargs)
|
||||||
|
|
||||||
if isinstance(media, (io.IOBase, InputFile)):
|
|
||||||
self.file = media
|
|
||||||
|
|
||||||
|
|
||||||
class MediaGroup(base.TelegramObject):
|
class MediaGroup(base.TelegramObject):
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,6 +326,7 @@ class MediaGroup(base.TelegramObject):
|
||||||
self.attach(photo)
|
self.attach(photo)
|
||||||
|
|
||||||
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
|
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
|
||||||
|
thumb: typing.Union[base.InputFile, base.String] = None,
|
||||||
caption: base.String = None,
|
caption: base.String = None,
|
||||||
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
|
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -308,7 +339,7 @@ class MediaGroup(base.TelegramObject):
|
||||||
:param duration:
|
:param duration:
|
||||||
"""
|
"""
|
||||||
if not isinstance(video, InputMedia):
|
if not isinstance(video, InputMedia):
|
||||||
video = InputMediaVideo(media=video, caption=caption,
|
video = InputMediaVideo(media=video, thumb=thumb, caption=caption,
|
||||||
width=width, height=height, duration=duration)
|
width=width, height=height, duration=duration)
|
||||||
self.attach(video)
|
self.attach(video)
|
||||||
|
|
||||||
|
|
@ -327,6 +358,7 @@ class MediaGroup(base.TelegramObject):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_files(self):
|
def get_files(self):
|
||||||
return {inputmedia.attachment_key: inputmedia.file
|
for inputmedia in self.media:
|
||||||
for inputmedia in self.media
|
if not isinstance(inputmedia, InputMedia) or not inputmedia.file:
|
||||||
if isinstance(inputmedia, InputMedia) and inputmedia.file}
|
continue
|
||||||
|
yield from inputmedia.get_files()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -42,7 +44,7 @@ class Message(base.TelegramObject):
|
||||||
forward_from_message_id: base.Integer = fields.Field()
|
forward_from_message_id: base.Integer = fields.Field()
|
||||||
forward_signature: base.String = fields.Field()
|
forward_signature: base.String = fields.Field()
|
||||||
forward_date: datetime.datetime = fields.DateTimeField()
|
forward_date: datetime.datetime = fields.DateTimeField()
|
||||||
reply_to_message: 'Message' = fields.Field(base='Message')
|
reply_to_message: Message = fields.Field(base='Message')
|
||||||
edit_date: datetime.datetime = fields.DateTimeField()
|
edit_date: datetime.datetime = fields.DateTimeField()
|
||||||
media_group_id: base.String = fields.Field()
|
media_group_id: base.String = fields.Field()
|
||||||
author_signature: base.String = fields.Field()
|
author_signature: base.String = fields.Field()
|
||||||
|
|
@ -72,7 +74,7 @@ class Message(base.TelegramObject):
|
||||||
channel_chat_created: base.Boolean = fields.Field()
|
channel_chat_created: base.Boolean = fields.Field()
|
||||||
migrate_to_chat_id: base.Integer = fields.Field()
|
migrate_to_chat_id: base.Integer = fields.Field()
|
||||||
migrate_from_chat_id: base.Integer = fields.Field()
|
migrate_from_chat_id: base.Integer = fields.Field()
|
||||||
pinned_message: 'Message' = fields.Field(base='Message')
|
pinned_message: Message = fields.Field(base='Message')
|
||||||
invoice: Invoice = fields.Field(base=Invoice)
|
invoice: Invoice = fields.Field(base=Invoice)
|
||||||
successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment)
|
successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment)
|
||||||
connected_website: base.String = fields.Field()
|
connected_website: base.String = fields.Field()
|
||||||
|
|
@ -82,59 +84,59 @@ class Message(base.TelegramObject):
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def content_type(self):
|
def content_type(self):
|
||||||
if self.text:
|
if self.text:
|
||||||
return ContentType.TEXT[0]
|
return ContentType.TEXT
|
||||||
elif self.audio:
|
elif self.audio:
|
||||||
return ContentType.AUDIO[0]
|
return ContentType.AUDIO
|
||||||
elif self.animation:
|
elif self.animation:
|
||||||
return ContentType.ANIMATION[0]
|
return ContentType.ANIMATION
|
||||||
elif self.document:
|
elif self.document:
|
||||||
return ContentType.DOCUMENT[0]
|
return ContentType.DOCUMENT
|
||||||
elif self.game:
|
elif self.game:
|
||||||
return ContentType.GAME[0]
|
return ContentType.GAME
|
||||||
elif self.photo:
|
elif self.photo:
|
||||||
return ContentType.PHOTO[0]
|
return ContentType.PHOTO
|
||||||
elif self.sticker:
|
elif self.sticker:
|
||||||
return ContentType.STICKER[0]
|
return ContentType.STICKER
|
||||||
elif self.video:
|
elif self.video:
|
||||||
return ContentType.VIDEO[0]
|
return ContentType.VIDEO
|
||||||
elif self.video_note:
|
elif self.video_note:
|
||||||
return ContentType.VIDEO_NOTE[0]
|
return ContentType.VIDEO_NOTE
|
||||||
elif self.voice:
|
elif self.voice:
|
||||||
return ContentType.VOICE[0]
|
return ContentType.VOICE
|
||||||
elif self.contact:
|
elif self.contact:
|
||||||
return ContentType.CONTACT[0]
|
return ContentType.CONTACT
|
||||||
elif self.venue:
|
elif self.venue:
|
||||||
return ContentType.VENUE[0]
|
return ContentType.VENUE
|
||||||
elif self.location:
|
elif self.location:
|
||||||
return ContentType.LOCATION[0]
|
return ContentType.LOCATION
|
||||||
elif self.new_chat_members:
|
elif self.new_chat_members:
|
||||||
return ContentType.NEW_CHAT_MEMBERS[0]
|
return ContentType.NEW_CHAT_MEMBERS
|
||||||
elif self.left_chat_member:
|
elif self.left_chat_member:
|
||||||
return ContentType.LEFT_CHAT_MEMBER[0]
|
return ContentType.LEFT_CHAT_MEMBER
|
||||||
elif self.invoice:
|
elif self.invoice:
|
||||||
return ContentType.INVOICE[0]
|
return ContentType.INVOICE
|
||||||
elif self.successful_payment:
|
elif self.successful_payment:
|
||||||
return ContentType.SUCCESSFUL_PAYMENT[0]
|
return ContentType.SUCCESSFUL_PAYMENT
|
||||||
elif self.connected_website:
|
elif self.connected_website:
|
||||||
return ContentType.CONNECTED_WEBSITE[0]
|
return ContentType.CONNECTED_WEBSITE
|
||||||
elif self.migrate_from_chat_id:
|
elif self.migrate_from_chat_id:
|
||||||
return ContentType.MIGRATE_FROM_CHAT_ID[0]
|
return ContentType.MIGRATE_FROM_CHAT_ID
|
||||||
elif self.migrate_to_chat_id:
|
elif self.migrate_to_chat_id:
|
||||||
return ContentType.MIGRATE_TO_CHAT_ID[0]
|
return ContentType.MIGRATE_TO_CHAT_ID
|
||||||
elif self.pinned_message:
|
elif self.pinned_message:
|
||||||
return ContentType.PINNED_MESSAGE[0]
|
return ContentType.PINNED_MESSAGE
|
||||||
elif self.new_chat_title:
|
elif self.new_chat_title:
|
||||||
return ContentType.NEW_CHAT_TITLE[0]
|
return ContentType.NEW_CHAT_TITLE
|
||||||
elif self.new_chat_photo:
|
elif self.new_chat_photo:
|
||||||
return ContentType.NEW_CHAT_PHOTO[0]
|
return ContentType.NEW_CHAT_PHOTO
|
||||||
elif self.delete_chat_photo:
|
elif self.delete_chat_photo:
|
||||||
return ContentType.DELETE_CHAT_PHOTO[0]
|
return ContentType.DELETE_CHAT_PHOTO
|
||||||
elif self.group_chat_created:
|
elif self.group_chat_created:
|
||||||
return ContentType.GROUP_CHAT_CREATED[0]
|
return ContentType.GROUP_CHAT_CREATED
|
||||||
elif self.passport_data:
|
elif self.passport_data:
|
||||||
return ContentType.PASSPORT_DATA[0]
|
return ContentType.PASSPORT_DATA
|
||||||
else:
|
else:
|
||||||
return ContentType.UNKNOWN[0]
|
return ContentType.UNKNOWN
|
||||||
|
|
||||||
def is_command(self):
|
def is_command(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -239,7 +241,7 @@ class Message(base.TelegramObject):
|
||||||
return self.parse_entities()
|
return self.parse_entities()
|
||||||
|
|
||||||
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
||||||
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
|
disable_notification=None, reply_markup=None, reply=True) -> Message:
|
||||||
"""
|
"""
|
||||||
Reply to this message
|
Reply to this message
|
||||||
|
|
||||||
|
|
@ -729,6 +731,69 @@ class ContentType(helper.Helper):
|
||||||
"""
|
"""
|
||||||
List of message content types
|
List of message content types
|
||||||
|
|
||||||
|
WARNING: Single elements
|
||||||
|
|
||||||
|
:key: TEXT
|
||||||
|
:key: AUDIO
|
||||||
|
:key: DOCUMENT
|
||||||
|
:key: GAME
|
||||||
|
:key: PHOTO
|
||||||
|
:key: STICKER
|
||||||
|
:key: VIDEO
|
||||||
|
:key: VIDEO_NOTE
|
||||||
|
:key: VOICE
|
||||||
|
:key: CONTACT
|
||||||
|
:key: LOCATION
|
||||||
|
:key: VENUE
|
||||||
|
:key: NEW_CHAT_MEMBERS
|
||||||
|
:key: LEFT_CHAT_MEMBER
|
||||||
|
:key: INVOICE
|
||||||
|
:key: SUCCESSFUL_PAYMENT
|
||||||
|
:key: CONNECTED_WEBSITE
|
||||||
|
:key: MIGRATE_TO_CHAT_ID
|
||||||
|
:key: MIGRATE_FROM_CHAT_ID
|
||||||
|
:key: UNKNOWN
|
||||||
|
:key: ANY
|
||||||
|
"""
|
||||||
|
mode = helper.HelperMode.snake_case
|
||||||
|
|
||||||
|
TEXT = helper.Item() # text
|
||||||
|
AUDIO = helper.Item() # audio
|
||||||
|
DOCUMENT = helper.Item() # document
|
||||||
|
ANIMATION = helper.Item() # animation
|
||||||
|
GAME = helper.Item() # game
|
||||||
|
PHOTO = helper.Item() # photo
|
||||||
|
STICKER = helper.Item() # sticker
|
||||||
|
VIDEO = helper.Item() # video
|
||||||
|
VIDEO_NOTE = helper.Item() # video_note
|
||||||
|
VOICE = helper.Item() # voice
|
||||||
|
CONTACT = helper.Item() # contact
|
||||||
|
LOCATION = helper.Item() # location
|
||||||
|
VENUE = helper.Item() # venue
|
||||||
|
NEW_CHAT_MEMBERS = helper.Item() # new_chat_member
|
||||||
|
LEFT_CHAT_MEMBER = helper.Item() # left_chat_member
|
||||||
|
INVOICE = helper.Item() # invoice
|
||||||
|
SUCCESSFUL_PAYMENT = helper.Item() # successful_payment
|
||||||
|
CONNECTED_WEBSITE = helper.Item() # connected_website
|
||||||
|
MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id
|
||||||
|
MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id
|
||||||
|
PINNED_MESSAGE = helper.Item() # pinned_message
|
||||||
|
NEW_CHAT_TITLE = helper.Item() # new_chat_title
|
||||||
|
NEW_CHAT_PHOTO = helper.Item() # new_chat_photo
|
||||||
|
DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo
|
||||||
|
GROUP_CHAT_CREATED = helper.Item() # group_chat_created
|
||||||
|
PASSPORT_DATA = helper.Item() # passport_data
|
||||||
|
|
||||||
|
UNKNOWN = helper.Item() # unknown
|
||||||
|
ANY = helper.Item() # any
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypes(helper.Helper):
|
||||||
|
"""
|
||||||
|
List of message content types
|
||||||
|
|
||||||
|
WARNING: List elements.
|
||||||
|
|
||||||
:key: TEXT
|
:key: TEXT
|
||||||
:key: AUDIO
|
:key: AUDIO
|
||||||
:key: DOCUMENT
|
:key: DOCUMENT
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
from . import fields
|
from . import fields
|
||||||
from .callback_query import CallbackQuery
|
from .callback_query import CallbackQuery
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import babel
|
import babel
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
|
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
"""
|
|
||||||
You need to setup task factory:
|
|
||||||
>>> from aiogram.utils import context
|
|
||||||
>>> loop = asyncio.get_event_loop()
|
|
||||||
>>> loop.set_task_factory(context.task_factory)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import typing
|
|
||||||
|
|
||||||
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
|
|
||||||
|
|
||||||
|
|
||||||
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
|
||||||
"""
|
|
||||||
Task factory for implementing context processor
|
|
||||||
|
|
||||||
:param loop:
|
|
||||||
:param coro:
|
|
||||||
:return: new task
|
|
||||||
:rtype: :obj:`asyncio.Task`
|
|
||||||
"""
|
|
||||||
# Is not allowed when loop is closed.
|
|
||||||
if loop.is_closed():
|
|
||||||
raise RuntimeError('Event loop is closed.')
|
|
||||||
|
|
||||||
task = asyncio.Task(coro, loop=loop)
|
|
||||||
|
|
||||||
# Hide factory
|
|
||||||
if task._source_traceback:
|
|
||||||
del task._source_traceback[-1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
task.context = asyncio.Task.current_task().context.copy()
|
|
||||||
except AttributeError:
|
|
||||||
task.context = {CONFIGURED: True}
|
|
||||||
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_state() -> typing.Dict:
|
|
||||||
"""
|
|
||||||
Get current execution context from task
|
|
||||||
|
|
||||||
:return: context
|
|
||||||
:rtype: :obj:`dict`
|
|
||||||
"""
|
|
||||||
task = asyncio.Task.current_task()
|
|
||||||
if task is None:
|
|
||||||
raise RuntimeError('Can be used only in Task context.')
|
|
||||||
context_ = getattr(task, 'context', None)
|
|
||||||
if context_ is None:
|
|
||||||
context_ = task.context = {}
|
|
||||||
return context_
|
|
||||||
|
|
||||||
|
|
||||||
def get_value(key, default=None):
|
|
||||||
"""
|
|
||||||
Get value from task
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param default:
|
|
||||||
:return: value
|
|
||||||
"""
|
|
||||||
return get_current_state().get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def check_value(key):
|
|
||||||
"""
|
|
||||||
Key in context?
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return key in get_current_state()
|
|
||||||
|
|
||||||
|
|
||||||
def set_value(key, value):
|
|
||||||
"""
|
|
||||||
Set value
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param value:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
get_current_state()[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def del_value(key):
|
|
||||||
"""
|
|
||||||
Remove value from context
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
del get_current_state()[key]
|
|
||||||
|
|
||||||
|
|
||||||
def update_state(data=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Update multiple state items
|
|
||||||
|
|
||||||
:param data:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
data = {}
|
|
||||||
state = get_current_state()
|
|
||||||
state.update(data, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def check_configured():
|
|
||||||
"""
|
|
||||||
Check loop is configured
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return get_value(CONFIGURED)
|
|
||||||
|
|
||||||
|
|
||||||
class _Context:
|
|
||||||
"""
|
|
||||||
Other things for interactions with the execution context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return get_value(item)
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
set_value(key, value)
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del_value(key)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_context():
|
|
||||||
return get_current_state()
|
|
||||||
|
|
||||||
|
|
||||||
context = _Context()
|
|
||||||
|
|
@ -182,6 +182,10 @@ class MessageToEditNotFound(MessageError):
|
||||||
match = 'message to edit not found'
|
match = 'message to edit not found'
|
||||||
|
|
||||||
|
|
||||||
|
class MessageIsTooLong(MessageError):
|
||||||
|
match = 'message is too long'
|
||||||
|
|
||||||
|
|
||||||
class ToMuchMessages(MessageError):
|
class ToMuchMessages(MessageError):
|
||||||
"""
|
"""
|
||||||
Will be raised when you try to send media group with more than 10 items.
|
Will be raised when you try to send media group with more than 10 items.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from warnings import warn
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from . import context
|
|
||||||
from ..bot.api import log
|
from ..bot.api import log
|
||||||
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
||||||
|
|
||||||
|
|
@ -104,6 +103,11 @@ class Executor:
|
||||||
|
|
||||||
self._freeze = False
|
self._freeze = False
|
||||||
|
|
||||||
|
from aiogram.bot.bot import bot as ctx_bot
|
||||||
|
from aiogram.dispatcher import dispatcher as ctx_dp
|
||||||
|
ctx_bot.set(dispatcher.bot)
|
||||||
|
ctx_dp.set(dispatcher)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def frozen(self):
|
def frozen(self):
|
||||||
return self._freeze
|
return self._freeze
|
||||||
|
|
@ -176,13 +180,13 @@ class Executor:
|
||||||
self._check_frozen()
|
self._check_frozen()
|
||||||
self._freeze = True
|
self._freeze = True
|
||||||
|
|
||||||
self.loop.set_task_factory(context.task_factory)
|
# self.loop.set_task_factory(context.task_factory)
|
||||||
|
|
||||||
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
||||||
self._check_frozen()
|
self._check_frozen()
|
||||||
self._freeze = True
|
self._freeze = True
|
||||||
|
|
||||||
self.loop.set_task_factory(context.task_factory)
|
# self.loop.set_task_factory(context.task_factory)
|
||||||
|
|
||||||
app = self._web_app
|
app = self._web_app
|
||||||
if app is None:
|
if app is None:
|
||||||
|
|
@ -203,6 +207,7 @@ class Executor:
|
||||||
|
|
||||||
for callback in self._on_startup_webhook:
|
for callback in self._on_startup_webhook:
|
||||||
app.on_startup.append(functools.partial(_wrap_callback, callback))
|
app.on_startup.append(functools.partial(_wrap_callback, callback))
|
||||||
|
|
||||||
# for callback in self._on_shutdown_webhook:
|
# for callback in self._on_shutdown_webhook:
|
||||||
# app.on_shutdown.append(functools.partial(_wrap_callback, callback))
|
# app.on_shutdown.append(functools.partial(_wrap_callback, callback))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,52 @@
|
||||||
import json
|
import os
|
||||||
|
|
||||||
|
JSON = 'json'
|
||||||
|
RAPIDJSON = 'rapidjson'
|
||||||
|
UJSON = 'ujson'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ujson
|
if 'DISABLE_UJSON' not in os.environ:
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
|
mode = UJSON
|
||||||
|
|
||||||
|
|
||||||
|
def dumps(data):
|
||||||
|
return json.dumps(data, ensure_ascii=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
mode = JSON
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ujson = None
|
mode = JSON
|
||||||
|
|
||||||
_use_ujson = True if ujson else False
|
try:
|
||||||
|
if 'DISABLE_RAPIDJSON' not in os.environ:
|
||||||
|
import rapidjson as json
|
||||||
|
|
||||||
|
mode = RAPIDJSON
|
||||||
|
|
||||||
|
|
||||||
def disable_ujson():
|
def dumps(data):
|
||||||
global _use_ujson
|
return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE,
|
||||||
_use_ujson = False
|
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
|
||||||
|
|
||||||
|
|
||||||
def dumps(data):
|
def loads(data):
|
||||||
if _use_ujson:
|
return json.loads(data, number_mode=json.NM_NATIVE,
|
||||||
return ujson.dumps(data)
|
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
|
||||||
return json.dumps(data)
|
|
||||||
|
else:
|
||||||
|
mode = JSON
|
||||||
|
except ImportError:
|
||||||
|
mode = JSON
|
||||||
|
|
||||||
|
if mode == JSON:
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
def loads(data):
|
def dumps(data):
|
||||||
if _use_ujson:
|
return json.dumps(data, ensure_ascii=False)
|
||||||
return ujson.loads(data)
|
|
||||||
return json.loads(data)
|
|
||||||
|
def loads(data):
|
||||||
|
return json.loads(data)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
from . import json
|
from . import json
|
||||||
|
|
||||||
DEFAULT_FILTER = ['self', 'cls']
|
DEFAULT_FILTER = ['self', 'cls']
|
||||||
|
|
@ -56,3 +58,22 @@ def prepare_arg(value):
|
||||||
elif isinstance(value, datetime.datetime):
|
elif isinstance(value, datetime.datetime):
|
||||||
return round(value.timestamp())
|
return round(value.timestamp())
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_file(payload, files, key, file):
|
||||||
|
if isinstance(file, str):
|
||||||
|
payload[key] = file
|
||||||
|
elif file is not None:
|
||||||
|
files[key] = file
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_attachment(payload, files, key, file):
|
||||||
|
if isinstance(file, str):
|
||||||
|
payload[key] = file
|
||||||
|
elif isinstance(file, types.InputFile):
|
||||||
|
payload[key] = file.attach
|
||||||
|
files[file.attachment_key] = file.file
|
||||||
|
elif file is not None:
|
||||||
|
file_attach_name = secrets.token_urlsafe(16)
|
||||||
|
payload[key] = "attach://" + file_attach_name
|
||||||
|
files[file_attach_name] = file
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
||||||
ujson>=1.35
|
ujson>=1.35
|
||||||
|
python-rapidjson>=0.6.3
|
||||||
emoji>=0.5.0
|
emoji>=0.5.0
|
||||||
pytest>=3.5.0
|
pytest>=3.5.0
|
||||||
pytest-asyncio>=0.8.0
|
pytest-asyncio>=0.8.0
|
||||||
|
|
|
||||||
|
|
@ -25,16 +25,16 @@ Next step: interaction with bots starts with one command. Register your first co
|
||||||
|
|
||||||
.. code-block:: python3
|
.. code-block:: python3
|
||||||
|
|
||||||
@dp.message_handler(commands=['start', 'help'])
|
@dp.message_handler(commands=['start', 'help'])
|
||||||
async def send_welcome(message: types.Message):
|
async def send_welcome(message: types.Message):
|
||||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||||
|
|
||||||
Last step: run long polling.
|
Last step: run long polling.
|
||||||
|
|
||||||
.. code-block:: python3
|
.. code-block:: python3
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
executor.start_polling(dp)
|
executor.start_polling(dp)
|
||||||
|
|
||||||
Summary
|
Summary
|
||||||
-------
|
-------
|
||||||
|
|
@ -48,9 +48,9 @@ Summary
|
||||||
bot = Bot(token='BOT TOKEN HERE')
|
bot = Bot(token='BOT TOKEN HERE')
|
||||||
dp = Dispatcher(bot)
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
@dp.message_handler(commands=['start', 'help'])
|
@dp.message_handler(commands=['start', 'help'])
|
||||||
async def send_welcome(message: types.Message):
|
async def send_welcome(message: types.Message):
|
||||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
executor.start_polling(dp)
|
executor.start_polling(dp)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
name: py36
|
name: py37
|
||||||
channels:
|
channels:
|
||||||
- conda-forge
|
- conda-forge
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.6
|
- python=3.7
|
||||||
- sphinx=1.5.3
|
- sphinx=1.5.3
|
||||||
- sphinx_rtd_theme=0.2.4
|
- sphinx_rtd_theme=0.2.4
|
||||||
- pip
|
- pip
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ async def send_message(user_id: int, text: str, disable_notification: bool = Fal
|
||||||
|
|
||||||
:param user_id:
|
:param user_id:
|
||||||
:param text:
|
:param text:
|
||||||
|
:param disable_notification:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,14 @@ Babel is required.
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, Dispatcher, executor, md, types
|
||||||
from aiogram.dispatcher import Dispatcher
|
|
||||||
from aiogram.types import ParseMode
|
|
||||||
from aiogram.utils.executor import start_polling
|
|
||||||
from aiogram.utils.markdown import *
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
bot = Bot(token=API_TOKEN, loop=loop)
|
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.MARKDOWN)
|
||||||
dp = Dispatcher(bot)
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,14 +20,14 @@ dp = Dispatcher(bot)
|
||||||
async def check_language(message: types.Message):
|
async def check_language(message: types.Message):
|
||||||
locale = message.from_user.locale
|
locale = message.from_user.locale
|
||||||
|
|
||||||
await message.reply(text(
|
await message.reply(md.text(
|
||||||
bold('Info about your language:'),
|
md.bold('Info about your language:'),
|
||||||
text(' 🔸', bold('Code:'), italic(locale.locale)),
|
md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)),
|
||||||
text(' 🔸', bold('Territory:'), italic(locale.territory or 'Unknown')),
|
md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')),
|
||||||
text(' 🔸', bold('Language name:'), italic(locale.language_name)),
|
md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)),
|
||||||
text(' 🔸', bold('English language name:'), italic(locale.english_name)),
|
md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)),
|
||||||
sep='\n'), parse_mode=ParseMode.MARKDOWN)
|
sep='\n'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
start_polling(dp, loop=loop, skip_updates=True)
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, types, Dispatcher, executor
|
||||||
from aiogram.dispatcher import Dispatcher
|
|
||||||
from aiogram.utils.executor import start_polling
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -32,10 +30,4 @@ async def echo(message: types.Message):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
start_polling(dp, loop=loop, skip_updates=True)
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
|
|
||||||
# Also you can use another execution method
|
|
||||||
# >>> try:
|
|
||||||
# >>> loop.run_until_complete(main())
|
|
||||||
# >>> except KeyboardInterrupt:
|
|
||||||
# >>> loop.stop()
|
|
||||||
|
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
from aiogram import Bot, types
|
|
||||||
from aiogram.contrib.middlewares.context import ContextMiddleware
|
|
||||||
from aiogram.dispatcher import Dispatcher
|
|
||||||
from aiogram.types import ParseMode
|
|
||||||
from aiogram.utils import markdown as md
|
|
||||||
from aiogram.utils.executor import start_polling
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
|
||||||
|
|
||||||
bot = Bot(token=API_TOKEN)
|
|
||||||
dp = Dispatcher(bot)
|
|
||||||
|
|
||||||
# Setup Context middleware
|
|
||||||
data: ContextMiddleware = dp.middleware.setup(ContextMiddleware())
|
|
||||||
|
|
||||||
|
|
||||||
# Write custom filter
|
|
||||||
async def demo_filter(message: types.Message):
|
|
||||||
# Store some data in context
|
|
||||||
command = data['command'] = message.get_command() or ''
|
|
||||||
args = data['args'] = message.get_args() or ''
|
|
||||||
data['has_args'] = bool(args)
|
|
||||||
data['some_random_data'] = 42
|
|
||||||
return command != '/bad_command'
|
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(demo_filter)
|
|
||||||
async def send_welcome(message: types.Message):
|
|
||||||
# Get data from context
|
|
||||||
# All of this is available only in current context and from current update object
|
|
||||||
# `data`- pseudo-alias for `ctx.get_update().conf['_context_data']`
|
|
||||||
command = data['command']
|
|
||||||
args = data['args']
|
|
||||||
rand = data['some_random_data']
|
|
||||||
has_args = data['has_args']
|
|
||||||
|
|
||||||
# Send as pre-formatted code block.
|
|
||||||
await message.reply(md.hpre(f"""command: {command}
|
|
||||||
args: {['Not available', 'available'][has_args]}: {args}
|
|
||||||
some random data: {rand}
|
|
||||||
message ID: {message.message_id}
|
|
||||||
message: {message.html_text}
|
|
||||||
"""), parse_mode=ParseMode.HTML)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
start_polling(dp)
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from aiogram import Bot, types
|
import aiogram.utils.markdown as md
|
||||||
|
from aiogram import Bot, Dispatcher, types
|
||||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
from aiogram.dispatcher import Dispatcher
|
from aiogram.dispatcher import FSMContext
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
from aiogram.types import ParseMode
|
from aiogram.types import ParseMode
|
||||||
from aiogram.utils import executor
|
from aiogram.utils import executor
|
||||||
from aiogram.utils.markdown import text, bold
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
storage = MemoryStorage()
|
storage = MemoryStorage()
|
||||||
dp = Dispatcher(bot, storage=storage)
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
|
||||||
# States
|
# States
|
||||||
AGE = 'process_age'
|
class Form(StatesGroup):
|
||||||
NAME = 'process_name'
|
name = State() # Will be represented in storage as 'Form:name'
|
||||||
GENDER = 'process_gender'
|
age = State() # Will be represented in storage as 'Form:age'
|
||||||
|
gender = State() # Will be represented in storage as 'Form:gender'
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(commands=['start'])
|
@dp.message_handler(commands=['start'])
|
||||||
|
|
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
|
||||||
"""
|
"""
|
||||||
Conversation's entry point
|
Conversation's entry point
|
||||||
"""
|
"""
|
||||||
# Get current state
|
# Set state
|
||||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
await Form.name.set()
|
||||||
# Update user's state
|
|
||||||
await state.set_state(NAME)
|
|
||||||
|
|
||||||
await message.reply("Hi there! What's your name?")
|
await message.reply("Hi there! What's your name?")
|
||||||
|
|
||||||
|
|
||||||
# You can use state '*' if you need to handle all states
|
# You can use state '*' if you need to handle all states
|
||||||
@dp.message_handler(state='*', commands=['cancel'])
|
@dp.message_handler(state='*', commands=['cancel'])
|
||||||
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel')
|
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||||
async def cancel_handler(message: types.Message):
|
async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Allow user to cancel any action
|
Allow user to cancel any action
|
||||||
"""
|
"""
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
if raw_state is None:
|
||||||
# Ignore command if user is not in any (defined) state
|
return
|
||||||
if await state.get_state() is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Otherwise cancel state and inform user about it
|
# Cancel state and inform user about it
|
||||||
# And remove keyboard (just in case)
|
await state.finish()
|
||||||
await state.reset_state(with_data=True)
|
# And remove keyboard (just in case)
|
||||||
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=NAME)
|
@dp.message_handler(state=Form.name)
|
||||||
async def process_name(message: types.Message):
|
async def process_name(message: types.Message, state: FSMContext):
|
||||||
"""
|
"""
|
||||||
Process user name
|
Process user name
|
||||||
"""
|
"""
|
||||||
# Save name to storage and go to next step
|
await Form.next()
|
||||||
# You can use context manager
|
await state.update_data(name=message.text)
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
|
||||||
await state.update_data(name=message.text)
|
|
||||||
await state.set_state(AGE)
|
|
||||||
|
|
||||||
await message.reply("How old are you?")
|
await message.reply("How old are you?")
|
||||||
|
|
||||||
|
|
||||||
# Check age. Age gotta be digit
|
# Check age. Age gotta be digit
|
||||||
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit())
|
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||||
async def failed_process_age(message: types.Message):
|
async def failed_process_age(message: types.Message):
|
||||||
"""
|
"""
|
||||||
If age is invalid
|
If age is invalid
|
||||||
|
|
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
|
||||||
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit())
|
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||||
async def process_age(message: types.Message):
|
async def process_age(message: types.Message, state: FSMContext):
|
||||||
# Update state and data
|
# Update state and data
|
||||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
await Form.next()
|
||||||
await state.set_state(GENDER)
|
await state.update_data(age=int(message.text))
|
||||||
await state.update_data(age=int(message.text))
|
|
||||||
|
|
||||||
# Configure ReplyKeyboardMarkup
|
# Configure ReplyKeyboardMarkup
|
||||||
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||||
|
|
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
|
||||||
await message.reply("What is your gender?", reply_markup=markup)
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"])
|
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||||
async def failed_process_gender(message: types.Message):
|
async def failed_process_gender(message: types.Message):
|
||||||
"""
|
"""
|
||||||
In this example gender has to be one of: Male, Female, Other.
|
In this example gender has to be one of: Male, Female, Other.
|
||||||
|
|
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
|
||||||
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(state=GENDER)
|
@dp.message_handler(state=Form.gender)
|
||||||
async def process_gender(message: types.Message):
|
async def process_gender(message: types.Message, state: FSMContext):
|
||||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
|
||||||
|
|
||||||
data = await state.get_data()
|
data = await state.get_data()
|
||||||
data['gender'] = message.text
|
data['gender'] = message.text
|
||||||
|
|
||||||
|
|
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
|
||||||
markup = types.ReplyKeyboardRemove()
|
markup = types.ReplyKeyboardRemove()
|
||||||
|
|
||||||
# And send message
|
# And send message
|
||||||
await bot.send_message(message.chat.id, text(
|
await bot.send_message(message.chat.id, md.text(
|
||||||
text('Hi! Nice to meet you,', bold(data['name'])),
|
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
|
||||||
text('Age:', data['age']),
|
md.text('Age:', data['age']),
|
||||||
text('Gender:', data['gender']),
|
md.text('Gender:', data['gender']),
|
||||||
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
|
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
|
||||||
|
|
||||||
# Finish conversation
|
# Finish conversation
|
||||||
|
|
@ -122,10 +116,5 @@ async def process_gender(message: types.Message):
|
||||||
await state.finish()
|
await state.finish()
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(dispatcher: Dispatcher):
|
|
||||||
await dispatcher.storage.close()
|
|
||||||
await dispatcher.storage.wait_closed()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
executor.start_polling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown)
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
|
|
|
||||||
126
examples/finite_state_machine_example_2.py
Normal file
126
examples/finite_state_machine_example_2.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
"""
|
||||||
|
This example is equals with 'finite_state_machine_example.py' but with FSM Middleware
|
||||||
|
|
||||||
|
Note that FSM Middleware implements the more simple methods for working with storage.
|
||||||
|
|
||||||
|
With that middleware all data from storage will be loaded before event will be processed
|
||||||
|
and data will be stored after processing the event.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import aiogram.utils.markdown as md
|
||||||
|
from aiogram import Bot, Dispatcher, types
|
||||||
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
|
from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
|
from aiogram.utils import executor
|
||||||
|
|
||||||
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
|
|
||||||
|
# For example use simple MemoryStorage for Dispatcher.
|
||||||
|
storage = MemoryStorage()
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
dp.middleware.setup(FSMMiddleware())
|
||||||
|
|
||||||
|
|
||||||
|
# States
|
||||||
|
class Form(StatesGroup):
|
||||||
|
name = State() # Will be represented in storage as 'Form:name'
|
||||||
|
age = State() # Will be represented in storage as 'Form:age'
|
||||||
|
gender = State() # Will be represented in storage as 'Form:gender'
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start'])
|
||||||
|
async def cmd_start(message: types.Message):
|
||||||
|
"""
|
||||||
|
Conversation's entry point
|
||||||
|
"""
|
||||||
|
# Set state
|
||||||
|
await Form.first()
|
||||||
|
|
||||||
|
await message.reply("Hi there! What's your name?")
|
||||||
|
|
||||||
|
|
||||||
|
# You can use state '*' if you need to handle all states
|
||||||
|
@dp.message_handler(state='*', commands=['cancel'])
|
||||||
|
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||||
|
async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
"""
|
||||||
|
Allow user to cancel any action
|
||||||
|
"""
|
||||||
|
if state_data.state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel state and inform user about it
|
||||||
|
del state_data.state
|
||||||
|
# And remove keyboard (just in case)
|
||||||
|
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(state=Form.name)
|
||||||
|
async def process_name(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
"""
|
||||||
|
Process user name
|
||||||
|
"""
|
||||||
|
state_data.state = Form.age
|
||||||
|
state_data['name'] = message.text
|
||||||
|
|
||||||
|
await message.reply("How old are you?")
|
||||||
|
|
||||||
|
|
||||||
|
# Check age. Age gotta be digit
|
||||||
|
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||||
|
async def failed_process_age(message: types.Message):
|
||||||
|
"""
|
||||||
|
If age is invalid
|
||||||
|
"""
|
||||||
|
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||||
|
async def process_age(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
# Update state and data
|
||||||
|
state_data.state = Form.gender
|
||||||
|
state_data['age'] = int(message.text)
|
||||||
|
|
||||||
|
# Configure ReplyKeyboardMarkup
|
||||||
|
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||||
|
markup.add("Male", "Female")
|
||||||
|
markup.add("Other")
|
||||||
|
|
||||||
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||||
|
async def failed_process_gender(message: types.Message):
|
||||||
|
"""
|
||||||
|
In this example gender has to be one of: Male, Female, Other.
|
||||||
|
"""
|
||||||
|
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(state=Form.gender)
|
||||||
|
async def process_gender(message: types.Message, state_data: FSMSStorageProxy):
|
||||||
|
state_data['gender'] = message.text
|
||||||
|
|
||||||
|
# Remove keyboard
|
||||||
|
markup = types.ReplyKeyboardRemove()
|
||||||
|
|
||||||
|
# And send message
|
||||||
|
await bot.send_message(message.chat.id, md.text(
|
||||||
|
md.text('Hi! Nice to meet you,', md.bold(state_data['name'])),
|
||||||
|
md.text('Age:', state_data['age']),
|
||||||
|
md.text('Gender:', state_data['gender']),
|
||||||
|
sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN)
|
||||||
|
|
||||||
|
# Finish conversation
|
||||||
|
# WARNING! This method will destroy all data in storage for current user!
|
||||||
|
state_data.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
56
examples/i18n_example.py
Normal file
56
examples/i18n_example.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""
|
||||||
|
Internalize your bot
|
||||||
|
|
||||||
|
Step 1: extract texts
|
||||||
|
# pybabel extract i18n_example.py -o locales/mybot.pot
|
||||||
|
Step 2: create *.po files. For e.g. create en, ru, uk locales.
|
||||||
|
# echo {en,ru,uk} | xargs -n1 pybabel init -i locales/mybot.pot -d locales -D mybot -l
|
||||||
|
Step 3: translate texts
|
||||||
|
Step 4: compile translations
|
||||||
|
# pybabel compile -d locales -D mybot
|
||||||
|
|
||||||
|
Step 5: When you change the code of your bot you need to update po & mo files.
|
||||||
|
Step 5.1: regenerate pot file:
|
||||||
|
command from step 1
|
||||||
|
Step 5.2: update po files
|
||||||
|
# pybabel update -d locales -D mybot -i locales/mybot.pot
|
||||||
|
Step 5.3: update your translations
|
||||||
|
Step 5.4: compile mo files
|
||||||
|
command from step 4
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aiogram import Bot, Dispatcher, executor, types
|
||||||
|
from aiogram.contrib.middlewares.i18n import I18nMiddleware
|
||||||
|
|
||||||
|
TOKEN = 'BOT TOKEN HERE'
|
||||||
|
I18N_DOMAIN = 'mybot'
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).parent
|
||||||
|
LOCALES_DIR = BASE_DIR / 'locales'
|
||||||
|
|
||||||
|
bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML)
|
||||||
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
|
# Setup i18n middleware
|
||||||
|
i18n = I18nMiddleware(I18N_DOMAIN, LOCALES_DIR)
|
||||||
|
dp.middleware.setup(i18n)
|
||||||
|
|
||||||
|
# Alias for gettext method
|
||||||
|
_ = i18n.gettext
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start'])
|
||||||
|
async def cmd_start(message: types.Message):
|
||||||
|
# Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts.
|
||||||
|
await message.reply(_('Hello, <b>{user}</b>!').format(user=message.from_user.full_name))
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['lang'])
|
||||||
|
async def cmd_lang(message: types.Message, locale):
|
||||||
|
await message.reply(_('Your current language: <i>{language}</i>').format(language=locale))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
executor.start_polling(dp, skip_updates=True)
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, types, Dispatcher, executor
|
||||||
from aiogram.dispatcher import Dispatcher
|
|
||||||
from aiogram.utils.executor import start_polling
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -23,4 +21,4 @@ async def inline_echo(inline_query: types.InlineQuery):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
start_polling(dp, loop=loop, skip_updates=True)
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
|
|
|
||||||
28
examples/locales/en/LC_MESSAGES/mybot.po
Normal file
28
examples/locales/en/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
# English translations for PROJECT.
|
||||||
|
# Copyright (C) 2018 ORGANIZATION
|
||||||
|
# This file is distributed under the same license as the PROJECT project.
|
||||||
|
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||||
|
#
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: PROJECT VERSION\n"
|
||||||
|
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||||
|
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||||
|
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||||
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
|
"Language: en\n"
|
||||||
|
"Language-Team: en <LL@li.org>\n"
|
||||||
|
"Plural-Forms: nplurals=2; plural=(n != 1)\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Generated-By: Babel 2.6.0\n"
|
||||||
|
|
||||||
|
#: i18n_example.py:48
|
||||||
|
msgid "Hello, <b>{user}</b>!"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: i18n_example.py:53
|
||||||
|
msgid "Your current language: <i>{language}</i>"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
27
examples/locales/mybot.pot
Normal file
27
examples/locales/mybot.pot
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Translations template for PROJECT.
|
||||||
|
# Copyright (C) 2018 ORGANIZATION
|
||||||
|
# This file is distributed under the same license as the PROJECT project.
|
||||||
|
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||||
|
#
|
||||||
|
#, fuzzy
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: PROJECT VERSION\n"
|
||||||
|
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||||
|
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||||
|
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||||
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
|
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Generated-By: Babel 2.6.0\n"
|
||||||
|
|
||||||
|
#: i18n_example.py:48
|
||||||
|
msgid "Hello, <b>{user}</b>!"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
#: i18n_example.py:53
|
||||||
|
msgid "Your current language: <i>{language}</i>"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
29
examples/locales/ru/LC_MESSAGES/mybot.po
Normal file
29
examples/locales/ru/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Russian translations for PROJECT.
|
||||||
|
# Copyright (C) 2018 ORGANIZATION
|
||||||
|
# This file is distributed under the same license as the PROJECT project.
|
||||||
|
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||||
|
#
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: PROJECT VERSION\n"
|
||||||
|
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||||
|
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||||
|
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||||
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
|
"Language: ru\n"
|
||||||
|
"Language-Team: ru <LL@li.org>\n"
|
||||||
|
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
|
||||||
|
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Generated-By: Babel 2.6.0\n"
|
||||||
|
|
||||||
|
#: i18n_example.py:48
|
||||||
|
msgid "Hello, <b>{user}</b>!"
|
||||||
|
msgstr "Привет, <b>{user}</b>!"
|
||||||
|
|
||||||
|
#: i18n_example.py:53
|
||||||
|
msgid "Your current language: <i>{language}</i>"
|
||||||
|
msgstr "Твой язык: <i>{language}</i>"
|
||||||
|
|
||||||
29
examples/locales/uk/LC_MESSAGES/mybot.po
Normal file
29
examples/locales/uk/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Ukrainian translations for PROJECT.
|
||||||
|
# Copyright (C) 2018 ORGANIZATION
|
||||||
|
# This file is distributed under the same license as the PROJECT project.
|
||||||
|
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||||
|
#
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: PROJECT VERSION\n"
|
||||||
|
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||||
|
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||||
|
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||||
|
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||||
|
"Language: uk\n"
|
||||||
|
"Language-Team: uk <LL@li.org>\n"
|
||||||
|
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
|
||||||
|
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Generated-By: Babel 2.6.0\n"
|
||||||
|
|
||||||
|
#: i18n_example.py:48
|
||||||
|
msgid "Hello, <b>{user}</b>!"
|
||||||
|
msgstr "Привіт, <b>{user}</b>!"
|
||||||
|
|
||||||
|
#: i18n_example.py:53
|
||||||
|
msgid "Your current language: <i>{language}</i>"
|
||||||
|
msgstr "Твоя мова: <i>{language}</i>"
|
||||||
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, Dispatcher, executor, filters, types
|
||||||
from aiogram.dispatcher import Dispatcher
|
|
||||||
from aiogram.types import ChatActions
|
|
||||||
from aiogram.utils.executor import start_polling
|
|
||||||
|
|
||||||
API_TOKEN = 'BOT TOKEN HERE'
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -12,7 +9,7 @@ bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
dp = Dispatcher(bot)
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(commands=['start'])
|
@dp.message_handler(filters.CommandStart())
|
||||||
async def send_welcome(message: types.Message):
|
async def send_welcome(message: types.Message):
|
||||||
# So... At first I want to send something like this:
|
# So... At first I want to send something like this:
|
||||||
await message.reply("Do you want to see many pussies? Are you ready?")
|
await message.reply("Do you want to see many pussies? Are you ready?")
|
||||||
|
|
@ -21,7 +18,7 @@ async def send_welcome(message: types.Message):
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
# Good bots should send chat actions. Or not.
|
# Good bots should send chat actions. Or not.
|
||||||
await ChatActions.upload_photo()
|
await types.ChatActions.upload_photo()
|
||||||
|
|
||||||
# Create media group
|
# Create media group
|
||||||
media = types.MediaGroup()
|
media = types.MediaGroup()
|
||||||
|
|
@ -39,9 +36,8 @@ async def send_welcome(message: types.Message):
|
||||||
# media.attach_photo('<file_id>', 'cat-cat-cat.')
|
# media.attach_photo('<file_id>', 'cat-cat-cat.')
|
||||||
|
|
||||||
# Done! Send media group
|
# Done! Send media group
|
||||||
await bot.send_media_group(message.chat.id, media=media,
|
await message.reply_media_group(media=media)
|
||||||
reply_to_message_id=message.message_id)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
start_polling(dp, loop=loop, skip_updates=True)
|
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, Dispatcher, executor, types
|
||||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
|
||||||
|
from aiogram.dispatcher.handler import CancelHandler
|
||||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
from aiogram.utils import context, executor
|
|
||||||
from aiogram.utils.exceptions import Throttled
|
|
||||||
|
|
||||||
TOKEN = 'BOT TOKEN HERE'
|
TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
|
@ -53,10 +52,10 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
:param message:
|
:param message:
|
||||||
"""
|
"""
|
||||||
# Get current handler
|
# Get current handler
|
||||||
handler = context.get_value('handler')
|
# handler = context.get_value('handler')
|
||||||
|
|
||||||
# Get dispatcher from context
|
# Get dispatcher from context
|
||||||
dispatcher = ctx.get_dispatcher()
|
dispatcher = Dispatcher.current()
|
||||||
|
|
||||||
# If handler was configured, get rate limit and key from handler
|
# If handler was configured, get rate limit and key from handler
|
||||||
if handler:
|
if handler:
|
||||||
|
|
@ -83,8 +82,8 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
:param message:
|
:param message:
|
||||||
:param throttled:
|
:param throttled:
|
||||||
"""
|
"""
|
||||||
handler = context.get_value('handler')
|
# handler = context.get_value('handler')
|
||||||
dispatcher = ctx.get_dispatcher()
|
dispatcher = Dispatcher.current()
|
||||||
if handler:
|
if handler:
|
||||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from aiogram import Bot
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from aiogram.utils import executor
|
from aiogram.utils import executor
|
||||||
from aiogram.dispatcher import Dispatcher
|
from aiogram.dispatcher import Dispatcher
|
||||||
from aiogram.types.message import ContentType
|
from aiogram.types.message import ContentTypes
|
||||||
|
|
||||||
|
|
||||||
BOT_TOKEN = 'BOT TOKEN HERE'
|
BOT_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
@ -86,7 +86,7 @@ async def checkout(pre_checkout_query: types.PreCheckoutQuery):
|
||||||
" try to pay again in a few minutes, we need a small rest.")
|
" try to pay again in a few minutes, we need a small rest.")
|
||||||
|
|
||||||
|
|
||||||
@dp.message_handler(content_types=ContentType.SUCCESSFUL_PAYMENT)
|
@dp.message_handler(content_types=ContentTypes.SUCCESSFUL_PAYMENT)
|
||||||
async def got_payment(message: types.Message):
|
async def got_payment(message: types.Message):
|
||||||
await bot.send_message(message.chat.id,
|
await bot.send_message(message.chat.id,
|
||||||
'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`'
|
'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`'
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ PROXY_URL = 'http://PROXY_URL' # Or 'socks5://...'
|
||||||
# PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password')
|
# PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password')
|
||||||
# And add `proxy_auth=PROXY_AUTH` argument in line 25, like this:
|
# And add `proxy_auth=PROXY_AUTH` argument in line 25, like this:
|
||||||
# >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH)
|
# >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH)
|
||||||
# Also you can use Socks5 proxy but you need manually install aiosocksy package.
|
# Also you can use Socks5 proxy but you need manually install aiohttp_socks package.
|
||||||
|
|
||||||
# Get my ip URL
|
# Get my ip URL
|
||||||
GET_IP_URL = 'http://bot.whatismyipaddress.com/'
|
GET_IP_URL = 'http://bot.whatismyipaddress.com/'
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from aiogram import Bot, types, Version
|
||||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
from aiogram.dispatcher import Dispatcher
|
from aiogram.dispatcher import Dispatcher
|
||||||
from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage
|
from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage
|
||||||
from aiogram.types import ChatType, ParseMode, ContentType
|
from aiogram.types import ChatType, ParseMode, ContentTypes
|
||||||
from aiogram.utils.markdown import hbold, bold, text, link
|
from aiogram.utils.markdown import hbold, bold, text, link
|
||||||
|
|
||||||
TOKEN = 'BOT TOKEN HERE'
|
TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
@ -31,7 +31,7 @@ WEBHOOK_URL = f"https://{WEBHOOK_HOST}:{WEBHOOK_PORT}{WEBHOOK_URL_PATH}"
|
||||||
WEBAPP_HOST = 'localhost'
|
WEBAPP_HOST = 'localhost'
|
||||||
WEBAPP_PORT = 3001
|
WEBAPP_PORT = 3001
|
||||||
|
|
||||||
BAD_CONTENT = ContentType.PHOTO & ContentType.DOCUMENT & ContentType.STICKER & ContentType.AUDIO
|
BAD_CONTENT = ContentTypes.PHOTO & ContentTypes.DOCUMENT & ContentTypes.STICKER & ContentTypes.AUDIO
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
bot = Bot(TOKEN, loop=loop)
|
bot = Bot(TOKEN, loop=loop)
|
||||||
|
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -13,7 +13,7 @@ except ImportError: # pip >= 10.0.0
|
||||||
WORK_DIR = pathlib.Path(__file__).parent
|
WORK_DIR = pathlib.Path(__file__).parent
|
||||||
|
|
||||||
# Check python version
|
# Check python version
|
||||||
MINIMAL_PY_VERSION = (3, 6)
|
MINIMAL_PY_VERSION = (3, 7)
|
||||||
if sys.version_info < MINIMAL_PY_VERSION:
|
if sys.version_info < MINIMAL_PY_VERSION:
|
||||||
raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION))))
|
raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION))))
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ setup(
|
||||||
url='https://github.com/aiogram/aiogram',
|
url='https://github.com/aiogram/aiogram',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
author='Alex Root Junior',
|
author='Alex Root Junior',
|
||||||
requires_python='>=3.6',
|
requires_python='>=3.7',
|
||||||
author_email='aiogram@illemius.xyz',
|
author_email='aiogram@illemius.xyz',
|
||||||
description='Is a pretty simple and fully asynchronous library for Telegram Bot API',
|
description='Is a pretty simple and fully asynchronous library for Telegram Bot API',
|
||||||
long_description=get_description(),
|
long_description=get_description(),
|
||||||
|
|
@ -76,7 +76,7 @@ setup(
|
||||||
'Intended Audience :: Developers',
|
'Intended Audience :: Developers',
|
||||||
'Intended Audience :: System Administrators',
|
'Intended Audience :: System Administrators',
|
||||||
'License :: OSI Approved :: MIT License',
|
'License :: OSI Approved :: MIT License',
|
||||||
'Programming Language :: Python :: 3.6',
|
'Programming Language :: Python :: 3.7',
|
||||||
'Topic :: Software Development :: Libraries :: Application Frameworks',
|
'Topic :: Software Development :: Libraries :: Application Frameworks',
|
||||||
],
|
],
|
||||||
install_requires=get_requirements()
|
install_requires=get_requirements()
|
||||||
|
|
|
||||||
102
tests/states_group.py
Normal file
102
tests/states_group.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from aiogram.dispatcher.filters.state import State, StatesGroup, any_state, default_state
|
||||||
|
|
||||||
|
|
||||||
|
class MyGroup(StatesGroup):
|
||||||
|
state = State()
|
||||||
|
state_1 = State()
|
||||||
|
state_2 = State()
|
||||||
|
|
||||||
|
class MySubGroup(StatesGroup):
|
||||||
|
sub_state = State()
|
||||||
|
sub_state_1 = State()
|
||||||
|
sub_state_2 = State()
|
||||||
|
|
||||||
|
in_custom_group = State(group_name='custom_group')
|
||||||
|
|
||||||
|
class NewGroup(StatesGroup):
|
||||||
|
spam = State()
|
||||||
|
renamed_state = State(state='spam_state')
|
||||||
|
|
||||||
|
|
||||||
|
alone_state = State('alone')
|
||||||
|
alone_in_group = State('alone', group_name='home')
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_state():
|
||||||
|
assert default_state.state is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_state():
|
||||||
|
assert any_state.state == '*'
|
||||||
|
|
||||||
|
|
||||||
|
def test_alone_state():
|
||||||
|
assert alone_state.state == '@:alone'
|
||||||
|
assert alone_in_group.state == 'home:alone'
|
||||||
|
|
||||||
|
|
||||||
|
def test_group_names():
|
||||||
|
assert MyGroup.__group_name__ == 'MyGroup'
|
||||||
|
assert MyGroup.__full_group_name__ == 'MyGroup'
|
||||||
|
|
||||||
|
assert MyGroup.MySubGroup.__group_name__ == 'MySubGroup'
|
||||||
|
assert MyGroup.MySubGroup.__full_group_name__ == 'MyGroup.MySubGroup'
|
||||||
|
|
||||||
|
assert MyGroup.MySubGroup.NewGroup.__group_name__ == 'NewGroup'
|
||||||
|
assert MyGroup.MySubGroup.NewGroup.__full_group_name__ == 'MyGroup.MySubGroup.NewGroup'
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_group_in_group():
|
||||||
|
assert MyGroup.MySubGroup.in_custom_group.state == 'custom_group:in_custom_group'
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_state_name_in_group():
|
||||||
|
assert MyGroup.MySubGroup.NewGroup.renamed_state.state == 'MyGroup.MySubGroup.NewGroup:spam_state'
|
||||||
|
|
||||||
|
|
||||||
|
def test_group_states_names():
|
||||||
|
assert len(MyGroup.states) == 3
|
||||||
|
assert len(MyGroup.all_states) == 9
|
||||||
|
|
||||||
|
assert MyGroup.states_names == ('MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2')
|
||||||
|
assert MyGroup.MySubGroup.states_names == (
|
||||||
|
'MyGroup.MySubGroup:sub_state', 'MyGroup.MySubGroup:sub_state_1', 'MyGroup.MySubGroup:sub_state_2',
|
||||||
|
'custom_group:in_custom_group')
|
||||||
|
assert MyGroup.MySubGroup.NewGroup.states_names == (
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam', 'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||||
|
|
||||||
|
assert MyGroup.all_states_names == (
|
||||||
|
'MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2',
|
||||||
|
'MyGroup.MySubGroup:sub_state',
|
||||||
|
'MyGroup.MySubGroup:sub_state_1',
|
||||||
|
'MyGroup.MySubGroup:sub_state_2',
|
||||||
|
'custom_group:in_custom_group',
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam',
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||||
|
|
||||||
|
assert MyGroup.MySubGroup.all_states_names == (
|
||||||
|
'MyGroup.MySubGroup:sub_state',
|
||||||
|
'MyGroup.MySubGroup:sub_state_1',
|
||||||
|
'MyGroup.MySubGroup:sub_state_2',
|
||||||
|
'custom_group:in_custom_group',
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam',
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||||
|
|
||||||
|
assert MyGroup.MySubGroup.NewGroup.all_states_names == (
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam',
|
||||||
|
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_element():
|
||||||
|
root = MyGroup.MySubGroup.NewGroup.spam.get_root()
|
||||||
|
|
||||||
|
assert issubclass(root, StatesGroup)
|
||||||
|
assert root == MyGroup
|
||||||
|
|
||||||
|
assert root == MyGroup.state.get_root()
|
||||||
|
assert root == MyGroup.MySubGroup.get_root()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
any_state.get_root()
|
||||||
2
tox.ini
2
tox.ini
|
|
@ -1,5 +1,5 @@
|
||||||
[tox]
|
[tox]
|
||||||
envlist = py36
|
envlist = py37
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps = -rdev_requirements.txt
|
deps = -rdev_requirements.txt
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue