This commit is contained in:
Kolay 2018-08-17 11:18:04 +00:00 committed by GitHub
commit bfccd6268c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 3455 additions and 2299 deletions

3
.gitignore vendored
View file

@ -57,3 +57,6 @@ experiment.py
# Doc's
docs/html
# i18n/l10n
*.mo

View file

@ -8,7 +8,7 @@
[![Github issues](https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square)](https://github.com/aiogram/aiogram/issues)
[![MIT License](https://img.shields.io/pypi/l/aiogram.svg?style=flat-square)](https://opensource.org/licenses/MIT)
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.6 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.7 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
You can [read the docs here](http://aiogram.readthedocs.io/en/latest/).

View file

@ -1,14 +1,42 @@
import asyncio
import os
from . import bot
from . import contrib
from . import dispatcher
from . import types
from . import utils
from .bot import Bot
from .dispatcher import Dispatcher
from .dispatcher import filters
from .dispatcher import middlewares
from .utils import exceptions, executor, helper, markdown as md
try:
import uvloop
except ImportError:
uvloop = None
else:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
if 'DISABLE_UVLOOP' not in os.environ:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
__version__ = '1.4'
__all__ = [
'Bot',
'Dispatcher',
'__api_version__',
'__version__',
'bot',
'contrib',
'dispatcher',
'exceptions',
'executor',
'filters',
'helper',
'md',
'middlewares',
'types',
'utils'
]
__version__ = '2.0.dev1'
__api_version__ = '3.6'

View file

@ -1,8 +1,14 @@
import abc
import asyncio
import logging
import os
import ssl
from asyncio import AbstractEventLoop
from http import HTTPStatus
from typing import Optional, Tuple
import aiohttp
import certifi
from .. import types
from ..utils import exceptions
@ -34,58 +40,73 @@ def check_token(token: str) -> bool:
return True
async def _check_result(method_name, response):
"""
Checks whether `result` is a valid API response.
A result is considered invalid if:
- The server returned an HTTP response code other than 200
- The content of the result is invalid JSON.
- The method call was unsuccessful (The JSON 'ok' field equals False)
async def check_result(method_name: str, content_type: str, status_code: int, body: str):
"""
Checks whether `result` is a valid API response.
A result is considered invalid if:
- The server returned an HTTP response code other than 200
- The content of the result is invalid JSON.
- The method call was unsuccessful (The JSON 'ok' field equals False)
:raises ApiException: if one of the above listed cases is applicable
:param method_name: The name of the method called
:param response: The returned response of the method request
:return: The result parsed to a JSON dictionary.
"""
body = await response.text()
log.debug(f"Response for {method_name}: [{response.status}] {body}")
:param method_name: The name of the method called
:param status_code: status code
:param content_type: content type of result
:param body: result body
:return: The result parsed to a JSON dictionary
:raises ApiException: if one of the above listed cases is applicable
"""
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
if response.content_type != 'application/json':
raise exceptions.NetworkError(f"Invalid response with content type {response.content_type}: \"{body}\"")
if content_type != 'application/json':
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
try:
result_json = json.loads(body)
except ValueError:
result_json = {}
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
return result_json.get('result')
elif parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
elif status_code == HTTPStatus.BAD_REQUEST:
exceptions.BadRequest.detect(description)
elif status_code == HTTPStatus.NOT_FOUND:
exceptions.NotFound.detect(description)
elif status_code == HTTPStatus.CONFLICT:
exceptions.ConflictError.detect(description)
elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
exceptions.Unauthorized.detect(description)
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
async def make_request(session, token, method, data=None, files=None, **kwargs):
# log.debug(f"Make request: '{method}' with data: {data} and files {files}")
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
url = Methods.api_url(token=token, method=method)
req = compose_data(data, files)
try:
result_json = await response.json(loads=json.loads)
except ValueError:
result_json = {}
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
if HTTPStatus.OK <= response.status <= HTTPStatus.IM_USED:
return result_json.get('result')
elif parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
elif response.status == HTTPStatus.BAD_REQUEST:
exceptions.BadRequest.detect(description)
elif response.status == HTTPStatus.NOT_FOUND:
exceptions.NotFound.detect(description)
elif response.status == HTTPStatus.CONFLICT:
exceptions.ConflictError.detect(description)
elif response.status in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
exceptions.Unauthorized.detect(description)
elif response.status == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
elif response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{response.status}]")
async with session.post(url, data=req, **kwargs) as response:
return await check_result(method, response.content_type, response.status, await response.text())
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
def _guess_filename(obj):
def guess_filename(obj):
"""
Get file name from object
@ -97,7 +118,7 @@ def _guess_filename(obj):
return os.path.basename(name)
def _compose_data(params=None, files=None):
def compose_data(params=None, files=None):
"""
Prepare request data
@ -121,47 +142,13 @@ def _compose_data(params=None, files=None):
elif isinstance(f, types.InputFile):
filename, fileobj = f.filename, f.file
else:
filename, fileobj = _guess_filename(f) or key, f
filename, fileobj = guess_filename(f) or key, f
data.add_field(key, fileobj, filename=filename)
return data
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
"""
Make request to API
That make request with Content-Type:
application/x-www-form-urlencoded - For simple request
and multipart/form-data - for files uploading
https://core.telegram.org/bots/api#making-requests
:param session: HTTP Client session
:type session: :obj:`aiohttp.ClientSession`
:param token: BOT token
:type token: :obj:`str`
:param method: API method
:type method: :obj:`str`
:param data: request payload
:type data: :obj:`dict`
:param files: files
:type files: :obj:`dict`
:return: result
:rtype :obj:`bool` or :obj:`dict`
"""
log.debug("Make request: '{0}' with data: {1} and files {2}".format(
method, data or {}, files or {}))
data = _compose_data(data, files)
url = Methods.api_url(token=token, method=method)
try:
async with session.post(url, data=data, **kwargs) as response:
return await _check_result(method, response)
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
class Methods(Helper):
"""
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api

View file

@ -47,7 +47,6 @@ class BaseBot:
api.check_token(token)
self.__token = token
# Proxy settings
self.proxy = proxy
self.proxy_auth = proxy_auth
@ -59,37 +58,42 @@ class BaseBot:
# aiohttp main session
ssl_context = ssl.create_default_context(cafile=certifi.where())
if isinstance(proxy, str) and proxy.startswith('socks5://'):
from aiosocksy.connector import ProxyClientRequest, ProxyConnector
connector = ProxyConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop)
request_class = ProxyClientRequest
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
from aiohttp_socks import SocksConnector
from aiohttp_socks.helpers import parse_socks_url
socks_ver, host, port, username, password = parse_socks_url(proxy)
if proxy_auth and not username or password:
username = proxy_auth.login
password = proxy_auth.password
connector = SocksConnector(socks_ver=socks_ver, host=host, port=port,
username=username, password=password,
limit=connections_limit, ssl_context=ssl_context,
loop=self.loop)
self.proxy = None
self.proxy_auth = None
else:
connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context,
loop=self.loop)
request_class = aiohttp.ClientRequest
self.session = aiohttp.ClientSession(connector=connector, request_class=request_class,
loop=self.loop, json_serialize=json.dumps)
# Data stored in bot instance
self._data = {}
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps)
self.parse_mode = parse_mode
def __del__(self):
# asyncio.ensure_future(self.close())
pass
# Data stored in bot instance
self._data = {}
async def close(self):
"""
Close all client sessions
"""
if self.session and not self.session.closed:
await self.session.close()
await self.session.close()
async def request(self, method: base.String,
data: Optional[Dict] = None,
files: Optional[Dict] = None) -> Union[List, Dict, base.Boolean]:
files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]:
"""
Make an request to Telegram Bot API
@ -105,8 +109,8 @@ class BaseBot:
:rtype: Union[List, Dict]
:raise: :obj:`aiogram.exceptions.TelegramApiError`
"""
return await api.request(self.session, self.__token, method, data, files,
proxy=self.proxy, proxy_auth=self.proxy_auth)
return await api.make_request(self.session, self.__token, method, data, files,
proxy=self.proxy, proxy_auth=self.proxy_auth, **kwargs)
async def download_file(self, file_path: base.String,
destination: Optional[base.InputFile] = None,

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
import typing
from ...dispatcher import BaseStorage
from ...dispatcher.storage import BaseStorage
class MemoryStorage(BaseStorage):
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user)
if data is None:
data = []
data = {}
user['data'].update(data, **kwargs)
async def set_state(self, *,

View file

@ -303,6 +303,8 @@ class RedisStorage2(BaseStorage):
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
@ -330,6 +332,8 @@ class RedisStorage2(BaseStorage):
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_data(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_bucket)

View file

@ -4,7 +4,7 @@ import weakref
import rethinkdb as r
from ...dispatcher import BaseStorage
from ...dispatcher.storage import BaseStorage
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed']

View file

@ -1,115 +0,0 @@
from aiogram import types
from aiogram.dispatcher import ctx
from aiogram.dispatcher.middlewares import BaseMiddleware
OBJ_KEY = '_context_data'
class ContextMiddleware(BaseMiddleware):
"""
Allow to store data at all of lifetime of Update object
"""
async def on_pre_process_update(self, update: types.Update):
"""
Start of Update lifetime
:param update:
:return:
"""
self._configure_update(update)
async def on_post_process_update(self, update: types.Update, result):
"""
On finishing of processing update
:param update:
:param result:
:return:
"""
if OBJ_KEY in update.conf:
del update.conf[OBJ_KEY]
def _configure_update(self, update: types.Update = None):
"""
Setup data storage
:param update:
:return:
"""
obj = update.conf[OBJ_KEY] = {}
return obj
def _get_dict(self):
"""
Get data from update stored in current context
:return:
"""
update = ctx.get_update()
obj = update.conf.get(OBJ_KEY, None)
if obj is None:
obj = self._configure_update(update)
return obj
def __getitem__(self, item):
"""
Item getter
:param item:
:return:
"""
return self._get_dict()[item]
def __setitem__(self, key, value):
"""
Item setter
:param key:
:param value:
:return:
"""
data = self._get_dict()
data[key] = value
def __iter__(self):
"""
Iterate over dict
:return:
"""
return self._get_dict().__iter__()
def keys(self):
"""
Iterate over dict keys
:return:
"""
return self._get_dict().keys()
def values(self):
"""
Iterate over dict values
:return:
"""
return self._get_dict().values()
def get(self, key, default=None):
"""
Get item from dit or return default value
:param key:
:param default:
:return:
"""
return self._get_dict().get(key, default)
def export(self):
"""
Export all data s dict
:return:
"""
return self._get_dict()

View file

@ -0,0 +1,25 @@
from aiogram.dispatcher.middlewares import BaseMiddleware
class EnvironmentMiddleware(BaseMiddleware):
def __init__(self, context=None):
super(EnvironmentMiddleware, self).__init__()
if context is None:
context = {}
self.context = context
def update_data(self, data):
dp = self.manager.dispatcher
data.update(
bot=dp.bot,
dispatcher=dp,
loop=dp.loop
)
if self.context:
data.update(self.context)
async def trigger(self, action, args):
if 'error' not in action and action.startswith('pre_process_'):
self.update_data(args[-1])
return True

View file

@ -0,0 +1,80 @@
import copy
import weakref
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
from aiogram.dispatcher.storage import FSMContext
class FSMMiddleware(LifetimeControllerMiddleware):
skip_patterns = ['error', 'update']
def __init__(self):
super(FSMMiddleware, self).__init__()
self._proxies = weakref.WeakKeyDictionary()
async def pre_process(self, obj, data, *args):
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
data['state_data'] = proxy
async def post_process(self, obj, data, *args):
proxy = data.get('state_data', None)
if isinstance(proxy, FSMSStorageProxy):
await proxy.save()
class FSMSStorageProxy(dict):
def __init__(self, fsm_context: FSMContext):
super(FSMSStorageProxy, self).__init__()
self.fsm_context = fsm_context
self._copy = {}
self._data = {}
self._state = None
self._is_dirty = False
@classmethod
async def create(cls, fsm_context: FSMContext):
"""
:param fsm_context:
:return:
"""
proxy = cls(fsm_context)
await proxy.load()
return proxy
async def load(self):
self.clear()
self._state = await self.fsm_context.get_state()
self.update(await self.fsm_context.get_data())
self._copy = copy.deepcopy(self)
self._is_dirty = False
@property
def state(self):
return self._state
@state.setter
def state(self, value):
self._state = value
self._is_dirty = True
@state.deleter
def state(self):
self._state = None
self._is_dirty = True
async def save(self, force=False):
if self._copy != self or force:
await self.fsm_context.set_data(data=self)
if self._is_dirty or force:
await self.fsm_context.set_state(self.state)
self._is_dirty = False
self._copy = copy.deepcopy(self)
def __str__(self):
s = super(FSMSStorageProxy, self).__str__()
readable_state = f"'{self.state}'" if self.state else "''"
return f"<{self.__class__.__name__}(state={readable_state}, data={s})>"
def clear(self):
del self.state
return super(FSMSStorageProxy, self).clear()

View file

@ -0,0 +1,140 @@
import gettext
import os
from contextvars import ContextVar
from typing import Any, Dict, Tuple
from babel import Locale
from ... import types
from ...dispatcher.middlewares import BaseMiddleware
class I18nMiddleware(BaseMiddleware):
"""
I18n middleware based on gettext util
>>> dp = Dispatcher(bot)
>>> i18n = I18nMiddleware(DOMAIN, LOCALES_DIR)
>>> dp.middleware.setup(i18n)
and then
>>> _ = i18n.gettext
or
>>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR)
"""
ctx_locale = ContextVar('ctx_user_locale', default=None)
def __init__(self, domain, path=None, default='en'):
"""
:param domain: domain
:param path: path where located all *.mo files
:param default: default locale name
"""
super(I18nMiddleware, self).__init__()
if path is None:
path = os.path.join(os.getcwd(), 'locales')
self.domain = domain
self.path = path
self.default = default
self.locales = self.find_locales()
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
"""
Load all compiled locales from path
:return: dict with locales
"""
translations = {}
for name in os.listdir(self.path):
if not os.path.isdir(os.path.join(self.path, name)):
continue
mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo')
if os.path.exists(mo_path):
with open(mo_path, 'rb') as fp:
translations[name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + 'po'):
raise RuntimeError(f"Found locale '{name} but this language is not compiled!")
return translations
def reload(self):
"""
Hot reload locles
"""
self.locales = self.find_locales()
@property
def available_locales(self) -> Tuple[str]:
"""
list of loaded locales
:return:
"""
return tuple(self.locales.keys())
def __call__(self, singular, plural=None, n=1, locale=None) -> str:
return self.gettext(singular, plural, n, locale)
def gettext(self, singular, plural=None, n=1, locale=None) -> str:
"""
Get text
:param singular:
:param plural:
:param n:
:param locale:
:return:
"""
if locale is None:
locale = self.ctx_locale.get()
if locale not in self.locales:
if n is 1:
return singular
else:
return plural
translator = self.locales[locale]
if plural is None:
return translator.gettext(singular)
else:
return translator.ngettext(singular, plural, n)
# noinspection PyMethodMayBeStatic,PyUnusedLocal
async def get_user_locale(self, action: str, args: Tuple[Any]) -> str:
"""
User locale getter
You can override the method if you want to use different way of getting user language.
:param action: event name
:param args: event arguments
:return: locale name
"""
user: types.User = types.User.current()
locale: Locale = user.locale
if locale:
*_, data = args
language = data['locale'] = locale.language
return language
async def trigger(self, action, args):
"""
Event trigger
:param action: event name
:param args: event arguments
:return:
"""
if 'update' not in action \
and 'error' not in action \
and action.startswith('pre_process'):
locale = await self.get_user_locale(action, args)
self.ctx_locale.set(locale)
return True

View file

@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware):
return round((time.time() - start) * 1000)
return -1
async def on_pre_process_update(self, update: types.Update):
async def on_pre_process_update(self, update: types.Update, data: dict):
update.conf['_start'] = time.time()
self.logger.debug(f"Received update [ID:{update.update_id}]")
async def on_post_process_update(self, update: types.Update, result):
async def on_post_process_update(self, update: types.Update, result, data: dict):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
async def on_pre_process_message(self, message: types.Message):
async def on_pre_process_message(self, message: types.Message, data: dict):
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_post_process_message(self, message: types.Message, results):
async def on_post_process_message(self, message: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_pre_process_edited_message(self, edited_message):
async def on_pre_process_edited_message(self, edited_message, data: dict):
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_post_process_edited_message(self, edited_message, results):
async def on_post_process_edited_message(self, edited_message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_pre_process_channel_post(self, channel_post: types.Message):
async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]")
async def on_post_process_channel_post(self, channel_post: types.Message, results):
async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
if callback_query.message:
if callback_query.message.from_user:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
@ -100,7 +100,7 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_post_process_callback_query(self, callback_query, results):
async def on_post_process_callback_query(self, callback_query, results, data: dict):
if callback_query.message:
if callback_query.message.from_user:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
@ -117,25 +117,25 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_post_process_shipping_query(self, shipping_query, results):
async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_pre_process_error(self, dispatcher, update, error):
async def on_pre_process_error(self, dispatcher, update, error, data: dict):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
from . import Bot
from .. import types
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
from ..utils import context
def _get(key, default=None, no_error=False):
result = context.get_value(key, default)
if not no_error and result is None:
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
f"Maybe asyncio task factory is not configured!\n"
f"\t>>> from aiogram.utils import context\n"
f"\t>>> loop.set_task_factory(context.task_factory)")
return result
def get_bot() -> Bot:
return _get('bot')
def get_dispatcher() -> Dispatcher:
return _get('dispatcher')
def get_update() -> types.Update:
return _get(UPDATE_OBJECT)
def get_mode() -> str:
return _get(MODE, 'unknown')
def get_chat() -> int:
return _get('chat', no_error=True)
def get_user() -> int:
return _get('user', no_error=True)
def get_state() -> FSMContext:
return get_dispatcher().current_state()

File diff suppressed because it is too large Load diff

View file

@ -1,289 +0,0 @@
import asyncio
import inspect
import re
from ..types import CallbackQuery, ContentType, Message
from ..utils import context
from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args):
"""
Helper for executing filter
:param filter_:
:param args:
:param kwargs:
:return:
"""
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args)
else:
return filter_(*args)
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
if not f:
return False
return True
class Filter:
"""
Base class for filters
"""
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
def check(self, *args, **kwargs):
raise NotImplementedError
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __aiter__(self):
return None
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass
class AnyFilter(AsyncFilter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(AsyncFilter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter):
"""
Check commands in message
"""
def __init__(self, commands):
self.commands = commands
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
if command not in self.commands:
return False
return True
class RegexpFilter(Filter):
"""
Regexp filter for messages
"""
def __init__(self, regexp):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
def check(self, obj):
if isinstance(obj, Message) and obj.text:
return bool(self.regexp.search(obj.text))
elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data))
return False
class RegexpCommandsFilter(AsyncFilter):
"""
Check commands by regexp in message
"""
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
message.conf['regexp_command'] = search
return True
return False
class ContentTypeFilter(Filter):
"""
Check message content type
"""
def __init__(self, content_types):
self.content_types = content_types
def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
class CancelFilter(Filter):
"""
Find cancel in message text
"""
def __init__(self, cancel_set=None):
if cancel_set is None:
cancel_set = ['/cancel', 'cancel', 'cancel.']
self.cancel_set = cancel_set
def check(self, message):
if message.text:
return message.text.lower() in self.cancel_set
class StateFilter(AsyncFilter):
"""
Check user state
"""
def __init__(self, dispatcher, state):
self.dispatcher = dispatcher
self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if self.state == '*':
return True
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False
class StatesListFilter(StateFilter):
"""
List of states
"""
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
class ExceptionsFilter(Filter):
"""
Filter for exceptions
"""
def __init__(self, exception):
self.exception = exception
def check(self, dispatcher, update, exception):
return isinstance(exception, self.exception)
def generate_default_filters(dispatcher, *args, **kwargs):
"""
Prepare filters
:param dispatcher:
:param args:
:param kwargs:
:return:
"""
filters_set = []
for name, filter_ in kwargs.items():
if filter_ is None and name != DefaultFilters.STATE:
continue
if name == DefaultFilters.COMMANDS:
if isinstance(filter_, str):
filters_set.append(CommandsFilter([filter_]))
else:
filters_set.append(CommandsFilter(filter_))
elif name == DefaultFilters.REGEXP:
filters_set.append(RegexpFilter(filter_))
elif name == DefaultFilters.CONTENT_TYPES:
filters_set.append(ContentTypeFilter(filter_))
elif name == DefaultFilters.FUNC:
filters_set.append(filter_)
elif name == DefaultFilters.STATE:
if isinstance(filter_, (list, set, tuple)):
filters_set.append(StatesListFilter(dispatcher, filter_))
else:
filters_set.append(StateFilter(dispatcher, filter_))
elif isinstance(filter_, Filter):
filters_set.append(filter_)
filters_set += list(args)
return filters_set
class DefaultFilters(Helper):
mode = HelperMode.snake_case
COMMANDS = Item() # commands
REGEXP = Item() # regexp
CONTENT_TYPES = Item() # content_type
FUNC = Item() # func
STATE = Item() # state

View file

@ -0,0 +1,24 @@
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
RegexpCommandsFilter, StateFilter, Text
from .factory import FiltersFactory
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [
'AbstractFilter',
'BoundFilter',
'Command',
'CommandStart',
'CommandHelp',
'ContentTypeFilter',
'ExceptionsFilter',
'Filter',
'FilterNotPassed',
'FilterRecord',
'FiltersFactory',
'RegexpCommandsFilter',
'Regexp',
'StateFilter',
'Text',
'check_filter',
'check_filters'
]

View file

@ -0,0 +1,313 @@
import inspect
import re
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Optional, Union
from aiogram import types
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
from aiogram.types import CallbackQuery, Message
class Command(Filter):
"""
You can handle commands by using this filter
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False):
"""
Filter can be initialized from filters factory or by simply creating instance of this class
:param commands: command or list of commands
:param prefixes:
:param ignore_case:
:param ignore_mention:
"""
if isinstance(commands, str):
commands = (commands,)
self.commands = list(map(str.lower, commands)) if ignore_case else commands
self.prefixes = prefixes
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
@classmethod
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Validator for filters factory
:param full_config:
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
full_command = message.text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
return False
elif prefix not in prefixes:
return False
elif (command.lower() if ignore_case else command) not in commands:
return False
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
@dataclass
class CommandObj:
prefix: str = '/'
command: str = ''
mention: str = None
args: str = field(repr=False, default=None)
@property
def mentioned(self) -> bool:
return bool(self.mention)
@property
def text(self) -> str:
line = self.prefix + self.command
if self.mentioned:
line += '@' + self.mention
if self.args:
line += ' ' + self.args
return line
class CommandStart(Command):
def __init__(self):
super(CommandStart, self).__init__(['start'])
class CommandHelp(Command):
def __init__(self):
super(CommandHelp, self).__init__(['help'])
class Text(Filter):
"""
Simple text filter
"""
def __init__(self,
equals: Optional[str] = None,
contains: Optional[str] = None,
startswith: Optional[str] = None,
endswith: Optional[str] = None,
ignore_case=False):
"""
Check text for one of pattern. Only one mode can be used in one filter.
:param equals:
:param contains:
:param startswith:
:param endswith:
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('startswith', startswith),
('endswith', endswith)
] if arg[1]])
raise ValueError(f"Arguments '{args}' cannot be used together.")
elif check == 0:
raise ValueError(f"No one mode is specified!")
self.equals = equals
self.contains = contains
self.endswith = endswith
self.startswith = startswith
self.ignore_case = ignore_case
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'text' in full_config:
return {'equals': full_config.pop('text')}
elif 'text_contains' in full_config:
return {'contains': full_config.pop('text_contains')}
elif 'text_startswith' in full_config:
return {'startswith': full_config.pop('text_startswith')}
elif 'text_endswith' in full_config:
return {'endswith': full_config.pop('text_endswith')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
elif isinstance(obj, CallbackQuery):
text = obj.data
else:
return False
if self.ignore_case:
text = text.lower()
if self.equals:
return text == self.equals
elif self.contains:
return self.contains in text
elif self.startswith:
return text.startswith(self.startswith)
elif self.endswith:
return text.endswith(self.endswith)
return False
class Regexp(Filter):
"""
Regexp filter for messages and callback query
"""
def __init__(self, regexp):
if not isinstance(regexp, re.Pattern):
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
self.regexp = regexp
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
match = self.regexp.search(obj.text or obj.caption or '')
elif isinstance(obj, CallbackQuery) and obj.data:
match = self.regexp.search(obj.data)
else:
return False
if match:
return {'regexp': match}
return False
class RegexpCommandsFilter(BoundFilter):
"""
Check commands by regexp in message
"""
key = 'regexp_commands'
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
return {'regexp_command': search}
return False
class ContentTypeFilter(BoundFilter):
"""
Check message content type
"""
key = 'content_types'
required = True
default = types.ContentTypes.TEXT
def __init__(self, content_types):
self.content_types = content_types
async def check(self, message):
return types.ContentType.ANY in self.content_types or \
message.content_type in self.content_types
class StateFilter(BoundFilter):
"""
Check user state
"""
key = 'state'
required = True
ctx_state = ContextVar('user_state')
def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup
self.dispatcher = dispatcher
states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ]
for item in state:
if isinstance(item, State):
states.append(item.state)
elif inspect.isclass(item) and issubclass(item, StatesGroup):
states.extend(item.all_states_names)
else:
states.append(item)
self.states = states
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if '*' in self.states:
return {'state': self.dispatcher.current_state()}
try:
state = self.ctx_state.get()
except LookupError:
chat, user = self.get_target(obj)
if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state)
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
else:
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return False
class ExceptionsFilter(BoundFilter):
"""
Filter for exceptions
"""
key = 'exception'
def __init__(self, dispatcher, exception):
super().__init__(dispatcher)
self.exception = exception
async def check(self, dispatcher, update, exception):
try:
raise exception
except self.exception:
return True
except:
return False

View file

@ -0,0 +1,73 @@
import typing
from .filters import AbstractFilter, FilterRecord
from ..handler import Handler
class FiltersFactory:
"""
Default filters factory
"""
def __init__(self, dispatcher):
self._dispatcher = dispatcher
self._registered: typing.List[FilterRecord] = []
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
"""
Register filter
:param callback: callable or subclass of :obj:`AbstractFilter`
:param validator: custom validator.
:param event_handlers: list of instances of :obj:`Handler`
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
"""
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
self._registered.append(record)
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
"""
Unregister callback
:param callback: callable of subclass of :obj:`AbstractFilter`
"""
for record in self._registered:
if record.callback == callback:
self._registered.remove(record)
def resolve(self, event_handler, *custom_filters, **full_config
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
"""
Resolve filters to filters-set
:param event_handler:
:param custom_filters:
:param full_config:
:return:
"""
filters_set = []
filters_set.extend(self._resolve_registered(event_handler,
{k: v for k, v in full_config.items() if v is not None}))
if custom_filters:
filters_set.extend(custom_filters)
return filters_set
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
"""
Resolve registered filters
:param event_handler:
:param full_config:
:return:
"""
for record in self._registered:
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
if filter_:
yield filter_
if full_config:
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')

View file

@ -0,0 +1,250 @@
import abc
import inspect
import typing
from ..handler import Handler
from ...types.base import TelegramObject
class FilterNotPassed(Exception):
pass
def wrap_async(func):
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
if inspect.isawaitable(func) \
or inspect.iscoroutinefunction(func) \
or isinstance(func, AbstractFilter):
return func
return async_wrapper
async def check_filter(dispatcher, filter_, args):
"""
Helper for executing filter
:param dispatcher:
:param filter_:
:param args:
:return:
"""
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
return await filter_(*args, **kwargs)
else:
return filter_(*args, **kwargs)
async def check_filters(dispatcher, filters, args):
"""
Check list of filters
:param dispatcher:
:param filters:
:param args:
:return:
"""
data = {}
if filters is not None:
for filter_ in filters:
f = await check_filter(dispatcher, filter_, args)
if not f:
raise FilterNotPassed()
elif isinstance(f, dict):
data.update(f)
return data
class FilterRecord:
"""
Filters record for factory
"""
def __init__(self, callback: typing.Callable,
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
if event_handlers and exclude_event_handlers:
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
self.callback = callback
self.event_handlers = event_handlers
self.exclude_event_handlers = exclude_event_handlers
if validator is not None:
if not callable(validator):
raise TypeError(f"validator must be callable, not {type(validator)}")
self.resolver = validator
elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate
else:
raise RuntimeError('validator is required!')
def resolve(self, dispatcher, event_handler, full_config):
if not self._check_event_handler(event_handler):
return
config = self.resolver(full_config)
if config:
if 'dispatcher' not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
for key in config:
if key in full_config:
full_config.pop(key)
return self.callback(**config)
def _check_event_handler(self, event_handler) -> bool:
if self.event_handlers:
return event_handler in self.event_handlers
elif self.exclude_event_handlers:
return event_handler not in self.exclude_event_handlers
return True
class AbstractFilter(abc.ABC):
"""
Abstract class for custom filters
"""
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Validate and parse config
:param full_config:
:return: config
"""
pass
@abc.abstractmethod
async def check(self, *args) -> bool:
"""
Check object
:param args:
:return:
"""
pass
async def __call__(self, obj: TelegramObject) -> bool:
return await self.check(obj)
def __invert__(self):
return NotFilter(self)
def __and__(self, other):
if isinstance(self, AndFilter):
self.append(other)
return self
return AndFilter(self, other)
def __or__(self, other):
if isinstance(self, OrFilter):
self.append(other)
return self
return OrFilter(self, other)
class Filter(AbstractFilter):
"""
You can make subclasses of that class for custom filters
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
pass
class BoundFilter(Filter):
"""
Base class for filters with default validator
"""
key = None
required = False
default = None
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None:
if cls.key in full_config:
return {cls.key: full_config[cls.key]}
elif cls.required:
return {cls.key: cls.default}
class _LogicFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
class NotFilter(_LogicFilter):
def __init__(self, target):
self.target = wrap_async(target)
async def check(self, *args):
return not bool(await self.target(*args))
class AndFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
"""
All filters must return a positive result
:param args:
:return:
"""
data = {}
for target in self.targets:
result = await target(*args)
if not result:
return False
if isinstance(result, dict):
data.update(result)
if not data:
return True
return data
def append(self, target):
self.targets.append(wrap_async(target))
class OrFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
"""
One of filters must return a positive result
:param args:
:return:
"""
for target in self.targets:
result = await target(*args)
if result:
if isinstance(result, dict):
return result
return True
return False
def append(self, target):
self.targets.append(wrap_async(target))

View file

@ -0,0 +1,198 @@
import inspect
from typing import Optional
from ..dispatcher import Dispatcher
class State:
"""
State object
"""
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
self._state = state
self._group_name = group_name
self._group = None
@property
def group(self):
if not self._group:
raise RuntimeError('This state is not in any group.')
return self._group
def get_root(self):
return self.group.get_root()
@property
def state(self):
if self._state is None:
return None
elif self._state == '*':
return self._state
elif self._group_name is None and self._group:
group = self._group.__full_group_name__
elif self._group_name:
group = self._group_name
else:
group = '@'
return f"{group}:{self._state}"
def set_parent(self, group):
if not issubclass(group, StatesGroup):
raise ValueError('Group must be subclass of StatesGroup')
self._group = group
def __set_name__(self, owner, name):
if self._state is None:
self._state = name
self.set_parent(owner)
def __str__(self):
return f"<State '{self.state or ''}'>"
__repr__ = __str__
async def set(self):
state = Dispatcher.current().current_state()
await state.set_state(self.state)
class StatesGroupMeta(type):
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
states = []
childs = []
cls._group_name = name
for name, prop in namespace.items():
if isinstance(prop, State):
states.append(prop)
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
childs.append(prop)
prop._parent = cls
# continue
cls._parent = None
cls._childs = tuple(childs)
cls._states = tuple(states)
cls._state_names = tuple(state.state for state in states)
return cls
@property
def __group_name__(cls):
return cls._group_name
@property
def __full_group_name__(cls):
if cls._parent:
return cls._parent.__full_group_name__ + '.' + cls._group_name
return cls._group_name
@property
def states(cls) -> tuple:
return cls._states
@property
def childs(cls):
return cls._childs
@property
def all_childs(cls):
result = cls.childs
for child in cls.childs:
result += child.childs
return result
@property
def all_states(cls):
result = cls.states
for group in cls.childs:
result += group.all_states
return result
@property
def all_states_names(cls):
return tuple(state.state for state in cls.all_states)
@property
def states_names(cls) -> tuple:
return tuple(state.state for state in cls.states)
def get_root(cls):
if cls._parent is None:
return cls
return cls._parent.get_root()
def __contains__(cls, item):
if isinstance(item, str):
return item in cls.all_states_names
elif isinstance(item, State):
return item in cls.all_states
elif isinstance(item, StatesGroup):
return item in cls.all_childs
return False
def __str__(self):
return f"<StatesGroup '{self.__full_group_name__}'>"
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def next(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
next_step = cls.states_names.index(state_name) + 1
except ValueError:
next_step = 0
try:
next_state_name = cls.states[next_step].state
except IndexError:
next_state_name = None
await state.set_state(next_state_name)
return next_state_name
@classmethod
async def previous(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
previous_step = cls.states_names.index(state_name) - 1
except ValueError:
previous_step = 0
if previous_step < 0:
previous_state_name = None
else:
previous_state_name = cls.states[previous_step].state
await state.set_state(previous_state_name)
return previous_state_name
@classmethod
async def first(cls) -> str:
state = Dispatcher.current().current_state()
first_step_name = cls.states_names[0]
await state.set_state(first_step_name)
return first_step_name
@classmethod
async def last(cls) -> str:
state = Dispatcher.current().current_state()
last_step_name = cls.states_names[-1]
await state.set_state(last_step_name)
return last_step_name
default_state = State()
any_state = State(state='*')

View file

@ -1,5 +1,7 @@
from .filters import check_filters
from ..utils import context
import inspect
from contextvars import ContextVar
ctx_data = ContextVar('ctx_handler_data')
class SkipHandler(BaseException):
@ -10,6 +12,14 @@ class CancelHandler(BaseException):
pass
def _check_spec(func: callable, kwargs: dict):
spec = inspect.getfullargspec(func)
if spec.varkw:
return kwargs
return {k: v for k, v in kwargs.items() if k in spec.args}
class Handler:
def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher
@ -57,31 +67,43 @@ class Handler:
:param args:
:return:
"""
from .filters import check_filters, FilterNotPassed
results = []
data = {}
ctx_data.set(data)
if self.middleware_key:
try:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
except CancelHandler: # Allow to cancel current event
return results
for filters, handler in self.handlers:
if await check_filters(filters, args):
try:
for filters, handler in self.handlers:
try:
if self.middleware_key:
context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
data.update(await check_filters(self.dispatcher, filters, args))
except FilterNotPassed:
continue
except CancelHandler:
break
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
else:
try:
if self.middleware_key:
# context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
partial_data = _check_spec(handler, data)
response = await handler(*args, **partial_data)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue
except CancelHandler:
break
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results, data,))
return results

View file

@ -101,3 +101,28 @@ class BaseMiddleware:
if not handler:
return None
await handler(*args)
class LifetimeControllerMiddleware(BaseMiddleware):
# TODO: Rename class
skip_patterns = None
async def pre_process(self, obj, data, *args):
pass
async def post_process(self, obj, data, *args):
pass
async def trigger(self, action, args):
if self.skip_patterns is not None and any(item in action for item in self.skip_patterns):
return False
obj, *args, data = args
if action.startswith('pre_process_'):
await self.pre_process(obj, data, *args)
elif action.startswith('post_process_'):
await self.post_process(obj, data, *args)
else:
return False
return True

View file

@ -281,8 +281,20 @@ class FSMContext:
def __exit__(self, exc_type, exc_val, exc_tb):
pass
@staticmethod
def _resolve_state(value):
from .filters.state import State
if value is None:
return
elif isinstance(value, str):
return value
elif isinstance(value, State):
return value.state
return str(value)
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
@ -291,7 +303,7 @@ class FSMContext:
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
await self.storage.set_state(chat=self.chat, user=self.user, state=state)
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
async def set_data(self, data: typing.Dict = None):
await self.storage.set_data(chat=self.chat, user=self.user, data=data)

View file

@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union
from aiohttp import web
from aiohttp.web_exceptions import HTTPGone
from .. import types
from ..bot import api
from ..types import ParseMode
from ..types.base import Boolean, Float, Integer, String
from ..utils import context
from ..utils import helper, markdown
from ..utils import json
from ..utils.deprecated import warn_deprecated as warn
@ -89,8 +89,10 @@ class WebhookRequestHandler(web.View):
"""
dp = self.request.app[BOT_DISPATCHER_KEY]
try:
context.set_value('dispatcher', dp)
context.set_value('bot', dp.bot)
from aiogram.bot import bot
from aiogram.dispatcher import dispatcher
dispatcher.set(dp)
bot.bot.set(dp.bot)
except RuntimeError:
pass
return dp
@ -117,9 +119,9 @@ class WebhookRequestHandler(web.View):
"""
self.validate_ip()
context.update_state({'CALLER': WEBHOOK,
WEBHOOK_CONNECTION: True,
WEBHOOK_REQUEST: self.request})
# context.update_state({'CALLER': WEBHOOK,
# WEBHOOK_CONNECTION: True,
# WEBHOOK_REQUEST: self.request})
dispatcher = self.get_dispatcher()
update = await self.parse_update(dispatcher.bot)
@ -177,7 +179,7 @@ class WebhookRequestHandler(web.View):
if fut.done():
return fut.result()
else:
context.set_value(WEBHOOK_CONNECTION, False)
# context.set_value(WEBHOOK_CONNECTION, False)
fut.remove_done_callback(cb)
fut.add_done_callback(self.respond_via_request)
finally:
@ -202,7 +204,7 @@ class WebhookRequestHandler(web.View):
results = task.result()
except Exception as e:
loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
else:
response = self.get_response(results)
if response is not None:
@ -249,7 +251,7 @@ class WebhookRequestHandler(web.View):
ip_address, accept = self.check_ip()
if not accept:
raise web.HTTPUnauthorized()
context.set_value('TELEGRAM_IP', ip_address)
# context.set_value('TELEGRAM_IP', ip_address)
class GoneRequestHandler(web.View):
@ -352,8 +354,8 @@ class BaseResponse:
async def __call__(self, bot=None):
if bot is None:
from aiogram.dispatcher import ctx
bot = ctx.get_bot()
from aiogram import Bot
bot = Bot.current()
return await self.execute_response(bot)
async def __aenter__(self):
@ -446,7 +448,8 @@ class ParseModeMixin:
:return:
"""
bot = context.get_value('bot', None)
from aiogram import Bot
bot = Bot.current()
if bot is not None:
return bot.parse_mode
@ -952,7 +955,7 @@ class SendMediaGroup(BaseResponse, ReplyToMixin, DisableNotificationMixin):
self.reply_to_message_id = reply_to_message_id
def prepare(self):
files = self.media.get_files()
files = dict(self.media.get_files())
if files:
raise TypeError('Allowed only file ID or URL\'s')

View file

@ -34,7 +34,7 @@ from .invoice import Invoice
from .labeled_price import LabeledPrice
from .location import Location
from .mask_position import MaskPosition
from .message import ContentType, Message, ParseMode
from .message import ContentType, ContentTypes, Message, ParseMode
from .message_entity import MessageEntity, MessageEntityType
from .order_info import OrderInfo
from .passport_data import PassportData
@ -77,6 +77,7 @@ __all__ = (
'ChosenInlineResult',
'Contact',
'ContentType',
'ContentTypes',
'Document',
'EncryptedCredentials',
'EncryptedPassportElement',

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import io
import typing
from contextvars import ContextVar
from typing import TypeVar
from .fields import BaseField
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
setattr(cls, ALIASES_ATTR_NAME, aliases)
mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls
@property
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values:
self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property
def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf
@ -137,8 +150,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
@property
def bot(self):
from ..dispatcher import ctx
return ctx.get_bot()
from ..bot.bot import Bot
return Bot.current()
def to_python(self) -> typing.Dict:
"""

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import asyncio
import typing
from contextvars import ContextVar
from . import base
from . import fields
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
if as_html:
return markdown.hlink(name, self.user_url)
return markdown.link(name, self.user_url)
async def get_url(self):
"""
Use this method to get chat link.
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
@classmethod
async def _do(cls, action: str, sleep=None):
from ..dispatcher.ctx import get_bot, get_chat
await get_bot().send_chat_action(get_chat(), action)
from aiogram import Bot
await Bot.current().send_chat_action(Chat.current(), action)
if sleep:
await asyncio.sleep(sleep)

View file

@ -9,7 +9,7 @@ class BaseField(metaclass=abc.ABCMeta):
Base field (prop)
"""
def __init__(self, *, base=None, default=None, alias=None):
def __init__(self, *, base=None, default=None, alias=None, on_change=None):
"""
Init prop
@ -17,10 +17,12 @@ class BaseField(metaclass=abc.ABCMeta):
:param default: default value
:param alias: alias name (for e.g. field 'from' has to be named 'from_user'
as 'from' is a builtin Python keyword
:param on_change: callback will be called when value is changed
"""
self.base_object = base
self.default = default
self.alias = alias
self.on_change = on_change
def __set_name__(self, owner, name):
if self.alias is None:
@ -53,6 +55,13 @@ class BaseField(metaclass=abc.ABCMeta):
self.resolve_base(instance)
value = self.deserialize(value, parent)
instance.values[self.alias] = value
self._trigger_changed(instance, value)
def _trigger_changed(self, instance, value):
if not self.on_change and instance is not None:
return
callback = getattr(instance, self.on_change)
callback(value)
def __get__(self, instance, owner):
return self.get_value(instance)
@ -154,7 +163,7 @@ class ListOfLists(Field):
return result
class DateTimeField(BaseField):
class DateTimeField(Field):
"""
In this field st_ored datetime
@ -167,3 +176,24 @@ class DateTimeField(BaseField):
def deserialize(self, value, parent=None):
return datetime.datetime.fromtimestamp(value)
class TextField(Field):
def __init__(self, *, prefix=None, suffix=None, default=None, alias=None):
super(TextField, self).__init__(default=default, alias=alias)
self.prefix = prefix
self.suffix = suffix
def serialize(self, value):
if value is None:
return value
if self.prefix:
value = self.prefix + value
if self.suffix:
value += self.suffix
return value
def deserialize(self, value, parent=None):
if value is not None and not isinstance(value, str):
raise TypeError(f"Field '{self.alias}' should be str not {type(value).__name__}")
return value

View file

@ -1,6 +1,7 @@
import io
import logging
import os
import secrets
import time
import aiohttp
@ -45,6 +46,8 @@ class InputFile(base.TelegramObject):
self._filename = filename
self.attachment_key = secrets.token_urlsafe(16)
def __del__(self):
"""
Close file descriptor
@ -54,13 +57,17 @@ class InputFile(base.TelegramObject):
@property
def filename(self):
if self._filename is None:
self._filename = api._guess_filename(self._file)
self._filename = api.guess_filename(self._file)
return self._filename
@filename.setter
def filename(self, value):
self._filename = value
@property
def attach(self):
return f"attach://{self.attachment_key}"
def get_filename(self) -> str:
"""
Get file name
@ -159,6 +166,9 @@ class InputFile(base.TelegramObject):
return writer
def __str__(self):
return f"<InputFile 'attach://{self.attachment_key}' with file='{self.file}'>"
def to_python(self):
raise TypeError('Object of this type is not exportable!')

View file

@ -12,6 +12,9 @@ ATTACHMENT_PREFIX = 'attach://'
class InputMedia(base.TelegramObject):
"""
This object represents the content of a media message to be sent. It should be one of
- InputMediaAnimation
- InputMediaDocument
- InputMediaAudio
- InputMediaPhoto
- InputMediaVideo
@ -20,36 +23,76 @@ class InputMedia(base.TelegramObject):
https://core.telegram.org/bots/api#inputmedia
"""
type: base.String = fields.Field(default='photo')
media: base.String = fields.Field()
thumb: typing.Union[base.InputFile, base.String] = fields.Field()
media: base.String = fields.Field(alias='media', on_change='_media_changed')
thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed')
caption: base.String = fields.Field()
parse_mode: base.Boolean = fields.Field()
def __init__(self, *args, **kwargs):
self._thumb_file = None
self._media_file = None
media = kwargs.pop('media', None)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
elif media is not None:
self.media = media
thumb = kwargs.pop('thumb', None)
if isinstance(thumb, (io.IOBase, InputFile)):
self.thumb_file = thumb
elif thumb is not None:
self.thumb = thumb
super(InputMedia, self).__init__(*args, **kwargs)
try:
if self.parse_mode is None and self.bot.parse_mode:
if self.parse_mode is None and self.bot and self.bot.parse_mode:
self.parse_mode = self.bot.parse_mode
except RuntimeError:
pass
@property
def file(self):
return getattr(self, '_file', None)
return self._media_file
@file.setter
def file(self, file: io.IOBase):
setattr(self, '_file', file)
attachment_key = self.attachment_key = secrets.token_urlsafe(16)
self.media = ATTACHMENT_PREFIX + attachment_key
self.media = 'attach://' + secrets.token_urlsafe(16)
self._media_file = file
@file.deleter
def file(self):
self.media = None
self._media_file = None
def _media_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
self._media_file = None
@property
def attachment_key(self):
return self.conf.get('attachment_key', None)
def thumb_file(self):
return self._thumb_file
@attachment_key.setter
def attachment_key(self, value):
self.conf['attachment_key'] = value
@thumb_file.setter
def thumb_file(self, file: io.IOBase):
self.thumb = 'attach://' + secrets.token_urlsafe(16)
self._thumb_file = file
@thumb_file.deleter
def thumb_file(self):
self.thumb = None
self._thumb_file = None
def _thumb_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
self._thumb_file = None
def get_files(self):
if self._media_file:
yield self.media[9:], self._media_file
if self._thumb_file:
yield self.thumb[9:], self._thumb_file
class InputMediaAnimation(InputMedia):
@ -72,9 +115,6 @@ class InputMediaAnimation(InputMedia):
width=width, height=height, duration=duration,
parse_mode=parse_mode, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaDocument(InputMedia):
"""
@ -89,9 +129,6 @@ class InputMediaDocument(InputMedia):
caption=caption, parse_mode=parse_mode,
conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaAudio(InputMedia):
"""
@ -119,9 +156,6 @@ class InputMediaAudio(InputMedia):
performer=performer, title=title,
parse_mode=parse_mode, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaPhoto(InputMedia):
"""
@ -136,9 +170,6 @@ class InputMediaPhoto(InputMedia):
caption=caption, parse_mode=parse_mode,
conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaVideo(InputMedia):
"""
@ -151,18 +182,17 @@ class InputMediaVideo(InputMedia):
duration: base.Integer = fields.Field()
supports_streaming: base.Boolean = fields.Field()
def __init__(self, media: base.InputFile, caption: base.String = None,
def __init__(self, media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
parse_mode: base.Boolean = None,
supports_streaming: base.Boolean = None, **kwargs):
super(InputMediaVideo, self).__init__(type='video', media=media, caption=caption,
super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption,
width=width, height=height, duration=duration,
parse_mode=parse_mode,
supports_streaming=supports_streaming, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class MediaGroup(base.TelegramObject):
"""
@ -296,6 +326,7 @@ class MediaGroup(base.TelegramObject):
self.attach(photo)
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
"""
@ -308,7 +339,7 @@ class MediaGroup(base.TelegramObject):
:param duration:
"""
if not isinstance(video, InputMedia):
video = InputMediaVideo(media=video, caption=caption,
video = InputMediaVideo(media=video, thumb=thumb, caption=caption,
width=width, height=height, duration=duration)
self.attach(video)
@ -327,6 +358,7 @@ class MediaGroup(base.TelegramObject):
return result
def get_files(self):
return {inputmedia.attachment_key: inputmedia.file
for inputmedia in self.media
if isinstance(inputmedia, InputMedia) and inputmedia.file}
for inputmedia in self.media:
if not isinstance(inputmedia, InputMedia) or not inputmedia.file:
continue
yield from inputmedia.get_files()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import functools
import sys
@ -42,7 +44,7 @@ class Message(base.TelegramObject):
forward_from_message_id: base.Integer = fields.Field()
forward_signature: base.String = fields.Field()
forward_date: datetime.datetime = fields.DateTimeField()
reply_to_message: 'Message' = fields.Field(base='Message')
reply_to_message: Message = fields.Field(base='Message')
edit_date: datetime.datetime = fields.DateTimeField()
media_group_id: base.String = fields.Field()
author_signature: base.String = fields.Field()
@ -72,7 +74,7 @@ class Message(base.TelegramObject):
channel_chat_created: base.Boolean = fields.Field()
migrate_to_chat_id: base.Integer = fields.Field()
migrate_from_chat_id: base.Integer = fields.Field()
pinned_message: 'Message' = fields.Field(base='Message')
pinned_message: Message = fields.Field(base='Message')
invoice: Invoice = fields.Field(base=Invoice)
successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment)
connected_website: base.String = fields.Field()
@ -82,59 +84,59 @@ class Message(base.TelegramObject):
@functools.lru_cache()
def content_type(self):
if self.text:
return ContentType.TEXT[0]
return ContentType.TEXT
elif self.audio:
return ContentType.AUDIO[0]
return ContentType.AUDIO
elif self.animation:
return ContentType.ANIMATION[0]
return ContentType.ANIMATION
elif self.document:
return ContentType.DOCUMENT[0]
return ContentType.DOCUMENT
elif self.game:
return ContentType.GAME[0]
return ContentType.GAME
elif self.photo:
return ContentType.PHOTO[0]
return ContentType.PHOTO
elif self.sticker:
return ContentType.STICKER[0]
return ContentType.STICKER
elif self.video:
return ContentType.VIDEO[0]
return ContentType.VIDEO
elif self.video_note:
return ContentType.VIDEO_NOTE[0]
return ContentType.VIDEO_NOTE
elif self.voice:
return ContentType.VOICE[0]
return ContentType.VOICE
elif self.contact:
return ContentType.CONTACT[0]
return ContentType.CONTACT
elif self.venue:
return ContentType.VENUE[0]
return ContentType.VENUE
elif self.location:
return ContentType.LOCATION[0]
return ContentType.LOCATION
elif self.new_chat_members:
return ContentType.NEW_CHAT_MEMBERS[0]
return ContentType.NEW_CHAT_MEMBERS
elif self.left_chat_member:
return ContentType.LEFT_CHAT_MEMBER[0]
return ContentType.LEFT_CHAT_MEMBER
elif self.invoice:
return ContentType.INVOICE[0]
return ContentType.INVOICE
elif self.successful_payment:
return ContentType.SUCCESSFUL_PAYMENT[0]
return ContentType.SUCCESSFUL_PAYMENT
elif self.connected_website:
return ContentType.CONNECTED_WEBSITE[0]
return ContentType.CONNECTED_WEBSITE
elif self.migrate_from_chat_id:
return ContentType.MIGRATE_FROM_CHAT_ID[0]
return ContentType.MIGRATE_FROM_CHAT_ID
elif self.migrate_to_chat_id:
return ContentType.MIGRATE_TO_CHAT_ID[0]
return ContentType.MIGRATE_TO_CHAT_ID
elif self.pinned_message:
return ContentType.PINNED_MESSAGE[0]
return ContentType.PINNED_MESSAGE
elif self.new_chat_title:
return ContentType.NEW_CHAT_TITLE[0]
return ContentType.NEW_CHAT_TITLE
elif self.new_chat_photo:
return ContentType.NEW_CHAT_PHOTO[0]
return ContentType.NEW_CHAT_PHOTO
elif self.delete_chat_photo:
return ContentType.DELETE_CHAT_PHOTO[0]
return ContentType.DELETE_CHAT_PHOTO
elif self.group_chat_created:
return ContentType.GROUP_CHAT_CREATED[0]
return ContentType.GROUP_CHAT_CREATED
elif self.passport_data:
return ContentType.PASSPORT_DATA[0]
return ContentType.PASSPORT_DATA
else:
return ContentType.UNKNOWN[0]
return ContentType.UNKNOWN
def is_command(self):
"""
@ -239,7 +241,7 @@ class Message(base.TelegramObject):
return self.parse_entities()
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
disable_notification=None, reply_markup=None, reply=True) -> Message:
"""
Reply to this message
@ -729,6 +731,69 @@ class ContentType(helper.Helper):
"""
List of message content types
WARNING: Single elements
:key: TEXT
:key: AUDIO
:key: DOCUMENT
:key: GAME
:key: PHOTO
:key: STICKER
:key: VIDEO
:key: VIDEO_NOTE
:key: VOICE
:key: CONTACT
:key: LOCATION
:key: VENUE
:key: NEW_CHAT_MEMBERS
:key: LEFT_CHAT_MEMBER
:key: INVOICE
:key: SUCCESSFUL_PAYMENT
:key: CONNECTED_WEBSITE
:key: MIGRATE_TO_CHAT_ID
:key: MIGRATE_FROM_CHAT_ID
:key: UNKNOWN
:key: ANY
"""
mode = helper.HelperMode.snake_case
TEXT = helper.Item() # text
AUDIO = helper.Item() # audio
DOCUMENT = helper.Item() # document
ANIMATION = helper.Item() # animation
GAME = helper.Item() # game
PHOTO = helper.Item() # photo
STICKER = helper.Item() # sticker
VIDEO = helper.Item() # video
VIDEO_NOTE = helper.Item() # video_note
VOICE = helper.Item() # voice
CONTACT = helper.Item() # contact
LOCATION = helper.Item() # location
VENUE = helper.Item() # venue
NEW_CHAT_MEMBERS = helper.Item() # new_chat_member
LEFT_CHAT_MEMBER = helper.Item() # left_chat_member
INVOICE = helper.Item() # invoice
SUCCESSFUL_PAYMENT = helper.Item() # successful_payment
CONNECTED_WEBSITE = helper.Item() # connected_website
MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id
MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id
PINNED_MESSAGE = helper.Item() # pinned_message
NEW_CHAT_TITLE = helper.Item() # new_chat_title
NEW_CHAT_PHOTO = helper.Item() # new_chat_photo
DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo
GROUP_CHAT_CREATED = helper.Item() # group_chat_created
PASSPORT_DATA = helper.Item() # passport_data
UNKNOWN = helper.Item() # unknown
ANY = helper.Item() # any
class ContentTypes(helper.Helper):
"""
List of message content types
WARNING: List elements.
:key: TEXT
:key: AUDIO
:key: DOCUMENT

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from . import base
from . import fields
from .callback_query import CallbackQuery

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import babel
from . import base

View file

@ -1,140 +0,0 @@
"""
You need to setup task factory:
>>> from aiogram.utils import context
>>> loop = asyncio.get_event_loop()
>>> loop.set_task_factory(context.task_factory)
"""
import asyncio
import typing
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
"""
Task factory for implementing context processor
:param loop:
:param coro:
:return: new task
:rtype: :obj:`asyncio.Task`
"""
# Is not allowed when loop is closed.
if loop.is_closed():
raise RuntimeError('Event loop is closed.')
task = asyncio.Task(coro, loop=loop)
# Hide factory
if task._source_traceback:
del task._source_traceback[-1]
try:
task.context = asyncio.Task.current_task().context.copy()
except AttributeError:
task.context = {CONFIGURED: True}
return task
def get_current_state() -> typing.Dict:
"""
Get current execution context from task
:return: context
:rtype: :obj:`dict`
"""
task = asyncio.Task.current_task()
if task is None:
raise RuntimeError('Can be used only in Task context.')
context_ = getattr(task, 'context', None)
if context_ is None:
context_ = task.context = {}
return context_
def get_value(key, default=None):
"""
Get value from task
:param key:
:param default:
:return: value
"""
return get_current_state().get(key, default)
def check_value(key):
"""
Key in context?
:param key:
:return:
"""
return key in get_current_state()
def set_value(key, value):
"""
Set value
:param key:
:param value:
:return:
"""
get_current_state()[key] = value
def del_value(key):
"""
Remove value from context
:param key:
:return:
"""
del get_current_state()[key]
def update_state(data=None, **kwargs):
"""
Update multiple state items
:param data:
:param kwargs:
:return:
"""
if data is None:
data = {}
state = get_current_state()
state.update(data, **kwargs)
def check_configured():
"""
Check loop is configured
:return:
"""
return get_value(CONFIGURED)
class _Context:
"""
Other things for interactions with the execution context.
"""
def __getitem__(self, item):
return get_value(item)
def __setitem__(self, key, value):
set_value(key, value)
def __delitem__(self, key):
del_value(key)
@staticmethod
def get_context():
return get_current_state()
context = _Context()

View file

@ -182,6 +182,10 @@ class MessageToEditNotFound(MessageError):
match = 'message to edit not found'
class MessageIsTooLong(MessageError):
match = 'message is too long'
class ToMuchMessages(MessageError):
"""
Will be raised when you try to send media group with more than 10 items.

View file

@ -6,7 +6,6 @@ from warnings import warn
from aiohttp import web
from . import context
from ..bot.api import log
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
@ -104,6 +103,11 @@ class Executor:
self._freeze = False
from aiogram.bot.bot import bot as ctx_bot
from aiogram.dispatcher import dispatcher as ctx_dp
ctx_bot.set(dispatcher.bot)
ctx_dp.set(dispatcher)
@property
def frozen(self):
return self._freeze
@ -176,13 +180,13 @@ class Executor:
self._check_frozen()
self._freeze = True
self.loop.set_task_factory(context.task_factory)
# self.loop.set_task_factory(context.task_factory)
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
self._check_frozen()
self._freeze = True
self.loop.set_task_factory(context.task_factory)
# self.loop.set_task_factory(context.task_factory)
app = self._web_app
if app is None:
@ -203,6 +207,7 @@ class Executor:
for callback in self._on_startup_webhook:
app.on_startup.append(functools.partial(_wrap_callback, callback))
# for callback in self._on_shutdown_webhook:
# app.on_shutdown.append(functools.partial(_wrap_callback, callback))

View file

@ -1,26 +1,52 @@
import json
import os
JSON = 'json'
RAPIDJSON = 'rapidjson'
UJSON = 'ujson'
try:
import ujson
if 'DISABLE_UJSON' not in os.environ:
import ujson as json
mode = UJSON
def dumps(data):
return json.dumps(data, ensure_ascii=False)
else:
mode = JSON
except ImportError:
ujson = None
mode = JSON
_use_ujson = True if ujson else False
try:
if 'DISABLE_RAPIDJSON' not in os.environ:
import rapidjson as json
mode = RAPIDJSON
def disable_ujson():
global _use_ujson
_use_ujson = False
def dumps(data):
return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE,
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
def dumps(data):
if _use_ujson:
return ujson.dumps(data)
return json.dumps(data)
def loads(data):
return json.loads(data, number_mode=json.NM_NATIVE,
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
else:
mode = JSON
except ImportError:
mode = JSON
if mode == JSON:
import json
def loads(data):
if _use_ujson:
return ujson.loads(data)
return json.loads(data)
def dumps(data):
return json.dumps(data, ensure_ascii=False)
def loads(data):
return json.loads(data)

View file

@ -1,5 +1,7 @@
import datetime
import secrets
from aiogram import types
from . import json
DEFAULT_FILTER = ['self', 'cls']
@ -56,3 +58,22 @@ def prepare_arg(value):
elif isinstance(value, datetime.datetime):
return round(value.timestamp())
return value
def prepare_file(payload, files, key, file):
if isinstance(file, str):
payload[key] = file
elif file is not None:
files[key] = file
def prepare_attachment(payload, files, key, file):
if isinstance(file, str):
payload[key] = file
elif isinstance(file, types.InputFile):
payload[key] = file.attach
files[file.attachment_key] = file.file
elif file is not None:
file_attach_name = secrets.token_urlsafe(16)
payload[key] = "attach://" + file_attach_name
files[file_attach_name] = file

View file

@ -1,6 +1,7 @@
-r requirements.txt
ujson>=1.35
python-rapidjson>=0.6.3
emoji>=0.5.0
pytest>=3.5.0
pytest-asyncio>=0.8.0

View file

@ -25,16 +25,16 @@ Next step: interaction with bots starts with one command. Register your first co
.. code-block:: python3
@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
Last step: run long polling.
.. code-block:: python3
if __name__ == '__main__':
executor.start_polling(dp)
if __name__ == '__main__':
executor.start_polling(dp)
Summary
-------
@ -48,9 +48,9 @@ Summary
bot = Bot(token='BOT TOKEN HERE')
dp = Dispatcher(bot)
@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
if __name__ == '__main__':
executor.start_polling(dp)
if __name__ == '__main__':
executor.start_polling(dp)

View file

@ -1,8 +1,8 @@
name: py36
name: py37
channels:
- conda-forge
dependencies:
- python=3.6
- python=3.7
- sphinx=1.5.3
- sphinx_rtd_theme=0.2.4
- pip

View file

@ -29,6 +29,7 @@ async def send_message(user_id: int, text: str, disable_notification: bool = Fal
:param user_id:
:param text:
:param disable_notification:
:return:
"""
try:

View file

@ -5,18 +5,14 @@ Babel is required.
import asyncio
import logging
from aiogram import Bot, types
from aiogram.dispatcher import Dispatcher
from aiogram.types import ParseMode
from aiogram.utils.executor import start_polling
from aiogram.utils.markdown import *
from aiogram import Bot, Dispatcher, executor, md, types
API_TOKEN = 'BOT TOKEN HERE'
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop)
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.MARKDOWN)
dp = Dispatcher(bot)
@ -24,14 +20,14 @@ dp = Dispatcher(bot)
async def check_language(message: types.Message):
locale = message.from_user.locale
await message.reply(text(
bold('Info about your language:'),
text(' 🔸', bold('Code:'), italic(locale.locale)),
text(' 🔸', bold('Territory:'), italic(locale.territory or 'Unknown')),
text(' 🔸', bold('Language name:'), italic(locale.language_name)),
text(' 🔸', bold('English language name:'), italic(locale.english_name)),
sep='\n'), parse_mode=ParseMode.MARKDOWN)
await message.reply(md.text(
md.bold('Info about your language:'),
md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)),
md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')),
md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)),
md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)),
sep='\n'))
if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True)
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -1,9 +1,7 @@
import asyncio
import logging
from aiogram import Bot, types
from aiogram.dispatcher import Dispatcher
from aiogram.utils.executor import start_polling
from aiogram import Bot, types, Dispatcher, executor
API_TOKEN = 'BOT TOKEN HERE'
@ -32,10 +30,4 @@ async def echo(message: types.Message):
if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True)
# Also you can use another execution method
# >>> try:
# >>> loop.run_until_complete(main())
# >>> except KeyboardInterrupt:
# >>> loop.stop()
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -1,47 +0,0 @@
from aiogram import Bot, types
from aiogram.contrib.middlewares.context import ContextMiddleware
from aiogram.dispatcher import Dispatcher
from aiogram.types import ParseMode
from aiogram.utils import markdown as md
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE'
bot = Bot(token=API_TOKEN)
dp = Dispatcher(bot)
# Setup Context middleware
data: ContextMiddleware = dp.middleware.setup(ContextMiddleware())
# Write custom filter
async def demo_filter(message: types.Message):
# Store some data in context
command = data['command'] = message.get_command() or ''
args = data['args'] = message.get_args() or ''
data['has_args'] = bool(args)
data['some_random_data'] = 42
return command != '/bad_command'
@dp.message_handler(demo_filter)
async def send_welcome(message: types.Message):
# Get data from context
# All of this is available only in current context and from current update object
# `data`- pseudo-alias for `ctx.get_update().conf['_context_data']`
command = data['command']
args = data['args']
rand = data['some_random_data']
has_args = data['has_args']
# Send as pre-formatted code block.
await message.reply(md.hpre(f"""command: {command}
args: {['Not available', 'available'][has_args]}: {args}
some random data: {rand}
message ID: {message.message_id}
message: {message.html_text}
"""), parse_mode=ParseMode.HTML)
if __name__ == '__main__':
start_polling(dp)

View file

@ -1,11 +1,13 @@
import asyncio
from typing import Optional
from aiogram import Bot, types
import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher
from aiogram.dispatcher import FSMContext
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.types import ParseMode
from aiogram.utils import executor
from aiogram.utils.markdown import text, bold
API_TOKEN = 'BOT TOKEN HERE'
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage)
# States
AGE = 'process_age'
NAME = 'process_name'
GENDER = 'process_gender'
class Form(StatesGroup):
name = State() # Will be represented in storage as 'Form:name'
age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start'])
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
"""
Conversation's entry point
"""
# Get current state
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
# Update user's state
await state.set_state(NAME)
# Set state
await Form.name.set()
await message.reply("Hi there! What's your name?")
# You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel')
async def cancel_handler(message: types.Message):
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
"""
Allow user to cancel any action
"""
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
# Ignore command if user is not in any (defined) state
if await state.get_state() is None:
return
if raw_state is None:
return
# Otherwise cancel state and inform user about it
# And remove keyboard (just in case)
await state.reset_state(with_data=True)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
# Cancel state and inform user about it
await state.finish()
# And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=NAME)
async def process_name(message: types.Message):
@dp.message_handler(state=Form.name)
async def process_name(message: types.Message, state: FSMContext):
"""
Process user name
"""
# Save name to storage and go to next step
# You can use context manager
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
await state.update_data(name=message.text)
await state.set_state(AGE)
await Form.next()
await state.update_data(name=message.text)
await message.reply("How old are you?")
# Check age. Age gotta be digit
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit())
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def failed_process_age(message: types.Message):
"""
If age is invalid
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit())
async def process_age(message: types.Message):
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message, state: FSMContext):
# Update state and data
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
await state.set_state(GENDER)
await state.update_data(age=int(message.text))
await Form.next()
await state.update_data(age=int(message.text))
# Configure ReplyKeyboardMarkup
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"])
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def failed_process_gender(message: types.Message):
"""
In this example gender has to be one of: Male, Female, Other.
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
return await message.reply("Bad gender name. Choose you gender from keyboard.")
@dp.message_handler(state=GENDER)
async def process_gender(message: types.Message):
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
@dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message, state: FSMContext):
data = await state.get_data()
data['gender'] = message.text
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
markup = types.ReplyKeyboardRemove()
# And send message
await bot.send_message(message.chat.id, text(
text('Hi! Nice to meet you,', bold(data['name'])),
text('Age:', data['age']),
text('Gender:', data['gender']),
await bot.send_message(message.chat.id, md.text(
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
md.text('Age:', data['age']),
md.text('Gender:', data['gender']),
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
# Finish conversation
@ -122,10 +116,5 @@ async def process_gender(message: types.Message):
await state.finish()
async def shutdown(dispatcher: Dispatcher):
await dispatcher.storage.close()
await dispatcher.storage.wait_closed()
if __name__ == '__main__':
executor.start_polling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown)
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -0,0 +1,126 @@
"""
This example is equals with 'finite_state_machine_example.py' but with FSM Middleware
Note that FSM Middleware implements the more simple methods for working with storage.
With that middleware all data from storage will be loaded before event will be processed
and data will be stored after processing the event.
"""
import asyncio
import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.utils import executor
API_TOKEN = 'BOT TOKEN HERE'
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop)
# For example use simple MemoryStorage for Dispatcher.
storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage)
dp.middleware.setup(FSMMiddleware())
# States
class Form(StatesGroup):
name = State() # Will be represented in storage as 'Form:name'
age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start'])
async def cmd_start(message: types.Message):
"""
Conversation's entry point
"""
# Set state
await Form.first()
await message.reply("Hi there! What's your name?")
# You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy):
"""
Allow user to cancel any action
"""
if state_data.state is None:
return
# Cancel state and inform user about it
del state_data.state
# And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=Form.name)
async def process_name(message: types.Message, state_data: FSMSStorageProxy):
"""
Process user name
"""
state_data.state = Form.age
state_data['name'] = message.text
await message.reply("How old are you?")
# Check age. Age gotta be digit
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def failed_process_age(message: types.Message):
"""
If age is invalid
"""
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message, state_data: FSMSStorageProxy):
# Update state and data
state_data.state = Form.gender
state_data['age'] = int(message.text)
# Configure ReplyKeyboardMarkup
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
markup.add("Male", "Female")
markup.add("Other")
await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def failed_process_gender(message: types.Message):
"""
In this example gender has to be one of: Male, Female, Other.
"""
return await message.reply("Bad gender name. Choose you gender from keyboard.")
@dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message, state_data: FSMSStorageProxy):
state_data['gender'] = message.text
# Remove keyboard
markup = types.ReplyKeyboardRemove()
# And send message
await bot.send_message(message.chat.id, md.text(
md.text('Hi! Nice to meet you,', md.bold(state_data['name'])),
md.text('Age:', state_data['age']),
md.text('Gender:', state_data['gender']),
sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN)
# Finish conversation
# WARNING! This method will destroy all data in storage for current user!
state_data.clear()
if __name__ == '__main__':
executor.start_polling(dp, loop=loop, skip_updates=True)

56
examples/i18n_example.py Normal file
View file

@ -0,0 +1,56 @@
"""
Internalize your bot
Step 1: extract texts
# pybabel extract i18n_example.py -o locales/mybot.pot
Step 2: create *.po files. For e.g. create en, ru, uk locales.
# echo {en,ru,uk} | xargs -n1 pybabel init -i locales/mybot.pot -d locales -D mybot -l
Step 3: translate texts
Step 4: compile translations
# pybabel compile -d locales -D mybot
Step 5: When you change the code of your bot you need to update po & mo files.
Step 5.1: regenerate pot file:
command from step 1
Step 5.2: update po files
# pybabel update -d locales -D mybot -i locales/mybot.pot
Step 5.3: update your translations
Step 5.4: compile mo files
command from step 4
"""
from pathlib import Path
from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.middlewares.i18n import I18nMiddleware
TOKEN = 'BOT TOKEN HERE'
I18N_DOMAIN = 'mybot'
BASE_DIR = Path(__file__).parent
LOCALES_DIR = BASE_DIR / 'locales'
bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML)
dp = Dispatcher(bot)
# Setup i18n middleware
i18n = I18nMiddleware(I18N_DOMAIN, LOCALES_DIR)
dp.middleware.setup(i18n)
# Alias for gettext method
_ = i18n.gettext
@dp.message_handler(commands=['start'])
async def cmd_start(message: types.Message):
# Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts.
await message.reply(_('Hello, <b>{user}</b>!').format(user=message.from_user.full_name))
@dp.message_handler(commands=['lang'])
async def cmd_lang(message: types.Message, locale):
await message.reply(_('Your current language: <i>{language}</i>').format(language=locale))
if __name__ == '__main__':
executor.start_polling(dp, skip_updates=True)

View file

@ -1,9 +1,7 @@
import asyncio
import logging
from aiogram import Bot, types
from aiogram.dispatcher import Dispatcher
from aiogram.utils.executor import start_polling
from aiogram import Bot, types, Dispatcher, executor
API_TOKEN = 'BOT TOKEN HERE'
@ -23,4 +21,4 @@ async def inline_echo(inline_query: types.InlineQuery):
if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True)
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -0,0 +1,28 @@
# English translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n"
"Language-Team: en <LL@li.org>\n"
"Plural-Forms: nplurals=2; plural=(n != 1)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr ""
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr ""

View file

@ -0,0 +1,27 @@
# Translations template for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr ""
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr ""

View file

@ -0,0 +1,29 @@
# Russian translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: ru\n"
"Language-Team: ru <LL@li.org>\n"
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr "Привет, <b>{user}</b>!"
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr "Твой язык: <i>{language}</i>"

View file

@ -0,0 +1,29 @@
# Ukrainian translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: uk\n"
"Language-Team: uk <LL@li.org>\n"
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr "Привіт, <b>{user}</b>!"
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr "Твоя мова: <i>{language}</i>"

View file

@ -1,9 +1,6 @@
import asyncio
from aiogram import Bot, types
from aiogram.dispatcher import Dispatcher
from aiogram.types import ChatActions
from aiogram.utils.executor import start_polling
from aiogram import Bot, Dispatcher, executor, filters, types
API_TOKEN = 'BOT TOKEN HERE'
@ -12,7 +9,7 @@ bot = Bot(token=API_TOKEN, loop=loop)
dp = Dispatcher(bot)
@dp.message_handler(commands=['start'])
@dp.message_handler(filters.CommandStart())
async def send_welcome(message: types.Message):
# So... At first I want to send something like this:
await message.reply("Do you want to see many pussies? Are you ready?")
@ -21,7 +18,7 @@ async def send_welcome(message: types.Message):
await asyncio.sleep(1)
# Good bots should send chat actions. Or not.
await ChatActions.upload_photo()
await types.ChatActions.upload_photo()
# Create media group
media = types.MediaGroup()
@ -39,9 +36,8 @@ async def send_welcome(message: types.Message):
# media.attach_photo('<file_id>', 'cat-cat-cat.')
# Done! Send media group
await bot.send_media_group(message.chat.id, media=media,
reply_to_message_id=message.message_id)
await message.reply_media_group(media=media)
if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True)
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -1,11 +1,10 @@
import asyncio
from aiogram import Bot, types
from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils import context, executor
from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE'
@ -53,10 +52,10 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message:
"""
# Get current handler
handler = context.get_value('handler')
# handler = context.get_value('handler')
# Get dispatcher from context
dispatcher = ctx.get_dispatcher()
dispatcher = Dispatcher.current()
# If handler was configured, get rate limit and key from handler
if handler:
@ -83,8 +82,8 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message:
:param throttled:
"""
handler = context.get_value('handler')
dispatcher = ctx.get_dispatcher()
# handler = context.get_value('handler')
dispatcher = Dispatcher.current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:

View file

@ -4,7 +4,7 @@ from aiogram import Bot
from aiogram import types
from aiogram.utils import executor
from aiogram.dispatcher import Dispatcher
from aiogram.types.message import ContentType
from aiogram.types.message import ContentTypes
BOT_TOKEN = 'BOT TOKEN HERE'
@ -86,7 +86,7 @@ async def checkout(pre_checkout_query: types.PreCheckoutQuery):
" try to pay again in a few minutes, we need a small rest.")
@dp.message_handler(content_types=ContentType.SUCCESSFUL_PAYMENT)
@dp.message_handler(content_types=ContentTypes.SUCCESSFUL_PAYMENT)
async def got_payment(message: types.Message):
await bot.send_message(message.chat.id,
'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`'

View file

@ -18,7 +18,7 @@ PROXY_URL = 'http://PROXY_URL' # Or 'socks5://...'
# PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password')
# And add `proxy_auth=PROXY_AUTH` argument in line 25, like this:
# >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH)
# Also you can use Socks5 proxy but you need manually install aiosocksy package.
# Also you can use Socks5 proxy but you need manually install aiohttp_socks package.
# Get my ip URL
GET_IP_URL = 'http://bot.whatismyipaddress.com/'

View file

@ -9,7 +9,7 @@ from aiogram import Bot, types, Version
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher
from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage
from aiogram.types import ChatType, ParseMode, ContentType
from aiogram.types import ChatType, ParseMode, ContentTypes
from aiogram.utils.markdown import hbold, bold, text, link
TOKEN = 'BOT TOKEN HERE'
@ -31,7 +31,7 @@ WEBHOOK_URL = f"https://{WEBHOOK_HOST}:{WEBHOOK_PORT}{WEBHOOK_URL_PATH}"
WEBAPP_HOST = 'localhost'
WEBAPP_PORT = 3001
BAD_CONTENT = ContentType.PHOTO & ContentType.DOCUMENT & ContentType.STICKER & ContentType.AUDIO
BAD_CONTENT = ContentTypes.PHOTO & ContentTypes.DOCUMENT & ContentTypes.STICKER & ContentTypes.AUDIO
loop = asyncio.get_event_loop()
bot = Bot(TOKEN, loop=loop)

View file

@ -13,7 +13,7 @@ except ImportError: # pip >= 10.0.0
WORK_DIR = pathlib.Path(__file__).parent
# Check python version
MINIMAL_PY_VERSION = (3, 6)
MINIMAL_PY_VERSION = (3, 7)
if sys.version_info < MINIMAL_PY_VERSION:
raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION))))
@ -65,7 +65,7 @@ setup(
url='https://github.com/aiogram/aiogram',
license='MIT',
author='Alex Root Junior',
requires_python='>=3.6',
requires_python='>=3.7',
author_email='aiogram@illemius.xyz',
description='Is a pretty simple and fully asynchronous library for Telegram Bot API',
long_description=get_description(),
@ -76,7 +76,7 @@ setup(
'Intended Audience :: Developers',
'Intended Audience :: System Administrators',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Software Development :: Libraries :: Application Frameworks',
],
install_requires=get_requirements()

102
tests/states_group.py Normal file
View file

@ -0,0 +1,102 @@
import pytest
from aiogram.dispatcher.filters.state import State, StatesGroup, any_state, default_state
class MyGroup(StatesGroup):
state = State()
state_1 = State()
state_2 = State()
class MySubGroup(StatesGroup):
sub_state = State()
sub_state_1 = State()
sub_state_2 = State()
in_custom_group = State(group_name='custom_group')
class NewGroup(StatesGroup):
spam = State()
renamed_state = State(state='spam_state')
alone_state = State('alone')
alone_in_group = State('alone', group_name='home')
def test_default_state():
assert default_state.state is None
def test_any_state():
assert any_state.state == '*'
def test_alone_state():
assert alone_state.state == '@:alone'
assert alone_in_group.state == 'home:alone'
def test_group_names():
assert MyGroup.__group_name__ == 'MyGroup'
assert MyGroup.__full_group_name__ == 'MyGroup'
assert MyGroup.MySubGroup.__group_name__ == 'MySubGroup'
assert MyGroup.MySubGroup.__full_group_name__ == 'MyGroup.MySubGroup'
assert MyGroup.MySubGroup.NewGroup.__group_name__ == 'NewGroup'
assert MyGroup.MySubGroup.NewGroup.__full_group_name__ == 'MyGroup.MySubGroup.NewGroup'
def test_custom_group_in_group():
assert MyGroup.MySubGroup.in_custom_group.state == 'custom_group:in_custom_group'
def test_custom_state_name_in_group():
assert MyGroup.MySubGroup.NewGroup.renamed_state.state == 'MyGroup.MySubGroup.NewGroup:spam_state'
def test_group_states_names():
assert len(MyGroup.states) == 3
assert len(MyGroup.all_states) == 9
assert MyGroup.states_names == ('MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2')
assert MyGroup.MySubGroup.states_names == (
'MyGroup.MySubGroup:sub_state', 'MyGroup.MySubGroup:sub_state_1', 'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group')
assert MyGroup.MySubGroup.NewGroup.states_names == (
'MyGroup.MySubGroup.NewGroup:spam', 'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.all_states_names == (
'MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2',
'MyGroup.MySubGroup:sub_state',
'MyGroup.MySubGroup:sub_state_1',
'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group',
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.MySubGroup.all_states_names == (
'MyGroup.MySubGroup:sub_state',
'MyGroup.MySubGroup:sub_state_1',
'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group',
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.MySubGroup.NewGroup.all_states_names == (
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
def test_root_element():
root = MyGroup.MySubGroup.NewGroup.spam.get_root()
assert issubclass(root, StatesGroup)
assert root == MyGroup
assert root == MyGroup.state.get_root()
assert root == MyGroup.MySubGroup.get_root()
with pytest.raises(RuntimeError):
any_state.get_root()

View file

@ -1,5 +1,5 @@
[tox]
envlist = py36
envlist = py37
[testenv]
deps = -rdev_requirements.txt