mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge c733257be5 into 5f7a495b10
This commit is contained in:
commit
bfccd6268c
64 changed files with 3455 additions and 2299 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -57,3 +57,6 @@ experiment.py
|
|||
|
||||
# Doc's
|
||||
docs/html
|
||||
|
||||
# i18n/l10n
|
||||
*.mo
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
[](https://github.com/aiogram/aiogram/issues)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.6 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
|
||||
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.7 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
|
||||
|
||||
You can [read the docs here](http://aiogram.readthedocs.io/en/latest/).
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,42 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from . import bot
|
||||
from . import contrib
|
||||
from . import dispatcher
|
||||
from . import types
|
||||
from . import utils
|
||||
from .bot import Bot
|
||||
from .dispatcher import Dispatcher
|
||||
from .dispatcher import filters
|
||||
from .dispatcher import middlewares
|
||||
from .utils import exceptions, executor, helper, markdown as md
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
except ImportError:
|
||||
uvloop = None
|
||||
else:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
if 'DISABLE_UVLOOP' not in os.environ:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
__version__ = '1.4'
|
||||
__all__ = [
|
||||
'Bot',
|
||||
'Dispatcher',
|
||||
'__api_version__',
|
||||
'__version__',
|
||||
'bot',
|
||||
'contrib',
|
||||
'dispatcher',
|
||||
'exceptions',
|
||||
'executor',
|
||||
'filters',
|
||||
'helper',
|
||||
'md',
|
||||
'middlewares',
|
||||
'types',
|
||||
'utils'
|
||||
]
|
||||
|
||||
__version__ = '2.0.dev1'
|
||||
__api_version__ = '3.6'
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from asyncio import AbstractEventLoop
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
from .. import types
|
||||
from ..utils import exceptions
|
||||
|
|
@ -34,58 +40,73 @@ def check_token(token: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def _check_result(method_name, response):
|
||||
"""
|
||||
Checks whether `result` is a valid API response.
|
||||
A result is considered invalid if:
|
||||
- The server returned an HTTP response code other than 200
|
||||
- The content of the result is invalid JSON.
|
||||
- The method call was unsuccessful (The JSON 'ok' field equals False)
|
||||
async def check_result(method_name: str, content_type: str, status_code: int, body: str):
|
||||
"""
|
||||
Checks whether `result` is a valid API response.
|
||||
A result is considered invalid if:
|
||||
- The server returned an HTTP response code other than 200
|
||||
- The content of the result is invalid JSON.
|
||||
- The method call was unsuccessful (The JSON 'ok' field equals False)
|
||||
|
||||
:raises ApiException: if one of the above listed cases is applicable
|
||||
:param method_name: The name of the method called
|
||||
:param response: The returned response of the method request
|
||||
:return: The result parsed to a JSON dictionary.
|
||||
"""
|
||||
body = await response.text()
|
||||
log.debug(f"Response for {method_name}: [{response.status}] {body}")
|
||||
:param method_name: The name of the method called
|
||||
:param status_code: status code
|
||||
:param content_type: content type of result
|
||||
:param body: result body
|
||||
:return: The result parsed to a JSON dictionary
|
||||
:raises ApiException: if one of the above listed cases is applicable
|
||||
"""
|
||||
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
|
||||
|
||||
if response.content_type != 'application/json':
|
||||
raise exceptions.NetworkError(f"Invalid response with content type {response.content_type}: \"{body}\"")
|
||||
if content_type != 'application/json':
|
||||
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
|
||||
|
||||
try:
|
||||
result_json = json.loads(body)
|
||||
except ValueError:
|
||||
result_json = {}
|
||||
|
||||
description = result_json.get('description') or body
|
||||
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
||||
|
||||
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
|
||||
return result_json.get('result')
|
||||
elif parameters.retry_after:
|
||||
raise exceptions.RetryAfter(parameters.retry_after)
|
||||
elif parameters.migrate_to_chat_id:
|
||||
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
|
||||
elif status_code == HTTPStatus.BAD_REQUEST:
|
||||
exceptions.BadRequest.detect(description)
|
||||
elif status_code == HTTPStatus.NOT_FOUND:
|
||||
exceptions.NotFound.detect(description)
|
||||
elif status_code == HTTPStatus.CONFLICT:
|
||||
exceptions.ConflictError.detect(description)
|
||||
elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
|
||||
exceptions.Unauthorized.detect(description)
|
||||
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
||||
raise exceptions.NetworkError('File too large for uploading. '
|
||||
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
||||
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
if 'restart' in description:
|
||||
raise exceptions.RestartingTelegram()
|
||||
raise exceptions.TelegramAPIError(description)
|
||||
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
|
||||
|
||||
|
||||
async def make_request(session, token, method, data=None, files=None, **kwargs):
|
||||
# log.debug(f"Make request: '{method}' with data: {data} and files {files}")
|
||||
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
|
||||
|
||||
url = Methods.api_url(token=token, method=method)
|
||||
|
||||
req = compose_data(data, files)
|
||||
try:
|
||||
result_json = await response.json(loads=json.loads)
|
||||
except ValueError:
|
||||
result_json = {}
|
||||
|
||||
description = result_json.get('description') or body
|
||||
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
||||
|
||||
if HTTPStatus.OK <= response.status <= HTTPStatus.IM_USED:
|
||||
return result_json.get('result')
|
||||
elif parameters.retry_after:
|
||||
raise exceptions.RetryAfter(parameters.retry_after)
|
||||
elif parameters.migrate_to_chat_id:
|
||||
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
|
||||
elif response.status == HTTPStatus.BAD_REQUEST:
|
||||
exceptions.BadRequest.detect(description)
|
||||
elif response.status == HTTPStatus.NOT_FOUND:
|
||||
exceptions.NotFound.detect(description)
|
||||
elif response.status == HTTPStatus.CONFLICT:
|
||||
exceptions.ConflictError.detect(description)
|
||||
elif response.status in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
|
||||
exceptions.Unauthorized.detect(description)
|
||||
elif response.status == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
||||
raise exceptions.NetworkError('File too large for uploading. '
|
||||
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
||||
elif response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
if 'restart' in description:
|
||||
raise exceptions.RestartingTelegram()
|
||||
raise exceptions.TelegramAPIError(description)
|
||||
raise exceptions.TelegramAPIError(f"{description} [{response.status}]")
|
||||
async with session.post(url, data=req, **kwargs) as response:
|
||||
return await check_result(method, response.content_type, response.status, await response.text())
|
||||
except aiohttp.ClientError as e:
|
||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
def _guess_filename(obj):
|
||||
def guess_filename(obj):
|
||||
"""
|
||||
Get file name from object
|
||||
|
||||
|
|
@ -97,7 +118,7 @@ def _guess_filename(obj):
|
|||
return os.path.basename(name)
|
||||
|
||||
|
||||
def _compose_data(params=None, files=None):
|
||||
def compose_data(params=None, files=None):
|
||||
"""
|
||||
Prepare request data
|
||||
|
||||
|
|
@ -121,47 +142,13 @@ def _compose_data(params=None, files=None):
|
|||
elif isinstance(f, types.InputFile):
|
||||
filename, fileobj = f.filename, f.file
|
||||
else:
|
||||
filename, fileobj = _guess_filename(f) or key, f
|
||||
filename, fileobj = guess_filename(f) or key, f
|
||||
|
||||
data.add_field(key, fileobj, filename=filename)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
|
||||
"""
|
||||
Make request to API
|
||||
|
||||
That make request with Content-Type:
|
||||
application/x-www-form-urlencoded - For simple request
|
||||
and multipart/form-data - for files uploading
|
||||
|
||||
https://core.telegram.org/bots/api#making-requests
|
||||
|
||||
:param session: HTTP Client session
|
||||
:type session: :obj:`aiohttp.ClientSession`
|
||||
:param token: BOT token
|
||||
:type token: :obj:`str`
|
||||
:param method: API method
|
||||
:type method: :obj:`str`
|
||||
:param data: request payload
|
||||
:type data: :obj:`dict`
|
||||
:param files: files
|
||||
:type files: :obj:`dict`
|
||||
:return: result
|
||||
:rtype :obj:`bool` or :obj:`dict`
|
||||
"""
|
||||
log.debug("Make request: '{0}' with data: {1} and files {2}".format(
|
||||
method, data or {}, files or {}))
|
||||
data = _compose_data(data, files)
|
||||
url = Methods.api_url(token=token, method=method)
|
||||
try:
|
||||
async with session.post(url, data=data, **kwargs) as response:
|
||||
return await _check_result(method, response)
|
||||
except aiohttp.ClientError as e:
|
||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
class Methods(Helper):
|
||||
"""
|
||||
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ class BaseBot:
|
|||
api.check_token(token)
|
||||
self.__token = token
|
||||
|
||||
# Proxy settings
|
||||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
|
||||
|
|
@ -59,37 +58,42 @@ class BaseBot:
|
|||
# aiohttp main session
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
if isinstance(proxy, str) and proxy.startswith('socks5://'):
|
||||
from aiosocksy.connector import ProxyClientRequest, ProxyConnector
|
||||
connector = ProxyConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop)
|
||||
request_class = ProxyClientRequest
|
||||
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
|
||||
from aiohttp_socks import SocksConnector
|
||||
from aiohttp_socks.helpers import parse_socks_url
|
||||
|
||||
socks_ver, host, port, username, password = parse_socks_url(proxy)
|
||||
if proxy_auth and not username or password:
|
||||
username = proxy_auth.login
|
||||
password = proxy_auth.password
|
||||
|
||||
connector = SocksConnector(socks_ver=socks_ver, host=host, port=port,
|
||||
username=username, password=password,
|
||||
limit=connections_limit, ssl_context=ssl_context,
|
||||
loop=self.loop)
|
||||
|
||||
self.proxy = None
|
||||
self.proxy_auth = None
|
||||
else:
|
||||
connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context,
|
||||
loop=self.loop)
|
||||
request_class = aiohttp.ClientRequest
|
||||
|
||||
self.session = aiohttp.ClientSession(connector=connector, request_class=request_class,
|
||||
loop=self.loop, json_serialize=json.dumps)
|
||||
|
||||
# Data stored in bot instance
|
||||
self._data = {}
|
||||
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps)
|
||||
|
||||
self.parse_mode = parse_mode
|
||||
|
||||
def __del__(self):
|
||||
# asyncio.ensure_future(self.close())
|
||||
pass
|
||||
# Data stored in bot instance
|
||||
self._data = {}
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close all client sessions
|
||||
"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
await self.session.close()
|
||||
|
||||
async def request(self, method: base.String,
|
||||
data: Optional[Dict] = None,
|
||||
files: Optional[Dict] = None) -> Union[List, Dict, base.Boolean]:
|
||||
files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]:
|
||||
"""
|
||||
Make an request to Telegram Bot API
|
||||
|
||||
|
|
@ -105,8 +109,8 @@ class BaseBot:
|
|||
:rtype: Union[List, Dict]
|
||||
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
||||
"""
|
||||
return await api.request(self.session, self.__token, method, data, files,
|
||||
proxy=self.proxy, proxy_auth=self.proxy_auth)
|
||||
return await api.make_request(self.session, self.__token, method, data, files,
|
||||
proxy=self.proxy, proxy_auth=self.proxy_auth, **kwargs)
|
||||
|
||||
async def download_file(self, file_path: base.String,
|
||||
destination: Optional[base.InputFile] = None,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
import typing
|
||||
|
||||
from ...dispatcher import BaseStorage
|
||||
from ...dispatcher.storage import BaseStorage
|
||||
|
||||
|
||||
class MemoryStorage(BaseStorage):
|
||||
|
|
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
|
|||
chat, user = self.check_address(chat=chat, user=user)
|
||||
user = self._get_user(chat, user)
|
||||
if data is None:
|
||||
data = []
|
||||
data = {}
|
||||
user['data'].update(data, **kwargs)
|
||||
|
||||
async def set_state(self, *,
|
||||
|
|
|
|||
|
|
@ -303,6 +303,8 @@ class RedisStorage2(BaseStorage):
|
|||
|
||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None, **kwargs):
|
||||
if data is None:
|
||||
data = {}
|
||||
temp_data = await self.get_data(chat=chat, user=user, default={})
|
||||
temp_data.update(data, **kwargs)
|
||||
await self.set_data(chat=chat, user=user, data=temp_data)
|
||||
|
|
@ -330,6 +332,8 @@ class RedisStorage2(BaseStorage):
|
|||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None, **kwargs):
|
||||
if bucket is None:
|
||||
bucket = {}
|
||||
temp_bucket = await self.get_data(chat=chat, user=user)
|
||||
temp_bucket.update(bucket, **kwargs)
|
||||
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import weakref
|
|||
|
||||
import rethinkdb as r
|
||||
|
||||
from ...dispatcher import BaseStorage
|
||||
from ...dispatcher.storage import BaseStorage
|
||||
|
||||
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']
|
||||
|
||||
|
|
|
|||
|
|
@ -1,115 +0,0 @@
|
|||
from aiogram import types
|
||||
from aiogram.dispatcher import ctx
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
|
||||
OBJ_KEY = '_context_data'
|
||||
|
||||
|
||||
class ContextMiddleware(BaseMiddleware):
|
||||
"""
|
||||
Allow to store data at all of lifetime of Update object
|
||||
"""
|
||||
|
||||
async def on_pre_process_update(self, update: types.Update):
|
||||
"""
|
||||
Start of Update lifetime
|
||||
|
||||
:param update:
|
||||
:return:
|
||||
"""
|
||||
self._configure_update(update)
|
||||
|
||||
async def on_post_process_update(self, update: types.Update, result):
|
||||
"""
|
||||
On finishing of processing update
|
||||
|
||||
:param update:
|
||||
:param result:
|
||||
:return:
|
||||
"""
|
||||
if OBJ_KEY in update.conf:
|
||||
del update.conf[OBJ_KEY]
|
||||
|
||||
def _configure_update(self, update: types.Update = None):
|
||||
"""
|
||||
Setup data storage
|
||||
|
||||
:param update:
|
||||
:return:
|
||||
"""
|
||||
obj = update.conf[OBJ_KEY] = {}
|
||||
return obj
|
||||
|
||||
def _get_dict(self):
|
||||
"""
|
||||
Get data from update stored in current context
|
||||
|
||||
:return:
|
||||
"""
|
||||
update = ctx.get_update()
|
||||
obj = update.conf.get(OBJ_KEY, None)
|
||||
if obj is None:
|
||||
obj = self._configure_update(update)
|
||||
return obj
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
Item getter
|
||||
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict()[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""
|
||||
Item setter
|
||||
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
data = self._get_dict()
|
||||
data[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over dict
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict().__iter__()
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Iterate over dict keys
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict().keys()
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
Iterate over dict values
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict().values()
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Get item from dit or return default value
|
||||
|
||||
:param key:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict().get(key, default)
|
||||
|
||||
def export(self):
|
||||
"""
|
||||
Export all data s dict
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._get_dict()
|
||||
25
aiogram/contrib/middlewares/environment.py
Normal file
25
aiogram/contrib/middlewares/environment.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
|
||||
|
||||
class EnvironmentMiddleware(BaseMiddleware):
|
||||
def __init__(self, context=None):
|
||||
super(EnvironmentMiddleware, self).__init__()
|
||||
|
||||
if context is None:
|
||||
context = {}
|
||||
self.context = context
|
||||
|
||||
def update_data(self, data):
|
||||
dp = self.manager.dispatcher
|
||||
data.update(
|
||||
bot=dp.bot,
|
||||
dispatcher=dp,
|
||||
loop=dp.loop
|
||||
)
|
||||
if self.context:
|
||||
data.update(self.context)
|
||||
|
||||
async def trigger(self, action, args):
|
||||
if 'error' not in action and action.startswith('pre_process_'):
|
||||
self.update_data(args[-1])
|
||||
return True
|
||||
80
aiogram/contrib/middlewares/fsm.py
Normal file
80
aiogram/contrib/middlewares/fsm.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import copy
|
||||
import weakref
|
||||
|
||||
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||
from aiogram.dispatcher.storage import FSMContext
|
||||
|
||||
|
||||
class FSMMiddleware(LifetimeControllerMiddleware):
|
||||
skip_patterns = ['error', 'update']
|
||||
|
||||
def __init__(self):
|
||||
super(FSMMiddleware, self).__init__()
|
||||
self._proxies = weakref.WeakKeyDictionary()
|
||||
|
||||
async def pre_process(self, obj, data, *args):
|
||||
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
|
||||
data['state_data'] = proxy
|
||||
|
||||
async def post_process(self, obj, data, *args):
|
||||
proxy = data.get('state_data', None)
|
||||
if isinstance(proxy, FSMSStorageProxy):
|
||||
await proxy.save()
|
||||
|
||||
|
||||
class FSMSStorageProxy(dict):
|
||||
def __init__(self, fsm_context: FSMContext):
|
||||
super(FSMSStorageProxy, self).__init__()
|
||||
self.fsm_context = fsm_context
|
||||
self._copy = {}
|
||||
self._data = {}
|
||||
self._state = None
|
||||
self._is_dirty = False
|
||||
|
||||
@classmethod
|
||||
async def create(cls, fsm_context: FSMContext):
|
||||
"""
|
||||
:param fsm_context:
|
||||
:return:
|
||||
"""
|
||||
proxy = cls(fsm_context)
|
||||
await proxy.load()
|
||||
return proxy
|
||||
|
||||
async def load(self):
|
||||
self.clear()
|
||||
self._state = await self.fsm_context.get_state()
|
||||
self.update(await self.fsm_context.get_data())
|
||||
self._copy = copy.deepcopy(self)
|
||||
self._is_dirty = False
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
self._state = value
|
||||
self._is_dirty = True
|
||||
|
||||
@state.deleter
|
||||
def state(self):
|
||||
self._state = None
|
||||
self._is_dirty = True
|
||||
|
||||
async def save(self, force=False):
|
||||
if self._copy != self or force:
|
||||
await self.fsm_context.set_data(data=self)
|
||||
if self._is_dirty or force:
|
||||
await self.fsm_context.set_state(self.state)
|
||||
self._is_dirty = False
|
||||
self._copy = copy.deepcopy(self)
|
||||
|
||||
def __str__(self):
|
||||
s = super(FSMSStorageProxy, self).__str__()
|
||||
readable_state = f"'{self.state}'" if self.state else "''"
|
||||
return f"<{self.__class__.__name__}(state={readable_state}, data={s})>"
|
||||
|
||||
def clear(self):
|
||||
del self.state
|
||||
return super(FSMSStorageProxy, self).clear()
|
||||
140
aiogram/contrib/middlewares/i18n.py
Normal file
140
aiogram/contrib/middlewares/i18n.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import gettext
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from babel import Locale
|
||||
|
||||
from ... import types
|
||||
from ...dispatcher.middlewares import BaseMiddleware
|
||||
|
||||
|
||||
class I18nMiddleware(BaseMiddleware):
|
||||
"""
|
||||
I18n middleware based on gettext util
|
||||
|
||||
>>> dp = Dispatcher(bot)
|
||||
>>> i18n = I18nMiddleware(DOMAIN, LOCALES_DIR)
|
||||
>>> dp.middleware.setup(i18n)
|
||||
and then
|
||||
>>> _ = i18n.gettext
|
||||
or
|
||||
>>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR)
|
||||
"""
|
||||
|
||||
ctx_locale = ContextVar('ctx_user_locale', default=None)
|
||||
|
||||
def __init__(self, domain, path=None, default='en'):
|
||||
"""
|
||||
:param domain: domain
|
||||
:param path: path where located all *.mo files
|
||||
:param default: default locale name
|
||||
"""
|
||||
super(I18nMiddleware, self).__init__()
|
||||
|
||||
if path is None:
|
||||
path = os.path.join(os.getcwd(), 'locales')
|
||||
|
||||
self.domain = domain
|
||||
self.path = path
|
||||
self.default = default
|
||||
|
||||
self.locales = self.find_locales()
|
||||
|
||||
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
|
||||
"""
|
||||
Load all compiled locales from path
|
||||
|
||||
:return: dict with locales
|
||||
"""
|
||||
translations = {}
|
||||
|
||||
for name in os.listdir(self.path):
|
||||
if not os.path.isdir(os.path.join(self.path, name)):
|
||||
continue
|
||||
mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo')
|
||||
|
||||
if os.path.exists(mo_path):
|
||||
with open(mo_path, 'rb') as fp:
|
||||
translations[name] = gettext.GNUTranslations(fp)
|
||||
elif os.path.exists(mo_path[:-2] + 'po'):
|
||||
raise RuntimeError(f"Found locale '{name} but this language is not compiled!")
|
||||
|
||||
return translations
|
||||
|
||||
def reload(self):
|
||||
"""
|
||||
Hot reload locles
|
||||
"""
|
||||
self.locales = self.find_locales()
|
||||
|
||||
@property
|
||||
def available_locales(self) -> Tuple[str]:
|
||||
"""
|
||||
list of loaded locales
|
||||
|
||||
:return:
|
||||
"""
|
||||
return tuple(self.locales.keys())
|
||||
|
||||
def __call__(self, singular, plural=None, n=1, locale=None) -> str:
|
||||
return self.gettext(singular, plural, n, locale)
|
||||
|
||||
def gettext(self, singular, plural=None, n=1, locale=None) -> str:
|
||||
"""
|
||||
Get text
|
||||
|
||||
:param singular:
|
||||
:param plural:
|
||||
:param n:
|
||||
:param locale:
|
||||
:return:
|
||||
"""
|
||||
if locale is None:
|
||||
locale = self.ctx_locale.get()
|
||||
|
||||
if locale not in self.locales:
|
||||
if n is 1:
|
||||
return singular
|
||||
else:
|
||||
return plural
|
||||
|
||||
translator = self.locales[locale]
|
||||
|
||||
if plural is None:
|
||||
return translator.gettext(singular)
|
||||
else:
|
||||
return translator.ngettext(singular, plural, n)
|
||||
|
||||
# noinspection PyMethodMayBeStatic,PyUnusedLocal
|
||||
async def get_user_locale(self, action: str, args: Tuple[Any]) -> str:
|
||||
"""
|
||||
User locale getter
|
||||
You can override the method if you want to use different way of getting user language.
|
||||
|
||||
:param action: event name
|
||||
:param args: event arguments
|
||||
:return: locale name
|
||||
"""
|
||||
user: types.User = types.User.current()
|
||||
locale: Locale = user.locale
|
||||
|
||||
if locale:
|
||||
*_, data = args
|
||||
language = data['locale'] = locale.language
|
||||
return language
|
||||
|
||||
async def trigger(self, action, args):
|
||||
"""
|
||||
Event trigger
|
||||
|
||||
:param action: event name
|
||||
:param args: event arguments
|
||||
:return:
|
||||
"""
|
||||
if 'update' not in action \
|
||||
and 'error' not in action \
|
||||
and action.startswith('pre_process'):
|
||||
locale = await self.get_user_locale(action, args)
|
||||
self.ctx_locale.set(locale)
|
||||
return True
|
||||
|
|
@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware):
|
|||
return round((time.time() - start) * 1000)
|
||||
return -1
|
||||
|
||||
async def on_pre_process_update(self, update: types.Update):
|
||||
async def on_pre_process_update(self, update: types.Update, data: dict):
|
||||
update.conf['_start'] = time.time()
|
||||
self.logger.debug(f"Received update [ID:{update.update_id}]")
|
||||
|
||||
async def on_post_process_update(self, update: types.Update, result):
|
||||
async def on_post_process_update(self, update: types.Update, result, data: dict):
|
||||
timeout = self.check_timeout(update)
|
||||
if timeout > 0:
|
||||
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
|
||||
|
||||
async def on_pre_process_message(self, message: types.Message):
|
||||
async def on_pre_process_message(self, message: types.Message, data: dict):
|
||||
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||
|
||||
async def on_post_process_message(self, message: types.Message, results):
|
||||
async def on_post_process_message(self, message: types.Message, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||
|
||||
async def on_pre_process_edited_message(self, edited_message):
|
||||
async def on_pre_process_edited_message(self, edited_message, data: dict):
|
||||
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
|
||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||
|
||||
async def on_post_process_edited_message(self, edited_message, results):
|
||||
async def on_post_process_edited_message(self, edited_message, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"edited message [ID:{edited_message.message_id}] "
|
||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||
|
||||
async def on_pre_process_channel_post(self, channel_post: types.Message):
|
||||
async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
|
||||
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
|
||||
f"in channel [ID:{channel_post.chat.id}]")
|
||||
|
||||
async def on_post_process_channel_post(self, channel_post: types.Message, results):
|
||||
async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"channel post [ID:{channel_post.message_id}] "
|
||||
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
|
||||
|
||||
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
|
||||
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
|
||||
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
|
||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||
|
||||
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
|
||||
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"edited channel post [ID:{edited_channel_post.message_id}] "
|
||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||
|
||||
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
|
||||
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
|
||||
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
|
||||
f"from user [ID:{inline_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
|
||||
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"inline query [ID:{inline_query.id}] "
|
||||
f"from user [ID:{inline_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
|
||||
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
|
||||
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||
f"result [ID:{chosen_inline_result.result_id}]")
|
||||
|
||||
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
|
||||
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||
f"result [ID:{chosen_inline_result.result_id}]")
|
||||
|
||||
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
|
||||
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
|
||||
if callback_query.message:
|
||||
if callback_query.message.from_user:
|
||||
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||
|
|
@ -100,7 +100,7 @@ class LoggingMiddleware(BaseMiddleware):
|
|||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||
f"from user [ID:{callback_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_callback_query(self, callback_query, results):
|
||||
async def on_post_process_callback_query(self, callback_query, results, data: dict):
|
||||
if callback_query.message:
|
||||
if callback_query.message.from_user:
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
|
|
@ -117,25 +117,25 @@ class LoggingMiddleware(BaseMiddleware):
|
|||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||
f"from user [ID:{callback_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
|
||||
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
|
||||
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
|
||||
f"from user [ID:{shipping_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_shipping_query(self, shipping_query, results):
|
||||
async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"shipping query [ID:{shipping_query.id}] "
|
||||
f"from user [ID:{shipping_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
|
||||
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
|
||||
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
|
||||
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_error(self, dispatcher, update, error):
|
||||
async def on_pre_process_error(self, dispatcher, update, error, data: dict):
|
||||
timeout = self.check_timeout(update)
|
||||
if timeout > 0:
|
||||
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,42 +0,0 @@
|
|||
from . import Bot
|
||||
from .. import types
|
||||
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
|
||||
from ..utils import context
|
||||
|
||||
|
||||
def _get(key, default=None, no_error=False):
|
||||
result = context.get_value(key, default)
|
||||
if not no_error and result is None:
|
||||
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
||||
f"Maybe asyncio task factory is not configured!\n"
|
||||
f"\t>>> from aiogram.utils import context\n"
|
||||
f"\t>>> loop.set_task_factory(context.task_factory)")
|
||||
return result
|
||||
|
||||
|
||||
def get_bot() -> Bot:
|
||||
return _get('bot')
|
||||
|
||||
|
||||
def get_dispatcher() -> Dispatcher:
|
||||
return _get('dispatcher')
|
||||
|
||||
|
||||
def get_update() -> types.Update:
|
||||
return _get(UPDATE_OBJECT)
|
||||
|
||||
|
||||
def get_mode() -> str:
|
||||
return _get(MODE, 'unknown')
|
||||
|
||||
|
||||
def get_chat() -> int:
|
||||
return _get('chat', no_error=True)
|
||||
|
||||
|
||||
def get_user() -> int:
|
||||
return _get('user', no_error=True)
|
||||
|
||||
|
||||
def get_state() -> FSMContext:
|
||||
return get_dispatcher().current_state()
|
||||
1007
aiogram/dispatcher/dispatcher.py
Normal file
1007
aiogram/dispatcher/dispatcher.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,289 +0,0 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from ..types import CallbackQuery, ContentType, Message
|
||||
from ..utils import context
|
||||
from ..utils.helper import Helper, HelperMode, Item
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param filter_:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args)
|
||||
|
||||
|
||||
async def check_filters(filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
Base class for filters
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
def check(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncFilter(Filter):
|
||||
"""
|
||||
Base class for asynchronous filters
|
||||
"""
|
||||
|
||||
def __aiter__(self):
|
||||
return None
|
||||
|
||||
def __await__(self):
|
||||
return self.check
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class AnyFilter(AsyncFilter):
|
||||
"""
|
||||
One filter from many
|
||||
"""
|
||||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
class NotFilter(AsyncFilter):
|
||||
"""
|
||||
Reverse filter
|
||||
"""
|
||||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(AsyncFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
"""
|
||||
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
if command not in self.commands:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class RegexpFilter(Filter):
|
||||
"""
|
||||
Regexp filter for messages
|
||||
"""
|
||||
|
||||
def __init__(self, regexp):
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
def check(self, obj):
|
||||
if isinstance(obj, Message) and obj.text:
|
||||
return bool(self.regexp.search(obj.text))
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
return bool(self.regexp.search(obj.data))
|
||||
return False
|
||||
|
||||
|
||||
class RegexpCommandsFilter(AsyncFilter):
|
||||
"""
|
||||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
message.conf['regexp_command'] = search
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ContentTypeFilter(Filter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
|
||||
def check(self, message):
|
||||
return ContentType.ANY[0] in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class CancelFilter(Filter):
|
||||
"""
|
||||
Find cancel in message text
|
||||
"""
|
||||
|
||||
def __init__(self, cancel_set=None):
|
||||
if cancel_set is None:
|
||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||
self.cancel_set = cancel_set
|
||||
|
||||
def check(self, message):
|
||||
if message.text:
|
||||
return message.text.lower() in self.cancel_set
|
||||
|
||||
|
||||
class StateFilter(AsyncFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
self.dispatcher = dispatcher
|
||||
self.state = state
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
if self.state == '*':
|
||||
return True
|
||||
|
||||
if context.check_value(USER_STATE):
|
||||
context_state = context.get_value(USER_STATE)
|
||||
return self.state == context_state
|
||||
else:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
||||
return False
|
||||
|
||||
|
||||
class StatesListFilter(StateFilter):
|
||||
"""
|
||||
List of states
|
||||
"""
|
||||
|
||||
async def check(self, obj):
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
return False
|
||||
|
||||
|
||||
class ExceptionsFilter(Filter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
||||
def __init__(self, exception):
|
||||
self.exception = exception
|
||||
|
||||
def check(self, dispatcher, update, exception):
|
||||
return isinstance(exception, self.exception)
|
||||
|
||||
|
||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
||||
"""
|
||||
Prepare filters
|
||||
|
||||
:param dispatcher:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
filters_set = []
|
||||
|
||||
for name, filter_ in kwargs.items():
|
||||
if filter_ is None and name != DefaultFilters.STATE:
|
||||
continue
|
||||
if name == DefaultFilters.COMMANDS:
|
||||
if isinstance(filter_, str):
|
||||
filters_set.append(CommandsFilter([filter_]))
|
||||
else:
|
||||
filters_set.append(CommandsFilter(filter_))
|
||||
elif name == DefaultFilters.REGEXP:
|
||||
filters_set.append(RegexpFilter(filter_))
|
||||
elif name == DefaultFilters.CONTENT_TYPES:
|
||||
filters_set.append(ContentTypeFilter(filter_))
|
||||
elif name == DefaultFilters.FUNC:
|
||||
filters_set.append(filter_)
|
||||
elif name == DefaultFilters.STATE:
|
||||
if isinstance(filter_, (list, set, tuple)):
|
||||
filters_set.append(StatesListFilter(dispatcher, filter_))
|
||||
else:
|
||||
filters_set.append(StateFilter(dispatcher, filter_))
|
||||
elif isinstance(filter_, Filter):
|
||||
filters_set.append(filter_)
|
||||
|
||||
filters_set += list(args)
|
||||
|
||||
return filters_set
|
||||
|
||||
|
||||
class DefaultFilters(Helper):
|
||||
mode = HelperMode.snake_case
|
||||
|
||||
COMMANDS = Item() # commands
|
||||
REGEXP = Item() # regexp
|
||||
CONTENT_TYPES = Item() # content_type
|
||||
FUNC = Item() # func
|
||||
STATE = Item() # state
|
||||
24
aiogram/dispatcher/filters/__init__.py
Normal file
24
aiogram/dispatcher/filters/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
|
||||
RegexpCommandsFilter, StateFilter, Text
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
'BoundFilter',
|
||||
'Command',
|
||||
'CommandStart',
|
||||
'CommandHelp',
|
||||
'ContentTypeFilter',
|
||||
'ExceptionsFilter',
|
||||
'Filter',
|
||||
'FilterNotPassed',
|
||||
'FilterRecord',
|
||||
'FiltersFactory',
|
||||
'RegexpCommandsFilter',
|
||||
'Regexp',
|
||||
'StateFilter',
|
||||
'Text',
|
||||
'check_filter',
|
||||
'check_filters'
|
||||
]
|
||||
313
aiogram/dispatcher/filters/builtin.py
Normal file
313
aiogram/dispatcher/filters/builtin.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import inspect
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
from aiogram import types
|
||||
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
|
||||
from aiogram.types import CallbackQuery, Message
|
||||
|
||||
|
||||
class Command(Filter):
|
||||
"""
|
||||
You can handle commands by using this filter
|
||||
"""
|
||||
|
||||
def __init__(self, commands: Union[Iterable, str],
|
||||
prefixes: Union[Iterable, str] = '/',
|
||||
ignore_case: bool = True,
|
||||
ignore_mention: bool = False):
|
||||
"""
|
||||
Filter can be initialized from filters factory or by simply creating instance of this class
|
||||
|
||||
:param commands: command or list of commands
|
||||
:param prefixes:
|
||||
:param ignore_case:
|
||||
:param ignore_mention:
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
|
||||
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
||||
self.prefixes = prefixes
|
||||
self.ignore_case = ignore_case
|
||||
self.ignore_mention = ignore_mention
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validator for filters factory
|
||||
|
||||
:param full_config:
|
||||
:return: config or empty dict
|
||||
"""
|
||||
config = {}
|
||||
if 'commands' in full_config:
|
||||
config['commands'] = full_config.pop('commands')
|
||||
if 'commands_prefix' in full_config:
|
||||
config['prefixes'] = full_config.pop('commands_prefix')
|
||||
if 'commands_ignore_mention' in full_config:
|
||||
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
||||
return config
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
|
||||
|
||||
@staticmethod
|
||||
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
|
||||
full_command = message.text.split()[0]
|
||||
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
||||
|
||||
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
||||
return False
|
||||
elif prefix not in prefixes:
|
||||
return False
|
||||
elif (command.lower() if ignore_case else command) not in commands:
|
||||
return False
|
||||
|
||||
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||
|
||||
@dataclass
|
||||
class CommandObj:
|
||||
prefix: str = '/'
|
||||
command: str = ''
|
||||
mention: str = None
|
||||
args: str = field(repr=False, default=None)
|
||||
|
||||
@property
|
||||
def mentioned(self) -> bool:
|
||||
return bool(self.mention)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
line = self.prefix + self.command
|
||||
if self.mentioned:
|
||||
line += '@' + self.mention
|
||||
if self.args:
|
||||
line += ' ' + self.args
|
||||
return line
|
||||
|
||||
|
||||
class CommandStart(Command):
|
||||
def __init__(self):
|
||||
super(CommandStart, self).__init__(['start'])
|
||||
|
||||
|
||||
class CommandHelp(Command):
|
||||
def __init__(self):
|
||||
super(CommandHelp, self).__init__(['help'])
|
||||
|
||||
|
||||
class Text(Filter):
|
||||
"""
|
||||
Simple text filter
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[str] = None,
|
||||
startswith: Optional[str] = None,
|
||||
endswith: Optional[str] = None,
|
||||
ignore_case=False):
|
||||
"""
|
||||
Check text for one of pattern. Only one mode can be used in one filter.
|
||||
|
||||
:param equals:
|
||||
:param contains:
|
||||
:param startswith:
|
||||
:param endswith:
|
||||
:param ignore_case: case insensitive
|
||||
"""
|
||||
# Only one mode can be used. check it.
|
||||
check = sum(map(bool, (equals, contains, startswith, endswith)))
|
||||
if check > 1:
|
||||
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
||||
('contains', contains),
|
||||
('startswith', startswith),
|
||||
('endswith', endswith)
|
||||
] if arg[1]])
|
||||
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||
elif check == 0:
|
||||
raise ValueError(f"No one mode is specified!")
|
||||
|
||||
self.equals = equals
|
||||
self.contains = contains
|
||||
self.endswith = endswith
|
||||
self.startswith = startswith
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
if 'text' in full_config:
|
||||
return {'equals': full_config.pop('text')}
|
||||
elif 'text_contains' in full_config:
|
||||
return {'contains': full_config.pop('text_contains')}
|
||||
elif 'text_startswith' in full_config:
|
||||
return {'startswith': full_config.pop('text_startswith')}
|
||||
elif 'text_endswith' in full_config:
|
||||
return {'endswith': full_config.pop('text_endswith')}
|
||||
|
||||
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||
if isinstance(obj, Message):
|
||||
text = obj.text or obj.caption or ''
|
||||
elif isinstance(obj, CallbackQuery):
|
||||
text = obj.data
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.equals:
|
||||
return text == self.equals
|
||||
elif self.contains:
|
||||
return self.contains in text
|
||||
elif self.startswith:
|
||||
return text.startswith(self.startswith)
|
||||
elif self.endswith:
|
||||
return text.endswith(self.endswith)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Regexp(Filter):
|
||||
"""
|
||||
Regexp filter for messages and callback query
|
||||
"""
|
||||
|
||||
def __init__(self, regexp):
|
||||
if not isinstance(regexp, re.Pattern):
|
||||
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
self.regexp = regexp
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
if 'regexp' in full_config:
|
||||
return {'regexp': full_config.pop('regexp')}
|
||||
|
||||
async def check(self, obj: Union[Message, CallbackQuery]):
|
||||
if isinstance(obj, Message):
|
||||
match = self.regexp.search(obj.text or obj.caption or '')
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
match = self.regexp.search(obj.data)
|
||||
else:
|
||||
return False
|
||||
|
||||
if match:
|
||||
return {'regexp': match}
|
||||
return False
|
||||
|
||||
|
||||
class RegexpCommandsFilter(BoundFilter):
|
||||
"""
|
||||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
key = 'regexp_commands'
|
||||
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
||||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
return {'regexp_command': search}
|
||||
return False
|
||||
|
||||
|
||||
class ContentTypeFilter(BoundFilter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
||||
key = 'content_types'
|
||||
required = True
|
||||
default = types.ContentTypes.TEXT
|
||||
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
|
||||
async def check(self, message):
|
||||
return types.ContentType.ANY in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
|
||||
|
||||
class StateFilter(BoundFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
key = 'state'
|
||||
required = True
|
||||
|
||||
ctx_state = ContextVar('user_state')
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
states = []
|
||||
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||
state = [state, ]
|
||||
for item in state:
|
||||
if isinstance(item, State):
|
||||
states.append(item.state)
|
||||
elif inspect.isclass(item) and issubclass(item, StatesGroup):
|
||||
states.extend(item.all_states_names)
|
||||
else:
|
||||
states.append(item)
|
||||
self.states = states
|
||||
|
||||
def get_target(self, obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
|
||||
async def check(self, obj):
|
||||
if '*' in self.states:
|
||||
return {'state': self.dispatcher.current_state()}
|
||||
|
||||
try:
|
||||
state = self.ctx_state.get()
|
||||
except LookupError:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
||||
self.ctx_state.set(state)
|
||||
if state in self.states:
|
||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||
|
||||
else:
|
||||
if state in self.states:
|
||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ExceptionsFilter(BoundFilter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
||||
key = 'exception'
|
||||
|
||||
def __init__(self, dispatcher, exception):
|
||||
super().__init__(dispatcher)
|
||||
self.exception = exception
|
||||
|
||||
async def check(self, dispatcher, update, exception):
|
||||
try:
|
||||
raise exception
|
||||
except self.exception:
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
73
aiogram/dispatcher/filters/factory.py
Normal file
73
aiogram/dispatcher/filters/factory.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
import typing
|
||||
|
||||
from .filters import AbstractFilter, FilterRecord
|
||||
from ..handler import Handler
|
||||
|
||||
|
||||
class FiltersFactory:
|
||||
"""
|
||||
Default filters factory
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher):
|
||||
self._dispatcher = dispatcher
|
||||
self._registered: typing.List[FilterRecord] = []
|
||||
|
||||
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.List[Handler]] = None,
|
||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||
"""
|
||||
Register filter
|
||||
|
||||
:param callback: callable or subclass of :obj:`AbstractFilter`
|
||||
:param validator: custom validator.
|
||||
:param event_handlers: list of instances of :obj:`Handler`
|
||||
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
|
||||
"""
|
||||
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
|
||||
self._registered.append(record)
|
||||
|
||||
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
||||
"""
|
||||
Unregister callback
|
||||
|
||||
:param callback: callable of subclass of :obj:`AbstractFilter`
|
||||
"""
|
||||
for record in self._registered:
|
||||
if record.callback == callback:
|
||||
self._registered.remove(record)
|
||||
|
||||
def resolve(self, event_handler, *custom_filters, **full_config
|
||||
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
|
||||
"""
|
||||
Resolve filters to filters-set
|
||||
|
||||
:param event_handler:
|
||||
:param custom_filters:
|
||||
:param full_config:
|
||||
:return:
|
||||
"""
|
||||
filters_set = []
|
||||
filters_set.extend(self._resolve_registered(event_handler,
|
||||
{k: v for k, v in full_config.items() if v is not None}))
|
||||
if custom_filters:
|
||||
filters_set.extend(custom_filters)
|
||||
|
||||
return filters_set
|
||||
|
||||
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
|
||||
"""
|
||||
Resolve registered filters
|
||||
|
||||
:param event_handler:
|
||||
:param full_config:
|
||||
:return:
|
||||
"""
|
||||
for record in self._registered:
|
||||
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
|
||||
if filter_:
|
||||
yield filter_
|
||||
|
||||
if full_config:
|
||||
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')
|
||||
250
aiogram/dispatcher/filters/filters.py
Normal file
250
aiogram/dispatcher/filters/filters.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
import abc
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from ..handler import Handler
|
||||
from ...types.base import TelegramObject
|
||||
|
||||
|
||||
class FilterNotPassed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def wrap_async(func):
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isawaitable(func) \
|
||||
or inspect.iscoroutinefunction(func) \
|
||||
or isinstance(func, AbstractFilter):
|
||||
return func
|
||||
return async_wrapper
|
||||
|
||||
|
||||
async def check_filter(dispatcher, filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param dispatcher:
|
||||
:param filter_:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
kwargs = {}
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
spec = inspect.getfullargspec(filter_)
|
||||
if 'dispatcher' in spec:
|
||||
kwargs['dispatcher'] = dispatcher
|
||||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, AbstractFilter):
|
||||
return await filter_(*args, **kwargs)
|
||||
else:
|
||||
return filter_(*args, **kwargs)
|
||||
|
||||
|
||||
async def check_filters(dispatcher, filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param dispatcher:
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
data = {}
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(dispatcher, filter_, args)
|
||||
if not f:
|
||||
raise FilterNotPassed()
|
||||
elif isinstance(f, dict):
|
||||
data.update(f)
|
||||
return data
|
||||
|
||||
|
||||
class FilterRecord:
|
||||
"""
|
||||
Filters record for factory
|
||||
"""
|
||||
|
||||
def __init__(self, callback: typing.Callable,
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||
if event_handlers and exclude_event_handlers:
|
||||
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
|
||||
|
||||
self.callback = callback
|
||||
self.event_handlers = event_handlers
|
||||
self.exclude_event_handlers = exclude_event_handlers
|
||||
|
||||
if validator is not None:
|
||||
if not callable(validator):
|
||||
raise TypeError(f"validator must be callable, not {type(validator)}")
|
||||
self.resolver = validator
|
||||
elif issubclass(callback, AbstractFilter):
|
||||
self.resolver = callback.validate
|
||||
else:
|
||||
raise RuntimeError('validator is required!')
|
||||
|
||||
def resolve(self, dispatcher, event_handler, full_config):
|
||||
if not self._check_event_handler(event_handler):
|
||||
return
|
||||
config = self.resolver(full_config)
|
||||
if config:
|
||||
if 'dispatcher' not in config:
|
||||
spec = inspect.getfullargspec(self.callback)
|
||||
if 'dispatcher' in spec.args:
|
||||
config['dispatcher'] = dispatcher
|
||||
|
||||
for key in config:
|
||||
if key in full_config:
|
||||
full_config.pop(key)
|
||||
|
||||
return self.callback(**config)
|
||||
|
||||
def _check_event_handler(self, event_handler) -> bool:
|
||||
if self.event_handlers:
|
||||
return event_handler in self.event_handlers
|
||||
elif self.exclude_event_handlers:
|
||||
return event_handler not in self.exclude_event_handlers
|
||||
return True
|
||||
|
||||
|
||||
class AbstractFilter(abc.ABC):
|
||||
"""
|
||||
Abstract class for custom filters
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
"""
|
||||
Validate and parse config
|
||||
|
||||
:param full_config:
|
||||
:return: config
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check(self, *args) -> bool:
|
||||
"""
|
||||
Check object
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def __call__(self, obj: TelegramObject) -> bool:
|
||||
return await self.check(obj)
|
||||
|
||||
def __invert__(self):
|
||||
return NotFilter(self)
|
||||
|
||||
def __and__(self, other):
|
||||
if isinstance(self, AndFilter):
|
||||
self.append(other)
|
||||
return self
|
||||
return AndFilter(self, other)
|
||||
|
||||
def __or__(self, other):
|
||||
if isinstance(self, OrFilter):
|
||||
self.append(other)
|
||||
return self
|
||||
return OrFilter(self, other)
|
||||
|
||||
|
||||
class Filter(AbstractFilter):
|
||||
"""
|
||||
You can make subclasses of that class for custom filters
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
pass
|
||||
|
||||
|
||||
class BoundFilter(Filter):
|
||||
"""
|
||||
Base class for filters with default validator
|
||||
"""
|
||||
key = None
|
||||
required = False
|
||||
default = None
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||
if cls.key is not None:
|
||||
if cls.key in full_config:
|
||||
return {cls.key: full_config[cls.key]}
|
||||
elif cls.required:
|
||||
return {cls.key: cls.default}
|
||||
|
||||
|
||||
class _LogicFilter(Filter):
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||
raise ValueError('That filter can\'t be used in filters factory!')
|
||||
|
||||
|
||||
class NotFilter(_LogicFilter):
|
||||
def __init__(self, target):
|
||||
self.target = wrap_async(target)
|
||||
|
||||
async def check(self, *args):
|
||||
return not bool(await self.target(*args))
|
||||
|
||||
|
||||
class AndFilter(_LogicFilter):
|
||||
|
||||
def __init__(self, *targets):
|
||||
self.targets = list(wrap_async(target) for target in targets)
|
||||
|
||||
async def check(self, *args):
|
||||
"""
|
||||
All filters must return a positive result
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
data = {}
|
||||
for target in self.targets:
|
||||
result = await target(*args)
|
||||
if not result:
|
||||
return False
|
||||
if isinstance(result, dict):
|
||||
data.update(result)
|
||||
if not data:
|
||||
return True
|
||||
return data
|
||||
|
||||
def append(self, target):
|
||||
self.targets.append(wrap_async(target))
|
||||
|
||||
|
||||
class OrFilter(_LogicFilter):
|
||||
def __init__(self, *targets):
|
||||
self.targets = list(wrap_async(target) for target in targets)
|
||||
|
||||
async def check(self, *args):
|
||||
"""
|
||||
One of filters must return a positive result
|
||||
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
for target in self.targets:
|
||||
result = await target(*args)
|
||||
if result:
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return True
|
||||
return False
|
||||
|
||||
def append(self, target):
|
||||
self.targets.append(wrap_async(target))
|
||||
198
aiogram/dispatcher/filters/state.py
Normal file
198
aiogram/dispatcher/filters/state.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from ..dispatcher import Dispatcher
|
||||
|
||||
|
||||
class State:
|
||||
"""
|
||||
State object
|
||||
"""
|
||||
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
|
||||
self._state = state
|
||||
self._group_name = group_name
|
||||
self._group = None
|
||||
|
||||
@property
|
||||
def group(self):
|
||||
if not self._group:
|
||||
raise RuntimeError('This state is not in any group.')
|
||||
return self._group
|
||||
|
||||
def get_root(self):
|
||||
return self.group.get_root()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self._state is None:
|
||||
return None
|
||||
elif self._state == '*':
|
||||
return self._state
|
||||
elif self._group_name is None and self._group:
|
||||
group = self._group.__full_group_name__
|
||||
elif self._group_name:
|
||||
group = self._group_name
|
||||
else:
|
||||
group = '@'
|
||||
return f"{group}:{self._state}"
|
||||
|
||||
def set_parent(self, group):
|
||||
if not issubclass(group, StatesGroup):
|
||||
raise ValueError('Group must be subclass of StatesGroup')
|
||||
self._group = group
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
if self._state is None:
|
||||
self._state = name
|
||||
self.set_parent(owner)
|
||||
|
||||
def __str__(self):
|
||||
return f"<State '{self.state or ''}'>"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
async def set(self):
|
||||
state = Dispatcher.current().current_state()
|
||||
await state.set_state(self.state)
|
||||
|
||||
|
||||
class StatesGroupMeta(type):
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
|
||||
|
||||
states = []
|
||||
childs = []
|
||||
|
||||
cls._group_name = name
|
||||
|
||||
for name, prop in namespace.items():
|
||||
|
||||
if isinstance(prop, State):
|
||||
states.append(prop)
|
||||
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
|
||||
childs.append(prop)
|
||||
prop._parent = cls
|
||||
# continue
|
||||
|
||||
cls._parent = None
|
||||
cls._childs = tuple(childs)
|
||||
cls._states = tuple(states)
|
||||
cls._state_names = tuple(state.state for state in states)
|
||||
|
||||
return cls
|
||||
|
||||
@property
|
||||
def __group_name__(cls):
|
||||
return cls._group_name
|
||||
|
||||
@property
|
||||
def __full_group_name__(cls):
|
||||
if cls._parent:
|
||||
return cls._parent.__full_group_name__ + '.' + cls._group_name
|
||||
return cls._group_name
|
||||
|
||||
@property
|
||||
def states(cls) -> tuple:
|
||||
return cls._states
|
||||
|
||||
@property
|
||||
def childs(cls):
|
||||
return cls._childs
|
||||
|
||||
@property
|
||||
def all_childs(cls):
|
||||
result = cls.childs
|
||||
for child in cls.childs:
|
||||
result += child.childs
|
||||
return result
|
||||
|
||||
@property
|
||||
def all_states(cls):
|
||||
result = cls.states
|
||||
for group in cls.childs:
|
||||
result += group.all_states
|
||||
return result
|
||||
|
||||
@property
|
||||
def all_states_names(cls):
|
||||
return tuple(state.state for state in cls.all_states)
|
||||
|
||||
@property
|
||||
def states_names(cls) -> tuple:
|
||||
return tuple(state.state for state in cls.states)
|
||||
|
||||
def get_root(cls):
|
||||
if cls._parent is None:
|
||||
return cls
|
||||
return cls._parent.get_root()
|
||||
|
||||
def __contains__(cls, item):
|
||||
if isinstance(item, str):
|
||||
return item in cls.all_states_names
|
||||
elif isinstance(item, State):
|
||||
return item in cls.all_states
|
||||
elif isinstance(item, StatesGroup):
|
||||
return item in cls.all_childs
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return f"<StatesGroup '{self.__full_group_name__}'>"
|
||||
|
||||
|
||||
class StatesGroup(metaclass=StatesGroupMeta):
|
||||
@classmethod
|
||||
async def next(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
next_step = cls.states_names.index(state_name) + 1
|
||||
except ValueError:
|
||||
next_step = 0
|
||||
|
||||
try:
|
||||
next_state_name = cls.states[next_step].state
|
||||
except IndexError:
|
||||
next_state_name = None
|
||||
|
||||
await state.set_state(next_state_name)
|
||||
return next_state_name
|
||||
|
||||
@classmethod
|
||||
async def previous(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
previous_step = cls.states_names.index(state_name) - 1
|
||||
except ValueError:
|
||||
previous_step = 0
|
||||
|
||||
if previous_step < 0:
|
||||
previous_state_name = None
|
||||
else:
|
||||
previous_state_name = cls.states[previous_step].state
|
||||
|
||||
await state.set_state(previous_state_name)
|
||||
return previous_state_name
|
||||
|
||||
@classmethod
|
||||
async def first(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
first_step_name = cls.states_names[0]
|
||||
|
||||
await state.set_state(first_step_name)
|
||||
return first_step_name
|
||||
|
||||
@classmethod
|
||||
async def last(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
last_step_name = cls.states_names[-1]
|
||||
|
||||
await state.set_state(last_step_name)
|
||||
return last_step_name
|
||||
|
||||
|
||||
default_state = State()
|
||||
any_state = State(state='*')
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
from .filters import check_filters
|
||||
from ..utils import context
|
||||
import inspect
|
||||
from contextvars import ContextVar
|
||||
|
||||
ctx_data = ContextVar('ctx_handler_data')
|
||||
|
||||
|
||||
class SkipHandler(BaseException):
|
||||
|
|
@ -10,6 +12,14 @@ class CancelHandler(BaseException):
|
|||
pass
|
||||
|
||||
|
||||
def _check_spec(func: callable, kwargs: dict):
|
||||
spec = inspect.getfullargspec(func)
|
||||
if spec.varkw:
|
||||
return kwargs
|
||||
|
||||
return {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||
self.dispatcher = dispatcher
|
||||
|
|
@ -57,31 +67,43 @@ class Handler:
|
|||
:param args:
|
||||
:return:
|
||||
"""
|
||||
from .filters import check_filters, FilterNotPassed
|
||||
|
||||
results = []
|
||||
|
||||
data = {}
|
||||
ctx_data.set(data)
|
||||
|
||||
if self.middleware_key:
|
||||
try:
|
||||
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
||||
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
|
||||
except CancelHandler: # Allow to cancel current event
|
||||
return results
|
||||
|
||||
for filters, handler in self.handlers:
|
||||
if await check_filters(filters, args):
|
||||
try:
|
||||
for filters, handler in self.handlers:
|
||||
try:
|
||||
if self.middleware_key:
|
||||
context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
response = await handler(*args)
|
||||
if response is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
break
|
||||
except SkipHandler:
|
||||
data.update(await check_filters(self.dispatcher, filters, args))
|
||||
except FilterNotPassed:
|
||||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results,))
|
||||
else:
|
||||
try:
|
||||
if self.middleware_key:
|
||||
# context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
|
||||
partial_data = _check_spec(handler, data)
|
||||
response = await handler(*args, **partial_data)
|
||||
if response is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
break
|
||||
except SkipHandler:
|
||||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
finally:
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results, data,))
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -101,3 +101,28 @@ class BaseMiddleware:
|
|||
if not handler:
|
||||
return None
|
||||
await handler(*args)
|
||||
|
||||
|
||||
class LifetimeControllerMiddleware(BaseMiddleware):
|
||||
# TODO: Rename class
|
||||
|
||||
skip_patterns = None
|
||||
|
||||
async def pre_process(self, obj, data, *args):
|
||||
pass
|
||||
|
||||
async def post_process(self, obj, data, *args):
|
||||
pass
|
||||
|
||||
async def trigger(self, action, args):
|
||||
if self.skip_patterns is not None and any(item in action for item in self.skip_patterns):
|
||||
return False
|
||||
|
||||
obj, *args, data = args
|
||||
if action.startswith('pre_process_'):
|
||||
await self.pre_process(obj, data, *args)
|
||||
elif action.startswith('post_process_'):
|
||||
await self.post_process(obj, data, *args)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -281,8 +281,20 @@ class FSMContext:
|
|||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _resolve_state(value):
|
||||
from .filters.state import State
|
||||
|
||||
if value is None:
|
||||
return
|
||||
elif isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, State):
|
||||
return value.state
|
||||
return str(value)
|
||||
|
||||
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
|
||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
|
||||
|
||||
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
|
||||
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
|
||||
|
|
@ -291,7 +303,7 @@ class FSMContext:
|
|||
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
|
||||
|
||||
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
|
||||
await self.storage.set_state(chat=self.chat, user=self.user, state=state)
|
||||
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
|
||||
|
||||
async def set_data(self, data: typing.Dict = None):
|
||||
await self.storage.set_data(chat=self.chat, user=self.user, data=data)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union
|
|||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPGone
|
||||
|
||||
|
||||
from .. import types
|
||||
from ..bot import api
|
||||
from ..types import ParseMode
|
||||
from ..types.base import Boolean, Float, Integer, String
|
||||
from ..utils import context
|
||||
from ..utils import helper, markdown
|
||||
from ..utils import json
|
||||
from ..utils.deprecated import warn_deprecated as warn
|
||||
|
|
@ -89,8 +89,10 @@ class WebhookRequestHandler(web.View):
|
|||
"""
|
||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||
try:
|
||||
context.set_value('dispatcher', dp)
|
||||
context.set_value('bot', dp.bot)
|
||||
from aiogram.bot import bot
|
||||
from aiogram.dispatcher import dispatcher
|
||||
dispatcher.set(dp)
|
||||
bot.bot.set(dp.bot)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return dp
|
||||
|
|
@ -117,9 +119,9 @@ class WebhookRequestHandler(web.View):
|
|||
"""
|
||||
self.validate_ip()
|
||||
|
||||
context.update_state({'CALLER': WEBHOOK,
|
||||
WEBHOOK_CONNECTION: True,
|
||||
WEBHOOK_REQUEST: self.request})
|
||||
# context.update_state({'CALLER': WEBHOOK,
|
||||
# WEBHOOK_CONNECTION: True,
|
||||
# WEBHOOK_REQUEST: self.request})
|
||||
|
||||
dispatcher = self.get_dispatcher()
|
||||
update = await self.parse_update(dispatcher.bot)
|
||||
|
|
@ -177,7 +179,7 @@ class WebhookRequestHandler(web.View):
|
|||
if fut.done():
|
||||
return fut.result()
|
||||
else:
|
||||
context.set_value(WEBHOOK_CONNECTION, False)
|
||||
# context.set_value(WEBHOOK_CONNECTION, False)
|
||||
fut.remove_done_callback(cb)
|
||||
fut.add_done_callback(self.respond_via_request)
|
||||
finally:
|
||||
|
|
@ -202,7 +204,7 @@ class WebhookRequestHandler(web.View):
|
|||
results = task.result()
|
||||
except Exception as e:
|
||||
loop.create_task(
|
||||
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
||||
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
||||
else:
|
||||
response = self.get_response(results)
|
||||
if response is not None:
|
||||
|
|
@ -249,7 +251,7 @@ class WebhookRequestHandler(web.View):
|
|||
ip_address, accept = self.check_ip()
|
||||
if not accept:
|
||||
raise web.HTTPUnauthorized()
|
||||
context.set_value('TELEGRAM_IP', ip_address)
|
||||
# context.set_value('TELEGRAM_IP', ip_address)
|
||||
|
||||
|
||||
class GoneRequestHandler(web.View):
|
||||
|
|
@ -352,8 +354,8 @@ class BaseResponse:
|
|||
|
||||
async def __call__(self, bot=None):
|
||||
if bot is None:
|
||||
from aiogram.dispatcher import ctx
|
||||
bot = ctx.get_bot()
|
||||
from aiogram import Bot
|
||||
bot = Bot.current()
|
||||
return await self.execute_response(bot)
|
||||
|
||||
async def __aenter__(self):
|
||||
|
|
@ -446,7 +448,8 @@ class ParseModeMixin:
|
|||
|
||||
:return:
|
||||
"""
|
||||
bot = context.get_value('bot', None)
|
||||
from aiogram import Bot
|
||||
bot = Bot.current()
|
||||
if bot is not None:
|
||||
return bot.parse_mode
|
||||
|
||||
|
|
@ -952,7 +955,7 @@ class SendMediaGroup(BaseResponse, ReplyToMixin, DisableNotificationMixin):
|
|||
self.reply_to_message_id = reply_to_message_id
|
||||
|
||||
def prepare(self):
|
||||
files = self.media.get_files()
|
||||
files = dict(self.media.get_files())
|
||||
if files:
|
||||
raise TypeError('Allowed only file ID or URL\'s')
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from .invoice import Invoice
|
|||
from .labeled_price import LabeledPrice
|
||||
from .location import Location
|
||||
from .mask_position import MaskPosition
|
||||
from .message import ContentType, Message, ParseMode
|
||||
from .message import ContentType, ContentTypes, Message, ParseMode
|
||||
from .message_entity import MessageEntity, MessageEntityType
|
||||
from .order_info import OrderInfo
|
||||
from .passport_data import PassportData
|
||||
|
|
@ -77,6 +77,7 @@ __all__ = (
|
|||
'ChosenInlineResult',
|
||||
'Contact',
|
||||
'ContentType',
|
||||
'ContentTypes',
|
||||
'Document',
|
||||
'EncryptedCredentials',
|
||||
'EncryptedPassportElement',
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
from typing import TypeVar
|
||||
|
||||
from .fields import BaseField
|
||||
|
|
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
|
|||
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
||||
|
||||
mcs._objects[cls.__name__] = cls
|
||||
|
||||
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
||||
return cls
|
||||
|
||||
@property
|
||||
|
|
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
if value.default and key not in self.values:
|
||||
self.values[key] = value.default
|
||||
|
||||
@classmethod
|
||||
def current(cls):
|
||||
return cls._current.get()
|
||||
|
||||
@classmethod
|
||||
def set_current(cls, obj: TelegramObject):
|
||||
return cls._current.set(obj)
|
||||
|
||||
@property
|
||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||
return self._conf
|
||||
|
|
@ -137,8 +150,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
|
||||
@property
|
||||
def bot(self):
|
||||
from ..dispatcher import ctx
|
||||
return ctx.get_bot()
|
||||
from ..bot.bot import Bot
|
||||
return Bot.current()
|
||||
|
||||
def to_python(self) -> typing.Dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
|
||||
from . import base
|
||||
from . import fields
|
||||
|
|
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
|
|||
if as_html:
|
||||
return markdown.hlink(name, self.user_url)
|
||||
return markdown.link(name, self.user_url)
|
||||
|
||||
|
||||
async def get_url(self):
|
||||
"""
|
||||
Use this method to get chat link.
|
||||
|
|
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
|
|||
|
||||
@classmethod
|
||||
async def _do(cls, action: str, sleep=None):
|
||||
from ..dispatcher.ctx import get_bot, get_chat
|
||||
await get_bot().send_chat_action(get_chat(), action)
|
||||
from aiogram import Bot
|
||||
await Bot.current().send_chat_action(Chat.current(), action)
|
||||
if sleep:
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class BaseField(metaclass=abc.ABCMeta):
|
|||
Base field (prop)
|
||||
"""
|
||||
|
||||
def __init__(self, *, base=None, default=None, alias=None):
|
||||
def __init__(self, *, base=None, default=None, alias=None, on_change=None):
|
||||
"""
|
||||
Init prop
|
||||
|
||||
|
|
@ -17,10 +17,12 @@ class BaseField(metaclass=abc.ABCMeta):
|
|||
:param default: default value
|
||||
:param alias: alias name (for e.g. field 'from' has to be named 'from_user'
|
||||
as 'from' is a builtin Python keyword
|
||||
:param on_change: callback will be called when value is changed
|
||||
"""
|
||||
self.base_object = base
|
||||
self.default = default
|
||||
self.alias = alias
|
||||
self.on_change = on_change
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
if self.alias is None:
|
||||
|
|
@ -53,6 +55,13 @@ class BaseField(metaclass=abc.ABCMeta):
|
|||
self.resolve_base(instance)
|
||||
value = self.deserialize(value, parent)
|
||||
instance.values[self.alias] = value
|
||||
self._trigger_changed(instance, value)
|
||||
|
||||
def _trigger_changed(self, instance, value):
|
||||
if not self.on_change and instance is not None:
|
||||
return
|
||||
callback = getattr(instance, self.on_change)
|
||||
callback(value)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.get_value(instance)
|
||||
|
|
@ -154,7 +163,7 @@ class ListOfLists(Field):
|
|||
return result
|
||||
|
||||
|
||||
class DateTimeField(BaseField):
|
||||
class DateTimeField(Field):
|
||||
"""
|
||||
In this field st_ored datetime
|
||||
|
||||
|
|
@ -167,3 +176,24 @@ class DateTimeField(BaseField):
|
|||
|
||||
def deserialize(self, value, parent=None):
|
||||
return datetime.datetime.fromtimestamp(value)
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, *, prefix=None, suffix=None, default=None, alias=None):
|
||||
super(TextField, self).__init__(default=default, alias=alias)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
|
||||
def serialize(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if self.prefix:
|
||||
value = self.prefix + value
|
||||
if self.suffix:
|
||||
value += self.suffix
|
||||
return value
|
||||
|
||||
def deserialize(self, value, parent=None):
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError(f"Field '{self.alias}' should be str not {type(value).__name__}")
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import io
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -45,6 +46,8 @@ class InputFile(base.TelegramObject):
|
|||
|
||||
self._filename = filename
|
||||
|
||||
self.attachment_key = secrets.token_urlsafe(16)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Close file descriptor
|
||||
|
|
@ -54,13 +57,17 @@ class InputFile(base.TelegramObject):
|
|||
@property
|
||||
def filename(self):
|
||||
if self._filename is None:
|
||||
self._filename = api._guess_filename(self._file)
|
||||
self._filename = api.guess_filename(self._file)
|
||||
return self._filename
|
||||
|
||||
@filename.setter
|
||||
def filename(self, value):
|
||||
self._filename = value
|
||||
|
||||
@property
|
||||
def attach(self):
|
||||
return f"attach://{self.attachment_key}"
|
||||
|
||||
def get_filename(self) -> str:
|
||||
"""
|
||||
Get file name
|
||||
|
|
@ -159,6 +166,9 @@ class InputFile(base.TelegramObject):
|
|||
|
||||
return writer
|
||||
|
||||
def __str__(self):
|
||||
return f"<InputFile 'attach://{self.attachment_key}' with file='{self.file}'>"
|
||||
|
||||
def to_python(self):
|
||||
raise TypeError('Object of this type is not exportable!')
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ ATTACHMENT_PREFIX = 'attach://'
|
|||
class InputMedia(base.TelegramObject):
|
||||
"""
|
||||
This object represents the content of a media message to be sent. It should be one of
|
||||
- InputMediaAnimation
|
||||
- InputMediaDocument
|
||||
- InputMediaAudio
|
||||
- InputMediaPhoto
|
||||
- InputMediaVideo
|
||||
|
||||
|
|
@ -20,36 +23,76 @@ class InputMedia(base.TelegramObject):
|
|||
https://core.telegram.org/bots/api#inputmedia
|
||||
"""
|
||||
type: base.String = fields.Field(default='photo')
|
||||
media: base.String = fields.Field()
|
||||
thumb: typing.Union[base.InputFile, base.String] = fields.Field()
|
||||
media: base.String = fields.Field(alias='media', on_change='_media_changed')
|
||||
thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed')
|
||||
caption: base.String = fields.Field()
|
||||
parse_mode: base.Boolean = fields.Field()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._thumb_file = None
|
||||
self._media_file = None
|
||||
|
||||
media = kwargs.pop('media', None)
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
elif media is not None:
|
||||
self.media = media
|
||||
|
||||
thumb = kwargs.pop('thumb', None)
|
||||
if isinstance(thumb, (io.IOBase, InputFile)):
|
||||
self.thumb_file = thumb
|
||||
elif thumb is not None:
|
||||
self.thumb = thumb
|
||||
|
||||
super(InputMedia, self).__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
if self.parse_mode is None and self.bot.parse_mode:
|
||||
if self.parse_mode is None and self.bot and self.bot.parse_mode:
|
||||
self.parse_mode = self.bot.parse_mode
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def file(self):
|
||||
return getattr(self, '_file', None)
|
||||
return self._media_file
|
||||
|
||||
@file.setter
|
||||
def file(self, file: io.IOBase):
|
||||
setattr(self, '_file', file)
|
||||
attachment_key = self.attachment_key = secrets.token_urlsafe(16)
|
||||
self.media = ATTACHMENT_PREFIX + attachment_key
|
||||
self.media = 'attach://' + secrets.token_urlsafe(16)
|
||||
self._media_file = file
|
||||
|
||||
@file.deleter
|
||||
def file(self):
|
||||
self.media = None
|
||||
self._media_file = None
|
||||
|
||||
def _media_changed(self, value):
|
||||
if value is None or isinstance(value, str) and not value.startswith('attach://'):
|
||||
self._media_file = None
|
||||
|
||||
@property
|
||||
def attachment_key(self):
|
||||
return self.conf.get('attachment_key', None)
|
||||
def thumb_file(self):
|
||||
return self._thumb_file
|
||||
|
||||
@attachment_key.setter
|
||||
def attachment_key(self, value):
|
||||
self.conf['attachment_key'] = value
|
||||
@thumb_file.setter
|
||||
def thumb_file(self, file: io.IOBase):
|
||||
self.thumb = 'attach://' + secrets.token_urlsafe(16)
|
||||
self._thumb_file = file
|
||||
|
||||
@thumb_file.deleter
|
||||
def thumb_file(self):
|
||||
self.thumb = None
|
||||
self._thumb_file = None
|
||||
|
||||
def _thumb_changed(self, value):
|
||||
if value is None or isinstance(value, str) and not value.startswith('attach://'):
|
||||
self._thumb_file = None
|
||||
|
||||
def get_files(self):
|
||||
if self._media_file:
|
||||
yield self.media[9:], self._media_file
|
||||
if self._thumb_file:
|
||||
yield self.thumb[9:], self._thumb_file
|
||||
|
||||
|
||||
class InputMediaAnimation(InputMedia):
|
||||
|
|
@ -72,9 +115,6 @@ class InputMediaAnimation(InputMedia):
|
|||
width=width, height=height, duration=duration,
|
||||
parse_mode=parse_mode, conf=kwargs)
|
||||
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
|
||||
|
||||
class InputMediaDocument(InputMedia):
|
||||
"""
|
||||
|
|
@ -89,9 +129,6 @@ class InputMediaDocument(InputMedia):
|
|||
caption=caption, parse_mode=parse_mode,
|
||||
conf=kwargs)
|
||||
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
|
||||
|
||||
class InputMediaAudio(InputMedia):
|
||||
"""
|
||||
|
|
@ -119,9 +156,6 @@ class InputMediaAudio(InputMedia):
|
|||
performer=performer, title=title,
|
||||
parse_mode=parse_mode, conf=kwargs)
|
||||
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
|
||||
|
||||
class InputMediaPhoto(InputMedia):
|
||||
"""
|
||||
|
|
@ -136,9 +170,6 @@ class InputMediaPhoto(InputMedia):
|
|||
caption=caption, parse_mode=parse_mode,
|
||||
conf=kwargs)
|
||||
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
|
||||
|
||||
class InputMediaVideo(InputMedia):
|
||||
"""
|
||||
|
|
@ -151,18 +182,17 @@ class InputMediaVideo(InputMedia):
|
|||
duration: base.Integer = fields.Field()
|
||||
supports_streaming: base.Boolean = fields.Field()
|
||||
|
||||
def __init__(self, media: base.InputFile, caption: base.String = None,
|
||||
def __init__(self, media: base.InputFile,
|
||||
thumb: typing.Union[base.InputFile, base.String] = None,
|
||||
caption: base.String = None,
|
||||
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
|
||||
parse_mode: base.Boolean = None,
|
||||
supports_streaming: base.Boolean = None, **kwargs):
|
||||
super(InputMediaVideo, self).__init__(type='video', media=media, caption=caption,
|
||||
super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption,
|
||||
width=width, height=height, duration=duration,
|
||||
parse_mode=parse_mode,
|
||||
supports_streaming=supports_streaming, conf=kwargs)
|
||||
|
||||
if isinstance(media, (io.IOBase, InputFile)):
|
||||
self.file = media
|
||||
|
||||
|
||||
class MediaGroup(base.TelegramObject):
|
||||
"""
|
||||
|
|
@ -296,6 +326,7 @@ class MediaGroup(base.TelegramObject):
|
|||
self.attach(photo)
|
||||
|
||||
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
|
||||
thumb: typing.Union[base.InputFile, base.String] = None,
|
||||
caption: base.String = None,
|
||||
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
|
||||
"""
|
||||
|
|
@ -308,7 +339,7 @@ class MediaGroup(base.TelegramObject):
|
|||
:param duration:
|
||||
"""
|
||||
if not isinstance(video, InputMedia):
|
||||
video = InputMediaVideo(media=video, caption=caption,
|
||||
video = InputMediaVideo(media=video, thumb=thumb, caption=caption,
|
||||
width=width, height=height, duration=duration)
|
||||
self.attach(video)
|
||||
|
||||
|
|
@ -327,6 +358,7 @@ class MediaGroup(base.TelegramObject):
|
|||
return result
|
||||
|
||||
def get_files(self):
|
||||
return {inputmedia.attachment_key: inputmedia.file
|
||||
for inputmedia in self.media
|
||||
if isinstance(inputmedia, InputMedia) and inputmedia.file}
|
||||
for inputmedia in self.media:
|
||||
if not isinstance(inputmedia, InputMedia) or not inputmedia.file:
|
||||
continue
|
||||
yield from inputmedia.get_files()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import sys
|
||||
|
|
@ -42,7 +44,7 @@ class Message(base.TelegramObject):
|
|||
forward_from_message_id: base.Integer = fields.Field()
|
||||
forward_signature: base.String = fields.Field()
|
||||
forward_date: datetime.datetime = fields.DateTimeField()
|
||||
reply_to_message: 'Message' = fields.Field(base='Message')
|
||||
reply_to_message: Message = fields.Field(base='Message')
|
||||
edit_date: datetime.datetime = fields.DateTimeField()
|
||||
media_group_id: base.String = fields.Field()
|
||||
author_signature: base.String = fields.Field()
|
||||
|
|
@ -72,7 +74,7 @@ class Message(base.TelegramObject):
|
|||
channel_chat_created: base.Boolean = fields.Field()
|
||||
migrate_to_chat_id: base.Integer = fields.Field()
|
||||
migrate_from_chat_id: base.Integer = fields.Field()
|
||||
pinned_message: 'Message' = fields.Field(base='Message')
|
||||
pinned_message: Message = fields.Field(base='Message')
|
||||
invoice: Invoice = fields.Field(base=Invoice)
|
||||
successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment)
|
||||
connected_website: base.String = fields.Field()
|
||||
|
|
@ -82,59 +84,59 @@ class Message(base.TelegramObject):
|
|||
@functools.lru_cache()
|
||||
def content_type(self):
|
||||
if self.text:
|
||||
return ContentType.TEXT[0]
|
||||
return ContentType.TEXT
|
||||
elif self.audio:
|
||||
return ContentType.AUDIO[0]
|
||||
return ContentType.AUDIO
|
||||
elif self.animation:
|
||||
return ContentType.ANIMATION[0]
|
||||
return ContentType.ANIMATION
|
||||
elif self.document:
|
||||
return ContentType.DOCUMENT[0]
|
||||
return ContentType.DOCUMENT
|
||||
elif self.game:
|
||||
return ContentType.GAME[0]
|
||||
return ContentType.GAME
|
||||
elif self.photo:
|
||||
return ContentType.PHOTO[0]
|
||||
return ContentType.PHOTO
|
||||
elif self.sticker:
|
||||
return ContentType.STICKER[0]
|
||||
return ContentType.STICKER
|
||||
elif self.video:
|
||||
return ContentType.VIDEO[0]
|
||||
return ContentType.VIDEO
|
||||
elif self.video_note:
|
||||
return ContentType.VIDEO_NOTE[0]
|
||||
return ContentType.VIDEO_NOTE
|
||||
elif self.voice:
|
||||
return ContentType.VOICE[0]
|
||||
return ContentType.VOICE
|
||||
elif self.contact:
|
||||
return ContentType.CONTACT[0]
|
||||
return ContentType.CONTACT
|
||||
elif self.venue:
|
||||
return ContentType.VENUE[0]
|
||||
return ContentType.VENUE
|
||||
elif self.location:
|
||||
return ContentType.LOCATION[0]
|
||||
return ContentType.LOCATION
|
||||
elif self.new_chat_members:
|
||||
return ContentType.NEW_CHAT_MEMBERS[0]
|
||||
return ContentType.NEW_CHAT_MEMBERS
|
||||
elif self.left_chat_member:
|
||||
return ContentType.LEFT_CHAT_MEMBER[0]
|
||||
return ContentType.LEFT_CHAT_MEMBER
|
||||
elif self.invoice:
|
||||
return ContentType.INVOICE[0]
|
||||
return ContentType.INVOICE
|
||||
elif self.successful_payment:
|
||||
return ContentType.SUCCESSFUL_PAYMENT[0]
|
||||
return ContentType.SUCCESSFUL_PAYMENT
|
||||
elif self.connected_website:
|
||||
return ContentType.CONNECTED_WEBSITE[0]
|
||||
return ContentType.CONNECTED_WEBSITE
|
||||
elif self.migrate_from_chat_id:
|
||||
return ContentType.MIGRATE_FROM_CHAT_ID[0]
|
||||
return ContentType.MIGRATE_FROM_CHAT_ID
|
||||
elif self.migrate_to_chat_id:
|
||||
return ContentType.MIGRATE_TO_CHAT_ID[0]
|
||||
return ContentType.MIGRATE_TO_CHAT_ID
|
||||
elif self.pinned_message:
|
||||
return ContentType.PINNED_MESSAGE[0]
|
||||
return ContentType.PINNED_MESSAGE
|
||||
elif self.new_chat_title:
|
||||
return ContentType.NEW_CHAT_TITLE[0]
|
||||
return ContentType.NEW_CHAT_TITLE
|
||||
elif self.new_chat_photo:
|
||||
return ContentType.NEW_CHAT_PHOTO[0]
|
||||
return ContentType.NEW_CHAT_PHOTO
|
||||
elif self.delete_chat_photo:
|
||||
return ContentType.DELETE_CHAT_PHOTO[0]
|
||||
return ContentType.DELETE_CHAT_PHOTO
|
||||
elif self.group_chat_created:
|
||||
return ContentType.GROUP_CHAT_CREATED[0]
|
||||
return ContentType.GROUP_CHAT_CREATED
|
||||
elif self.passport_data:
|
||||
return ContentType.PASSPORT_DATA[0]
|
||||
return ContentType.PASSPORT_DATA
|
||||
else:
|
||||
return ContentType.UNKNOWN[0]
|
||||
return ContentType.UNKNOWN
|
||||
|
||||
def is_command(self):
|
||||
"""
|
||||
|
|
@ -239,7 +241,7 @@ class Message(base.TelegramObject):
|
|||
return self.parse_entities()
|
||||
|
||||
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
||||
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
|
||||
disable_notification=None, reply_markup=None, reply=True) -> Message:
|
||||
"""
|
||||
Reply to this message
|
||||
|
||||
|
|
@ -729,6 +731,69 @@ class ContentType(helper.Helper):
|
|||
"""
|
||||
List of message content types
|
||||
|
||||
WARNING: Single elements
|
||||
|
||||
:key: TEXT
|
||||
:key: AUDIO
|
||||
:key: DOCUMENT
|
||||
:key: GAME
|
||||
:key: PHOTO
|
||||
:key: STICKER
|
||||
:key: VIDEO
|
||||
:key: VIDEO_NOTE
|
||||
:key: VOICE
|
||||
:key: CONTACT
|
||||
:key: LOCATION
|
||||
:key: VENUE
|
||||
:key: NEW_CHAT_MEMBERS
|
||||
:key: LEFT_CHAT_MEMBER
|
||||
:key: INVOICE
|
||||
:key: SUCCESSFUL_PAYMENT
|
||||
:key: CONNECTED_WEBSITE
|
||||
:key: MIGRATE_TO_CHAT_ID
|
||||
:key: MIGRATE_FROM_CHAT_ID
|
||||
:key: UNKNOWN
|
||||
:key: ANY
|
||||
"""
|
||||
mode = helper.HelperMode.snake_case
|
||||
|
||||
TEXT = helper.Item() # text
|
||||
AUDIO = helper.Item() # audio
|
||||
DOCUMENT = helper.Item() # document
|
||||
ANIMATION = helper.Item() # animation
|
||||
GAME = helper.Item() # game
|
||||
PHOTO = helper.Item() # photo
|
||||
STICKER = helper.Item() # sticker
|
||||
VIDEO = helper.Item() # video
|
||||
VIDEO_NOTE = helper.Item() # video_note
|
||||
VOICE = helper.Item() # voice
|
||||
CONTACT = helper.Item() # contact
|
||||
LOCATION = helper.Item() # location
|
||||
VENUE = helper.Item() # venue
|
||||
NEW_CHAT_MEMBERS = helper.Item() # new_chat_member
|
||||
LEFT_CHAT_MEMBER = helper.Item() # left_chat_member
|
||||
INVOICE = helper.Item() # invoice
|
||||
SUCCESSFUL_PAYMENT = helper.Item() # successful_payment
|
||||
CONNECTED_WEBSITE = helper.Item() # connected_website
|
||||
MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id
|
||||
MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id
|
||||
PINNED_MESSAGE = helper.Item() # pinned_message
|
||||
NEW_CHAT_TITLE = helper.Item() # new_chat_title
|
||||
NEW_CHAT_PHOTO = helper.Item() # new_chat_photo
|
||||
DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo
|
||||
GROUP_CHAT_CREATED = helper.Item() # group_chat_created
|
||||
PASSPORT_DATA = helper.Item() # passport_data
|
||||
|
||||
UNKNOWN = helper.Item() # unknown
|
||||
ANY = helper.Item() # any
|
||||
|
||||
|
||||
class ContentTypes(helper.Helper):
|
||||
"""
|
||||
List of message content types
|
||||
|
||||
WARNING: List elements.
|
||||
|
||||
:key: TEXT
|
||||
:key: AUDIO
|
||||
:key: DOCUMENT
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import base
|
||||
from . import fields
|
||||
from .callback_query import CallbackQuery
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import babel
|
||||
|
||||
from . import base
|
||||
|
|
|
|||
|
|
@ -1,140 +0,0 @@
|
|||
"""
|
||||
You need to setup task factory:
|
||||
>>> from aiogram.utils import context
|
||||
>>> loop = asyncio.get_event_loop()
|
||||
>>> loop.set_task_factory(context.task_factory)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
|
||||
|
||||
|
||||
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
||||
"""
|
||||
Task factory for implementing context processor
|
||||
|
||||
:param loop:
|
||||
:param coro:
|
||||
:return: new task
|
||||
:rtype: :obj:`asyncio.Task`
|
||||
"""
|
||||
# Is not allowed when loop is closed.
|
||||
if loop.is_closed():
|
||||
raise RuntimeError('Event loop is closed.')
|
||||
|
||||
task = asyncio.Task(coro, loop=loop)
|
||||
|
||||
# Hide factory
|
||||
if task._source_traceback:
|
||||
del task._source_traceback[-1]
|
||||
|
||||
try:
|
||||
task.context = asyncio.Task.current_task().context.copy()
|
||||
except AttributeError:
|
||||
task.context = {CONFIGURED: True}
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_current_state() -> typing.Dict:
|
||||
"""
|
||||
Get current execution context from task
|
||||
|
||||
:return: context
|
||||
:rtype: :obj:`dict`
|
||||
"""
|
||||
task = asyncio.Task.current_task()
|
||||
if task is None:
|
||||
raise RuntimeError('Can be used only in Task context.')
|
||||
context_ = getattr(task, 'context', None)
|
||||
if context_ is None:
|
||||
context_ = task.context = {}
|
||||
return context_
|
||||
|
||||
|
||||
def get_value(key, default=None):
|
||||
"""
|
||||
Get value from task
|
||||
|
||||
:param key:
|
||||
:param default:
|
||||
:return: value
|
||||
"""
|
||||
return get_current_state().get(key, default)
|
||||
|
||||
|
||||
def check_value(key):
|
||||
"""
|
||||
Key in context?
|
||||
|
||||
:param key:
|
||||
:return:
|
||||
"""
|
||||
return key in get_current_state()
|
||||
|
||||
|
||||
def set_value(key, value):
|
||||
"""
|
||||
Set value
|
||||
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
get_current_state()[key] = value
|
||||
|
||||
|
||||
def del_value(key):
|
||||
"""
|
||||
Remove value from context
|
||||
|
||||
:param key:
|
||||
:return:
|
||||
"""
|
||||
del get_current_state()[key]
|
||||
|
||||
|
||||
def update_state(data=None, **kwargs):
|
||||
"""
|
||||
Update multiple state items
|
||||
|
||||
:param data:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
state = get_current_state()
|
||||
state.update(data, **kwargs)
|
||||
|
||||
|
||||
def check_configured():
|
||||
"""
|
||||
Check loop is configured
|
||||
:return:
|
||||
"""
|
||||
return get_value(CONFIGURED)
|
||||
|
||||
|
||||
class _Context:
|
||||
"""
|
||||
Other things for interactions with the execution context.
|
||||
"""
|
||||
|
||||
def __getitem__(self, item):
|
||||
return get_value(item)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del_value(key)
|
||||
|
||||
@staticmethod
|
||||
def get_context():
|
||||
return get_current_state()
|
||||
|
||||
|
||||
context = _Context()
|
||||
|
|
@ -182,6 +182,10 @@ class MessageToEditNotFound(MessageError):
|
|||
match = 'message to edit not found'
|
||||
|
||||
|
||||
class MessageIsTooLong(MessageError):
|
||||
match = 'message is too long'
|
||||
|
||||
|
||||
class ToMuchMessages(MessageError):
|
||||
"""
|
||||
Will be raised when you try to send media group with more than 10 items.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from warnings import warn
|
|||
|
||||
from aiohttp import web
|
||||
|
||||
from . import context
|
||||
from ..bot.api import log
|
||||
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
||||
|
||||
|
|
@ -104,6 +103,11 @@ class Executor:
|
|||
|
||||
self._freeze = False
|
||||
|
||||
from aiogram.bot.bot import bot as ctx_bot
|
||||
from aiogram.dispatcher import dispatcher as ctx_dp
|
||||
ctx_bot.set(dispatcher.bot)
|
||||
ctx_dp.set(dispatcher)
|
||||
|
||||
@property
|
||||
def frozen(self):
|
||||
return self._freeze
|
||||
|
|
@ -176,13 +180,13 @@ class Executor:
|
|||
self._check_frozen()
|
||||
self._freeze = True
|
||||
|
||||
self.loop.set_task_factory(context.task_factory)
|
||||
# self.loop.set_task_factory(context.task_factory)
|
||||
|
||||
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
||||
self._check_frozen()
|
||||
self._freeze = True
|
||||
|
||||
self.loop.set_task_factory(context.task_factory)
|
||||
# self.loop.set_task_factory(context.task_factory)
|
||||
|
||||
app = self._web_app
|
||||
if app is None:
|
||||
|
|
@ -203,6 +207,7 @@ class Executor:
|
|||
|
||||
for callback in self._on_startup_webhook:
|
||||
app.on_startup.append(functools.partial(_wrap_callback, callback))
|
||||
|
||||
# for callback in self._on_shutdown_webhook:
|
||||
# app.on_shutdown.append(functools.partial(_wrap_callback, callback))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,26 +1,52 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
JSON = 'json'
|
||||
RAPIDJSON = 'rapidjson'
|
||||
UJSON = 'ujson'
|
||||
|
||||
try:
|
||||
import ujson
|
||||
if 'DISABLE_UJSON' not in os.environ:
|
||||
import ujson as json
|
||||
|
||||
mode = UJSON
|
||||
|
||||
|
||||
def dumps(data):
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
mode = JSON
|
||||
except ImportError:
|
||||
ujson = None
|
||||
mode = JSON
|
||||
|
||||
_use_ujson = True if ujson else False
|
||||
try:
|
||||
if 'DISABLE_RAPIDJSON' not in os.environ:
|
||||
import rapidjson as json
|
||||
|
||||
mode = RAPIDJSON
|
||||
|
||||
|
||||
def disable_ujson():
|
||||
global _use_ujson
|
||||
_use_ujson = False
|
||||
def dumps(data):
|
||||
return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE,
|
||||
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
|
||||
|
||||
|
||||
def dumps(data):
|
||||
if _use_ujson:
|
||||
return ujson.dumps(data)
|
||||
return json.dumps(data)
|
||||
def loads(data):
|
||||
return json.loads(data, number_mode=json.NM_NATIVE,
|
||||
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
|
||||
|
||||
else:
|
||||
mode = JSON
|
||||
except ImportError:
|
||||
mode = JSON
|
||||
|
||||
if mode == JSON:
|
||||
import json
|
||||
|
||||
|
||||
def loads(data):
|
||||
if _use_ujson:
|
||||
return ujson.loads(data)
|
||||
return json.loads(data)
|
||||
def dumps(data):
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
def loads(data):
|
||||
return json.loads(data)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import datetime
|
||||
import secrets
|
||||
|
||||
from aiogram import types
|
||||
from . import json
|
||||
|
||||
DEFAULT_FILTER = ['self', 'cls']
|
||||
|
|
@ -56,3 +58,22 @@ def prepare_arg(value):
|
|||
elif isinstance(value, datetime.datetime):
|
||||
return round(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def prepare_file(payload, files, key, file):
|
||||
if isinstance(file, str):
|
||||
payload[key] = file
|
||||
elif file is not None:
|
||||
files[key] = file
|
||||
|
||||
|
||||
def prepare_attachment(payload, files, key, file):
|
||||
if isinstance(file, str):
|
||||
payload[key] = file
|
||||
elif isinstance(file, types.InputFile):
|
||||
payload[key] = file.attach
|
||||
files[file.attachment_key] = file.file
|
||||
elif file is not None:
|
||||
file_attach_name = secrets.token_urlsafe(16)
|
||||
payload[key] = "attach://" + file_attach_name
|
||||
files[file_attach_name] = file
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
-r requirements.txt
|
||||
|
||||
ujson>=1.35
|
||||
python-rapidjson>=0.6.3
|
||||
emoji>=0.5.0
|
||||
pytest>=3.5.0
|
||||
pytest-asyncio>=0.8.0
|
||||
|
|
|
|||
|
|
@ -25,16 +25,16 @@ Next step: interaction with bots starts with one command. Register your first co
|
|||
|
||||
.. code-block:: python3
|
||||
|
||||
@dp.message_handler(commands=['start', 'help'])
|
||||
async def send_welcome(message: types.Message):
|
||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||
@dp.message_handler(commands=['start', 'help'])
|
||||
async def send_welcome(message: types.Message):
|
||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||
|
||||
Last step: run long polling.
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp)
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp)
|
||||
|
||||
Summary
|
||||
-------
|
||||
|
|
@ -48,9 +48,9 @@ Summary
|
|||
bot = Bot(token='BOT TOKEN HERE')
|
||||
dp = Dispatcher(bot)
|
||||
|
||||
@dp.message_handler(commands=['start', 'help'])
|
||||
async def send_welcome(message: types.Message):
|
||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||
@dp.message_handler(commands=['start', 'help'])
|
||||
async def send_welcome(message: types.Message):
|
||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp)
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
name: py36
|
||||
name: py37
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.6
|
||||
- python=3.7
|
||||
- sphinx=1.5.3
|
||||
- sphinx_rtd_theme=0.2.4
|
||||
- pip
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def send_message(user_id: int, text: str, disable_notification: bool = Fal
|
|||
|
||||
:param user_id:
|
||||
:param text:
|
||||
:param disable_notification:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,18 +5,14 @@ Babel is required.
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.types import ParseMode
|
||||
from aiogram.utils.executor import start_polling
|
||||
from aiogram.utils.markdown import *
|
||||
from aiogram import Bot, Dispatcher, executor, md, types
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
bot = Bot(token=API_TOKEN, loop=loop)
|
||||
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.MARKDOWN)
|
||||
dp = Dispatcher(bot)
|
||||
|
||||
|
||||
|
|
@ -24,14 +20,14 @@ dp = Dispatcher(bot)
|
|||
async def check_language(message: types.Message):
|
||||
locale = message.from_user.locale
|
||||
|
||||
await message.reply(text(
|
||||
bold('Info about your language:'),
|
||||
text(' 🔸', bold('Code:'), italic(locale.locale)),
|
||||
text(' 🔸', bold('Territory:'), italic(locale.territory or 'Unknown')),
|
||||
text(' 🔸', bold('Language name:'), italic(locale.language_name)),
|
||||
text(' 🔸', bold('English language name:'), italic(locale.english_name)),
|
||||
sep='\n'), parse_mode=ParseMode.MARKDOWN)
|
||||
await message.reply(md.text(
|
||||
md.bold('Info about your language:'),
|
||||
md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)),
|
||||
md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')),
|
||||
md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)),
|
||||
md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)),
|
||||
sep='\n'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp, loop=loop, skip_updates=True)
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.utils.executor import start_polling
|
||||
from aiogram import Bot, types, Dispatcher, executor
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
|
|
@ -32,10 +30,4 @@ async def echo(message: types.Message):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
||||
# Also you can use another execution method
|
||||
# >>> try:
|
||||
# >>> loop.run_until_complete(main())
|
||||
# >>> except KeyboardInterrupt:
|
||||
# >>> loop.stop()
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
from aiogram import Bot, types
|
||||
from aiogram.contrib.middlewares.context import ContextMiddleware
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.types import ParseMode
|
||||
from aiogram.utils import markdown as md
|
||||
from aiogram.utils.executor import start_polling
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
bot = Bot(token=API_TOKEN)
|
||||
dp = Dispatcher(bot)
|
||||
|
||||
# Setup Context middleware
|
||||
data: ContextMiddleware = dp.middleware.setup(ContextMiddleware())
|
||||
|
||||
|
||||
# Write custom filter
|
||||
async def demo_filter(message: types.Message):
|
||||
# Store some data in context
|
||||
command = data['command'] = message.get_command() or ''
|
||||
args = data['args'] = message.get_args() or ''
|
||||
data['has_args'] = bool(args)
|
||||
data['some_random_data'] = 42
|
||||
return command != '/bad_command'
|
||||
|
||||
|
||||
@dp.message_handler(demo_filter)
|
||||
async def send_welcome(message: types.Message):
|
||||
# Get data from context
|
||||
# All of this is available only in current context and from current update object
|
||||
# `data`- pseudo-alias for `ctx.get_update().conf['_context_data']`
|
||||
command = data['command']
|
||||
args = data['args']
|
||||
rand = data['some_random_data']
|
||||
has_args = data['has_args']
|
||||
|
||||
# Send as pre-formatted code block.
|
||||
await message.reply(md.hpre(f"""command: {command}
|
||||
args: {['Not available', 'available'][has_args]}: {args}
|
||||
some random data: {rand}
|
||||
message ID: {message.message_id}
|
||||
message: {message.html_text}
|
||||
"""), parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp)
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import Bot, types
|
||||
import aiogram.utils.markdown as md
|
||||
from aiogram import Bot, Dispatcher, types
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher import FSMContext
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
from aiogram.types import ParseMode
|
||||
from aiogram.utils import executor
|
||||
from aiogram.utils.markdown import text, bold
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
|
|
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
|
|||
storage = MemoryStorage()
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
|
||||
|
||||
# States
|
||||
AGE = 'process_age'
|
||||
NAME = 'process_name'
|
||||
GENDER = 'process_gender'
|
||||
class Form(StatesGroup):
|
||||
name = State() # Will be represented in storage as 'Form:name'
|
||||
age = State() # Will be represented in storage as 'Form:age'
|
||||
gender = State() # Will be represented in storage as 'Form:gender'
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start'])
|
||||
|
|
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
|
|||
"""
|
||||
Conversation's entry point
|
||||
"""
|
||||
# Get current state
|
||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
||||
# Update user's state
|
||||
await state.set_state(NAME)
|
||||
# Set state
|
||||
await Form.name.set()
|
||||
|
||||
await message.reply("Hi there! What's your name?")
|
||||
|
||||
|
||||
# You can use state '*' if you need to handle all states
|
||||
@dp.message_handler(state='*', commands=['cancel'])
|
||||
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel')
|
||||
async def cancel_handler(message: types.Message):
|
||||
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||
async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
|
||||
"""
|
||||
Allow user to cancel any action
|
||||
"""
|
||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
||||
# Ignore command if user is not in any (defined) state
|
||||
if await state.get_state() is None:
|
||||
return
|
||||
if raw_state is None:
|
||||
return
|
||||
|
||||
# Otherwise cancel state and inform user about it
|
||||
# And remove keyboard (just in case)
|
||||
await state.reset_state(with_data=True)
|
||||
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||
# Cancel state and inform user about it
|
||||
await state.finish()
|
||||
# And remove keyboard (just in case)
|
||||
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||
|
||||
|
||||
@dp.message_handler(state=NAME)
|
||||
async def process_name(message: types.Message):
|
||||
@dp.message_handler(state=Form.name)
|
||||
async def process_name(message: types.Message, state: FSMContext):
|
||||
"""
|
||||
Process user name
|
||||
"""
|
||||
# Save name to storage and go to next step
|
||||
# You can use context manager
|
||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
||||
await state.update_data(name=message.text)
|
||||
await state.set_state(AGE)
|
||||
await Form.next()
|
||||
await state.update_data(name=message.text)
|
||||
|
||||
await message.reply("How old are you?")
|
||||
|
||||
|
||||
# Check age. Age gotta be digit
|
||||
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit())
|
||||
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||
async def failed_process_age(message: types.Message):
|
||||
"""
|
||||
If age is invalid
|
||||
|
|
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
|
|||
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||
|
||||
|
||||
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit())
|
||||
async def process_age(message: types.Message):
|
||||
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||
async def process_age(message: types.Message, state: FSMContext):
|
||||
# Update state and data
|
||||
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
|
||||
await state.set_state(GENDER)
|
||||
await state.update_data(age=int(message.text))
|
||||
await Form.next()
|
||||
await state.update_data(age=int(message.text))
|
||||
|
||||
# Configure ReplyKeyboardMarkup
|
||||
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||
|
|
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
|
|||
await message.reply("What is your gender?", reply_markup=markup)
|
||||
|
||||
|
||||
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"])
|
||||
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||
async def failed_process_gender(message: types.Message):
|
||||
"""
|
||||
In this example gender has to be one of: Male, Female, Other.
|
||||
|
|
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
|
|||
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||
|
||||
|
||||
@dp.message_handler(state=GENDER)
|
||||
async def process_gender(message: types.Message):
|
||||
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
|
||||
|
||||
@dp.message_handler(state=Form.gender)
|
||||
async def process_gender(message: types.Message, state: FSMContext):
|
||||
data = await state.get_data()
|
||||
data['gender'] = message.text
|
||||
|
||||
|
|
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
|
|||
markup = types.ReplyKeyboardRemove()
|
||||
|
||||
# And send message
|
||||
await bot.send_message(message.chat.id, text(
|
||||
text('Hi! Nice to meet you,', bold(data['name'])),
|
||||
text('Age:', data['age']),
|
||||
text('Gender:', data['gender']),
|
||||
await bot.send_message(message.chat.id, md.text(
|
||||
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
|
||||
md.text('Age:', data['age']),
|
||||
md.text('Gender:', data['gender']),
|
||||
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
|
||||
|
||||
# Finish conversation
|
||||
|
|
@ -122,10 +116,5 @@ async def process_gender(message: types.Message):
|
|||
await state.finish()
|
||||
|
||||
|
||||
async def shutdown(dispatcher: Dispatcher):
|
||||
await dispatcher.storage.close()
|
||||
await dispatcher.storage.wait_closed()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown)
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
|
|
|||
126
examples/finite_state_machine_example_2.py
Normal file
126
examples/finite_state_machine_example_2.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
This example is equals with 'finite_state_machine_example.py' but with FSM Middleware
|
||||
|
||||
Note that FSM Middleware implements the more simple methods for working with storage.
|
||||
|
||||
With that middleware all data from storage will be loaded before event will be processed
|
||||
and data will be stored after processing the event.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import aiogram.utils.markdown as md
|
||||
from aiogram import Bot, Dispatcher, types
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
from aiogram.utils import executor
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
bot = Bot(token=API_TOKEN, loop=loop)
|
||||
|
||||
# For example use simple MemoryStorage for Dispatcher.
|
||||
storage = MemoryStorage()
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
dp.middleware.setup(FSMMiddleware())
|
||||
|
||||
|
||||
# States
|
||||
class Form(StatesGroup):
|
||||
name = State() # Will be represented in storage as 'Form:name'
|
||||
age = State() # Will be represented in storage as 'Form:age'
|
||||
gender = State() # Will be represented in storage as 'Form:gender'
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start'])
|
||||
async def cmd_start(message: types.Message):
|
||||
"""
|
||||
Conversation's entry point
|
||||
"""
|
||||
# Set state
|
||||
await Form.first()
|
||||
|
||||
await message.reply("Hi there! What's your name?")
|
||||
|
||||
|
||||
# You can use state '*' if you need to handle all states
|
||||
@dp.message_handler(state='*', commands=['cancel'])
|
||||
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
|
||||
async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy):
|
||||
"""
|
||||
Allow user to cancel any action
|
||||
"""
|
||||
if state_data.state is None:
|
||||
return
|
||||
|
||||
# Cancel state and inform user about it
|
||||
del state_data.state
|
||||
# And remove keyboard (just in case)
|
||||
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
|
||||
|
||||
|
||||
@dp.message_handler(state=Form.name)
|
||||
async def process_name(message: types.Message, state_data: FSMSStorageProxy):
|
||||
"""
|
||||
Process user name
|
||||
"""
|
||||
state_data.state = Form.age
|
||||
state_data['name'] = message.text
|
||||
|
||||
await message.reply("How old are you?")
|
||||
|
||||
|
||||
# Check age. Age gotta be digit
|
||||
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||
async def failed_process_age(message: types.Message):
|
||||
"""
|
||||
If age is invalid
|
||||
"""
|
||||
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||
|
||||
|
||||
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||
async def process_age(message: types.Message, state_data: FSMSStorageProxy):
|
||||
# Update state and data
|
||||
state_data.state = Form.gender
|
||||
state_data['age'] = int(message.text)
|
||||
|
||||
# Configure ReplyKeyboardMarkup
|
||||
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||
markup.add("Male", "Female")
|
||||
markup.add("Other")
|
||||
|
||||
await message.reply("What is your gender?", reply_markup=markup)
|
||||
|
||||
|
||||
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||
async def failed_process_gender(message: types.Message):
|
||||
"""
|
||||
In this example gender has to be one of: Male, Female, Other.
|
||||
"""
|
||||
return await message.reply("Bad gender name. Choose you gender from keyboard.")
|
||||
|
||||
|
||||
@dp.message_handler(state=Form.gender)
|
||||
async def process_gender(message: types.Message, state_data: FSMSStorageProxy):
|
||||
state_data['gender'] = message.text
|
||||
|
||||
# Remove keyboard
|
||||
markup = types.ReplyKeyboardRemove()
|
||||
|
||||
# And send message
|
||||
await bot.send_message(message.chat.id, md.text(
|
||||
md.text('Hi! Nice to meet you,', md.bold(state_data['name'])),
|
||||
md.text('Age:', state_data['age']),
|
||||
md.text('Gender:', state_data['gender']),
|
||||
sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN)
|
||||
|
||||
# Finish conversation
|
||||
# WARNING! This method will destroy all data in storage for current user!
|
||||
state_data.clear()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
56
examples/i18n_example.py
Normal file
56
examples/i18n_example.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
Internalize your bot
|
||||
|
||||
Step 1: extract texts
|
||||
# pybabel extract i18n_example.py -o locales/mybot.pot
|
||||
Step 2: create *.po files. For e.g. create en, ru, uk locales.
|
||||
# echo {en,ru,uk} | xargs -n1 pybabel init -i locales/mybot.pot -d locales -D mybot -l
|
||||
Step 3: translate texts
|
||||
Step 4: compile translations
|
||||
# pybabel compile -d locales -D mybot
|
||||
|
||||
Step 5: When you change the code of your bot you need to update po & mo files.
|
||||
Step 5.1: regenerate pot file:
|
||||
command from step 1
|
||||
Step 5.2: update po files
|
||||
# pybabel update -d locales -D mybot -i locales/mybot.pot
|
||||
Step 5.3: update your translations
|
||||
Step 5.4: compile mo files
|
||||
command from step 4
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from aiogram import Bot, Dispatcher, executor, types
|
||||
from aiogram.contrib.middlewares.i18n import I18nMiddleware
|
||||
|
||||
TOKEN = 'BOT TOKEN HERE'
|
||||
I18N_DOMAIN = 'mybot'
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
LOCALES_DIR = BASE_DIR / 'locales'
|
||||
|
||||
bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML)
|
||||
dp = Dispatcher(bot)
|
||||
|
||||
# Setup i18n middleware
|
||||
i18n = I18nMiddleware(I18N_DOMAIN, LOCALES_DIR)
|
||||
dp.middleware.setup(i18n)
|
||||
|
||||
# Alias for gettext method
|
||||
_ = i18n.gettext
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start'])
|
||||
async def cmd_start(message: types.Message):
|
||||
# Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts.
|
||||
await message.reply(_('Hello, <b>{user}</b>!').format(user=message.from_user.full_name))
|
||||
|
||||
|
||||
@dp.message_handler(commands=['lang'])
|
||||
async def cmd_lang(message: types.Message, locale):
|
||||
await message.reply(_('Your current language: <i>{language}</i>').format(language=locale))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp, skip_updates=True)
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.utils.executor import start_polling
|
||||
from aiogram import Bot, types, Dispatcher, executor
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
|
|
@ -23,4 +21,4 @@ async def inline_echo(inline_query: types.InlineQuery):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp, loop=loop, skip_updates=True)
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
|
|
|||
28
examples/locales/en/LC_MESSAGES/mybot.po
Normal file
28
examples/locales/en/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# English translations for PROJECT.
|
||||
# Copyright (C) 2018 ORGANIZATION
|
||||
# This file is distributed under the same license as the PROJECT project.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||
#
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PROJECT VERSION\n"
|
||||
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: en\n"
|
||||
"Language-Team: en <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=2; plural=(n != 1)\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.6.0\n"
|
||||
|
||||
#: i18n_example.py:48
|
||||
msgid "Hello, <b>{user}</b>!"
|
||||
msgstr ""
|
||||
|
||||
#: i18n_example.py:53
|
||||
msgid "Your current language: <i>{language}</i>"
|
||||
msgstr ""
|
||||
|
||||
27
examples/locales/mybot.pot
Normal file
27
examples/locales/mybot.pot
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Translations template for PROJECT.
|
||||
# Copyright (C) 2018 ORGANIZATION
|
||||
# This file is distributed under the same license as the PROJECT project.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PROJECT VERSION\n"
|
||||
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.6.0\n"
|
||||
|
||||
#: i18n_example.py:48
|
||||
msgid "Hello, <b>{user}</b>!"
|
||||
msgstr ""
|
||||
|
||||
#: i18n_example.py:53
|
||||
msgid "Your current language: <i>{language}</i>"
|
||||
msgstr ""
|
||||
|
||||
29
examples/locales/ru/LC_MESSAGES/mybot.po
Normal file
29
examples/locales/ru/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Russian translations for PROJECT.
|
||||
# Copyright (C) 2018 ORGANIZATION
|
||||
# This file is distributed under the same license as the PROJECT project.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||
#
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PROJECT VERSION\n"
|
||||
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: ru\n"
|
||||
"Language-Team: ru <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
|
||||
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.6.0\n"
|
||||
|
||||
#: i18n_example.py:48
|
||||
msgid "Hello, <b>{user}</b>!"
|
||||
msgstr "Привет, <b>{user}</b>!"
|
||||
|
||||
#: i18n_example.py:53
|
||||
msgid "Your current language: <i>{language}</i>"
|
||||
msgstr "Твой язык: <i>{language}</i>"
|
||||
|
||||
29
examples/locales/uk/LC_MESSAGES/mybot.po
Normal file
29
examples/locales/uk/LC_MESSAGES/mybot.po
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Ukrainian translations for PROJECT.
|
||||
# Copyright (C) 2018 ORGANIZATION
|
||||
# This file is distributed under the same license as the PROJECT project.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
|
||||
#
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PROJECT VERSION\n"
|
||||
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
|
||||
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
|
||||
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: uk\n"
|
||||
"Language-Team: uk <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
|
||||
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.6.0\n"
|
||||
|
||||
#: i18n_example.py:48
|
||||
msgid "Hello, <b>{user}</b>!"
|
||||
msgstr "Привіт, <b>{user}</b>!"
|
||||
|
||||
#: i18n_example.py:53
|
||||
msgid "Your current language: <i>{language}</i>"
|
||||
msgstr "Твоя мова: <i>{language}</i>"
|
||||
|
||||
|
|
@ -1,9 +1,6 @@
|
|||
import asyncio
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.types import ChatActions
|
||||
from aiogram.utils.executor import start_polling
|
||||
from aiogram import Bot, Dispatcher, executor, filters, types
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
|
|
@ -12,7 +9,7 @@ bot = Bot(token=API_TOKEN, loop=loop)
|
|||
dp = Dispatcher(bot)
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start'])
|
||||
@dp.message_handler(filters.CommandStart())
|
||||
async def send_welcome(message: types.Message):
|
||||
# So... At first I want to send something like this:
|
||||
await message.reply("Do you want to see many pussies? Are you ready?")
|
||||
|
|
@ -21,7 +18,7 @@ async def send_welcome(message: types.Message):
|
|||
await asyncio.sleep(1)
|
||||
|
||||
# Good bots should send chat actions. Or not.
|
||||
await ChatActions.upload_photo()
|
||||
await types.ChatActions.upload_photo()
|
||||
|
||||
# Create media group
|
||||
media = types.MediaGroup()
|
||||
|
|
@ -39,9 +36,8 @@ async def send_welcome(message: types.Message):
|
|||
# media.attach_photo('<file_id>', 'cat-cat-cat.')
|
||||
|
||||
# Done! Send media group
|
||||
await bot.send_media_group(message.chat.id, media=media,
|
||||
reply_to_message_id=message.message_id)
|
||||
await message.reply_media_group(media=media)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp, loop=loop, skip_updates=True)
|
||||
executor.start_polling(dp, loop=loop, skip_updates=True)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import asyncio
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram import Bot, Dispatcher, executor, types
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
||||
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
|
||||
from aiogram.dispatcher.handler import CancelHandler
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
from aiogram.utils import context, executor
|
||||
from aiogram.utils.exceptions import Throttled
|
||||
|
||||
TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
|
|
@ -53,10 +52,10 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
:param message:
|
||||
"""
|
||||
# Get current handler
|
||||
handler = context.get_value('handler')
|
||||
# handler = context.get_value('handler')
|
||||
|
||||
# Get dispatcher from context
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
dispatcher = Dispatcher.current()
|
||||
|
||||
# If handler was configured, get rate limit and key from handler
|
||||
if handler:
|
||||
|
|
@ -83,8 +82,8 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
:param message:
|
||||
:param throttled:
|
||||
"""
|
||||
handler = context.get_value('handler')
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
# handler = context.get_value('handler')
|
||||
dispatcher = Dispatcher.current()
|
||||
if handler:
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from aiogram import Bot
|
|||
from aiogram import types
|
||||
from aiogram.utils import executor
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.types.message import ContentType
|
||||
from aiogram.types.message import ContentTypes
|
||||
|
||||
|
||||
BOT_TOKEN = 'BOT TOKEN HERE'
|
||||
|
|
@ -86,7 +86,7 @@ async def checkout(pre_checkout_query: types.PreCheckoutQuery):
|
|||
" try to pay again in a few minutes, we need a small rest.")
|
||||
|
||||
|
||||
@dp.message_handler(content_types=ContentType.SUCCESSFUL_PAYMENT)
|
||||
@dp.message_handler(content_types=ContentTypes.SUCCESSFUL_PAYMENT)
|
||||
async def got_payment(message: types.Message):
|
||||
await bot.send_message(message.chat.id,
|
||||
'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`'
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ PROXY_URL = 'http://PROXY_URL' # Or 'socks5://...'
|
|||
# PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password')
|
||||
# And add `proxy_auth=PROXY_AUTH` argument in line 25, like this:
|
||||
# >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH)
|
||||
# Also you can use Socks5 proxy but you need manually install aiosocksy package.
|
||||
# Also you can use Socks5 proxy but you need manually install aiohttp_socks package.
|
||||
|
||||
# Get my ip URL
|
||||
GET_IP_URL = 'http://bot.whatismyipaddress.com/'
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from aiogram import Bot, types, Version
|
|||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage
|
||||
from aiogram.types import ChatType, ParseMode, ContentType
|
||||
from aiogram.types import ChatType, ParseMode, ContentTypes
|
||||
from aiogram.utils.markdown import hbold, bold, text, link
|
||||
|
||||
TOKEN = 'BOT TOKEN HERE'
|
||||
|
|
@ -31,7 +31,7 @@ WEBHOOK_URL = f"https://{WEBHOOK_HOST}:{WEBHOOK_PORT}{WEBHOOK_URL_PATH}"
|
|||
WEBAPP_HOST = 'localhost'
|
||||
WEBAPP_PORT = 3001
|
||||
|
||||
BAD_CONTENT = ContentType.PHOTO & ContentType.DOCUMENT & ContentType.STICKER & ContentType.AUDIO
|
||||
BAD_CONTENT = ContentTypes.PHOTO & ContentTypes.DOCUMENT & ContentTypes.STICKER & ContentTypes.AUDIO
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
bot = Bot(TOKEN, loop=loop)
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -13,7 +13,7 @@ except ImportError: # pip >= 10.0.0
|
|||
WORK_DIR = pathlib.Path(__file__).parent
|
||||
|
||||
# Check python version
|
||||
MINIMAL_PY_VERSION = (3, 6)
|
||||
MINIMAL_PY_VERSION = (3, 7)
|
||||
if sys.version_info < MINIMAL_PY_VERSION:
|
||||
raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION))))
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ setup(
|
|||
url='https://github.com/aiogram/aiogram',
|
||||
license='MIT',
|
||||
author='Alex Root Junior',
|
||||
requires_python='>=3.6',
|
||||
requires_python='>=3.7',
|
||||
author_email='aiogram@illemius.xyz',
|
||||
description='Is a pretty simple and fully asynchronous library for Telegram Bot API',
|
||||
long_description=get_description(),
|
||||
|
|
@ -76,7 +76,7 @@ setup(
|
|||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: System Administrators',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Topic :: Software Development :: Libraries :: Application Frameworks',
|
||||
],
|
||||
install_requires=get_requirements()
|
||||
|
|
|
|||
102
tests/states_group.py
Normal file
102
tests/states_group.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup, any_state, default_state
|
||||
|
||||
|
||||
class MyGroup(StatesGroup):
|
||||
state = State()
|
||||
state_1 = State()
|
||||
state_2 = State()
|
||||
|
||||
class MySubGroup(StatesGroup):
|
||||
sub_state = State()
|
||||
sub_state_1 = State()
|
||||
sub_state_2 = State()
|
||||
|
||||
in_custom_group = State(group_name='custom_group')
|
||||
|
||||
class NewGroup(StatesGroup):
|
||||
spam = State()
|
||||
renamed_state = State(state='spam_state')
|
||||
|
||||
|
||||
alone_state = State('alone')
|
||||
alone_in_group = State('alone', group_name='home')
|
||||
|
||||
|
||||
def test_default_state():
|
||||
assert default_state.state is None
|
||||
|
||||
|
||||
def test_any_state():
|
||||
assert any_state.state == '*'
|
||||
|
||||
|
||||
def test_alone_state():
|
||||
assert alone_state.state == '@:alone'
|
||||
assert alone_in_group.state == 'home:alone'
|
||||
|
||||
|
||||
def test_group_names():
|
||||
assert MyGroup.__group_name__ == 'MyGroup'
|
||||
assert MyGroup.__full_group_name__ == 'MyGroup'
|
||||
|
||||
assert MyGroup.MySubGroup.__group_name__ == 'MySubGroup'
|
||||
assert MyGroup.MySubGroup.__full_group_name__ == 'MyGroup.MySubGroup'
|
||||
|
||||
assert MyGroup.MySubGroup.NewGroup.__group_name__ == 'NewGroup'
|
||||
assert MyGroup.MySubGroup.NewGroup.__full_group_name__ == 'MyGroup.MySubGroup.NewGroup'
|
||||
|
||||
|
||||
def test_custom_group_in_group():
|
||||
assert MyGroup.MySubGroup.in_custom_group.state == 'custom_group:in_custom_group'
|
||||
|
||||
|
||||
def test_custom_state_name_in_group():
|
||||
assert MyGroup.MySubGroup.NewGroup.renamed_state.state == 'MyGroup.MySubGroup.NewGroup:spam_state'
|
||||
|
||||
|
||||
def test_group_states_names():
|
||||
assert len(MyGroup.states) == 3
|
||||
assert len(MyGroup.all_states) == 9
|
||||
|
||||
assert MyGroup.states_names == ('MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2')
|
||||
assert MyGroup.MySubGroup.states_names == (
|
||||
'MyGroup.MySubGroup:sub_state', 'MyGroup.MySubGroup:sub_state_1', 'MyGroup.MySubGroup:sub_state_2',
|
||||
'custom_group:in_custom_group')
|
||||
assert MyGroup.MySubGroup.NewGroup.states_names == (
|
||||
'MyGroup.MySubGroup.NewGroup:spam', 'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||
|
||||
assert MyGroup.all_states_names == (
|
||||
'MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2',
|
||||
'MyGroup.MySubGroup:sub_state',
|
||||
'MyGroup.MySubGroup:sub_state_1',
|
||||
'MyGroup.MySubGroup:sub_state_2',
|
||||
'custom_group:in_custom_group',
|
||||
'MyGroup.MySubGroup.NewGroup:spam',
|
||||
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||
|
||||
assert MyGroup.MySubGroup.all_states_names == (
|
||||
'MyGroup.MySubGroup:sub_state',
|
||||
'MyGroup.MySubGroup:sub_state_1',
|
||||
'MyGroup.MySubGroup:sub_state_2',
|
||||
'custom_group:in_custom_group',
|
||||
'MyGroup.MySubGroup.NewGroup:spam',
|
||||
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||
|
||||
assert MyGroup.MySubGroup.NewGroup.all_states_names == (
|
||||
'MyGroup.MySubGroup.NewGroup:spam',
|
||||
'MyGroup.MySubGroup.NewGroup:spam_state')
|
||||
|
||||
|
||||
def test_root_element():
|
||||
root = MyGroup.MySubGroup.NewGroup.spam.get_root()
|
||||
|
||||
assert issubclass(root, StatesGroup)
|
||||
assert root == MyGroup
|
||||
|
||||
assert root == MyGroup.state.get_root()
|
||||
assert root == MyGroup.MySubGroup.get_root()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
any_state.get_root()
|
||||
2
tox.ini
2
tox.ini
|
|
@ -1,5 +1,5 @@
|
|||
[tox]
|
||||
envlist = py36
|
||||
envlist = py37
|
||||
|
||||
[testenv]
|
||||
deps = -rdev_requirements.txt
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue