This commit is contained in:
Kolay 2018-08-17 11:18:04 +00:00 committed by GitHub
commit bfccd6268c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 3455 additions and 2299 deletions

3
.gitignore vendored
View file

@ -57,3 +57,6 @@ experiment.py
# Doc's # Doc's
docs/html docs/html
# i18n/l10n
*.mo

View file

@ -8,7 +8,7 @@
[![Github issues](https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square)](https://github.com/aiogram/aiogram/issues) [![Github issues](https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square)](https://github.com/aiogram/aiogram/issues)
[![MIT License](https://img.shields.io/pypi/l/aiogram.svg?style=flat-square)](https://opensource.org/licenses/MIT) [![MIT License](https://img.shields.io/pypi/l/aiogram.svg?style=flat-square)](https://opensource.org/licenses/MIT)
**aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.6 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler. **aiogram** is a pretty simple and fully asynchronous library for [Telegram Bot API](https://core.telegram.org/bots/api) written in Python 3.7 with [asyncio](https://docs.python.org/3/library/asyncio.html) and [aiohttp](https://github.com/aio-libs/aiohttp). It helps you to make your bots faster and simpler.
You can [read the docs here](http://aiogram.readthedocs.io/en/latest/). You can [read the docs here](http://aiogram.readthedocs.io/en/latest/).

View file

@ -1,14 +1,42 @@
import asyncio import asyncio
import os
from . import bot
from . import contrib
from . import dispatcher
from . import types
from . import utils
from .bot import Bot from .bot import Bot
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .dispatcher import filters
from .dispatcher import middlewares
from .utils import exceptions, executor, helper, markdown as md
try: try:
import uvloop import uvloop
except ImportError: except ImportError:
uvloop = None uvloop = None
else: else:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) if 'DISABLE_UVLOOP' not in os.environ:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
__version__ = '1.4' __all__ = [
'Bot',
'Dispatcher',
'__api_version__',
'__version__',
'bot',
'contrib',
'dispatcher',
'exceptions',
'executor',
'filters',
'helper',
'md',
'middlewares',
'types',
'utils'
]
__version__ = '2.0.dev1'
__api_version__ = '3.6' __api_version__ = '3.6'

View file

@ -1,8 +1,14 @@
import abc
import asyncio
import logging import logging
import os import os
import ssl
from asyncio import AbstractEventLoop
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Tuple
import aiohttp import aiohttp
import certifi
from .. import types from .. import types
from ..utils import exceptions from ..utils import exceptions
@ -34,58 +40,73 @@ def check_token(token: str) -> bool:
return True return True
async def _check_result(method_name, response): async def check_result(method_name: str, content_type: str, status_code: int, body: str):
""" """
Checks whether `result` is a valid API response. Checks whether `result` is a valid API response.
A result is considered invalid if: A result is considered invalid if:
- The server returned an HTTP response code other than 200 - The server returned an HTTP response code other than 200
- The content of the result is invalid JSON. - The content of the result is invalid JSON.
- The method call was unsuccessful (The JSON 'ok' field equals False) - The method call was unsuccessful (The JSON 'ok' field equals False)
:raises ApiException: if one of the above listed cases is applicable :param method_name: The name of the method called
:param method_name: The name of the method called :param status_code: status code
:param response: The returned response of the method request :param content_type: content type of result
:return: The result parsed to a JSON dictionary. :param body: result body
""" :return: The result parsed to a JSON dictionary
body = await response.text() :raises ApiException: if one of the above listed cases is applicable
log.debug(f"Response for {method_name}: [{response.status}] {body}") """
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
if response.content_type != 'application/json': if content_type != 'application/json':
raise exceptions.NetworkError(f"Invalid response with content type {response.content_type}: \"{body}\"") raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
try:
result_json = json.loads(body)
except ValueError:
result_json = {}
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
return result_json.get('result')
elif parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
elif status_code == HTTPStatus.BAD_REQUEST:
exceptions.BadRequest.detect(description)
elif status_code == HTTPStatus.NOT_FOUND:
exceptions.NotFound.detect(description)
elif status_code == HTTPStatus.CONFLICT:
exceptions.ConflictError.detect(description)
elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
exceptions.Unauthorized.detect(description)
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
async def make_request(session, token, method, data=None, files=None, **kwargs):
# log.debug(f"Make request: '{method}' with data: {data} and files {files}")
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
url = Methods.api_url(token=token, method=method)
req = compose_data(data, files)
try: try:
result_json = await response.json(loads=json.loads) async with session.post(url, data=req, **kwargs) as response:
except ValueError: return await check_result(method, response.content_type, response.status, await response.text())
result_json = {} except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
if HTTPStatus.OK <= response.status <= HTTPStatus.IM_USED:
return result_json.get('result')
elif parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
elif response.status == HTTPStatus.BAD_REQUEST:
exceptions.BadRequest.detect(description)
elif response.status == HTTPStatus.NOT_FOUND:
exceptions.NotFound.detect(description)
elif response.status == HTTPStatus.CONFLICT:
exceptions.ConflictError.detect(description)
elif response.status in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
exceptions.Unauthorized.detect(description)
elif response.status == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
elif response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{response.status}]")
def _guess_filename(obj): def guess_filename(obj):
""" """
Get file name from object Get file name from object
@ -97,7 +118,7 @@ def _guess_filename(obj):
return os.path.basename(name) return os.path.basename(name)
def _compose_data(params=None, files=None): def compose_data(params=None, files=None):
""" """
Prepare request data Prepare request data
@ -121,47 +142,13 @@ def _compose_data(params=None, files=None):
elif isinstance(f, types.InputFile): elif isinstance(f, types.InputFile):
filename, fileobj = f.filename, f.file filename, fileobj = f.filename, f.file
else: else:
filename, fileobj = _guess_filename(f) or key, f filename, fileobj = guess_filename(f) or key, f
data.add_field(key, fileobj, filename=filename) data.add_field(key, fileobj, filename=filename)
return data return data
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
"""
Make request to API
That make request with Content-Type:
application/x-www-form-urlencoded - For simple request
and multipart/form-data - for files uploading
https://core.telegram.org/bots/api#making-requests
:param session: HTTP Client session
:type session: :obj:`aiohttp.ClientSession`
:param token: BOT token
:type token: :obj:`str`
:param method: API method
:type method: :obj:`str`
:param data: request payload
:type data: :obj:`dict`
:param files: files
:type files: :obj:`dict`
:return: result
:rtype :obj:`bool` or :obj:`dict`
"""
log.debug("Make request: '{0}' with data: {1} and files {2}".format(
method, data or {}, files or {}))
data = _compose_data(data, files)
url = Methods.api_url(token=token, method=method)
try:
async with session.post(url, data=data, **kwargs) as response:
return await _check_result(method, response)
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
class Methods(Helper): class Methods(Helper):
""" """
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api Helper for Telegram API Methods listed on https://core.telegram.org/bots/api

View file

@ -47,7 +47,6 @@ class BaseBot:
api.check_token(token) api.check_token(token)
self.__token = token self.__token = token
# Proxy settings
self.proxy = proxy self.proxy = proxy
self.proxy_auth = proxy_auth self.proxy_auth = proxy_auth
@ -59,37 +58,42 @@ class BaseBot:
# aiohttp main session # aiohttp main session
ssl_context = ssl.create_default_context(cafile=certifi.where()) ssl_context = ssl.create_default_context(cafile=certifi.where())
if isinstance(proxy, str) and proxy.startswith('socks5://'): if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
from aiosocksy.connector import ProxyClientRequest, ProxyConnector from aiohttp_socks import SocksConnector
connector = ProxyConnector(limit=connections_limit, ssl_context=ssl_context, loop=self.loop) from aiohttp_socks.helpers import parse_socks_url
request_class = ProxyClientRequest
socks_ver, host, port, username, password = parse_socks_url(proxy)
if proxy_auth and not username or password:
username = proxy_auth.login
password = proxy_auth.password
connector = SocksConnector(socks_ver=socks_ver, host=host, port=port,
username=username, password=password,
limit=connections_limit, ssl_context=ssl_context,
loop=self.loop)
self.proxy = None
self.proxy_auth = None
else: else:
connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context, connector = aiohttp.TCPConnector(limit=connections_limit, ssl_context=ssl_context,
loop=self.loop) loop=self.loop)
request_class = aiohttp.ClientRequest
self.session = aiohttp.ClientSession(connector=connector, request_class=request_class, self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps)
loop=self.loop, json_serialize=json.dumps)
# Data stored in bot instance
self._data = {}
self.parse_mode = parse_mode self.parse_mode = parse_mode
def __del__(self): # Data stored in bot instance
# asyncio.ensure_future(self.close()) self._data = {}
pass
async def close(self): async def close(self):
""" """
Close all client sessions Close all client sessions
""" """
if self.session and not self.session.closed: await self.session.close()
await self.session.close()
async def request(self, method: base.String, async def request(self, method: base.String,
data: Optional[Dict] = None, data: Optional[Dict] = None,
files: Optional[Dict] = None) -> Union[List, Dict, base.Boolean]: files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]:
""" """
Make an request to Telegram Bot API Make an request to Telegram Bot API
@ -105,8 +109,8 @@ class BaseBot:
:rtype: Union[List, Dict] :rtype: Union[List, Dict]
:raise: :obj:`aiogram.exceptions.TelegramApiError` :raise: :obj:`aiogram.exceptions.TelegramApiError`
""" """
return await api.request(self.session, self.__token, method, data, files, return await api.make_request(self.session, self.__token, method, data, files,
proxy=self.proxy, proxy_auth=self.proxy_auth) proxy=self.proxy, proxy_auth=self.proxy_auth, **kwargs)
async def download_file(self, file_path: base.String, async def download_file(self, file_path: base.String,
destination: Optional[base.InputFile] = None, destination: Optional[base.InputFile] = None,

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
import typing import typing
from ...dispatcher import BaseStorage from ...dispatcher.storage import BaseStorage
class MemoryStorage(BaseStorage): class MemoryStorage(BaseStorage):
@ -56,7 +56,7 @@ class MemoryStorage(BaseStorage):
chat, user = self.check_address(chat=chat, user=user) chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user) user = self._get_user(chat, user)
if data is None: if data is None:
data = [] data = {}
user['data'].update(data, **kwargs) user['data'].update(data, **kwargs)
async def set_state(self, *, async def set_state(self, *,

View file

@ -303,6 +303,8 @@ class RedisStorage2(BaseStorage):
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs): data: typing.Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={}) temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs) temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data) await self.set_data(chat=chat, user=user, data=temp_data)
@ -330,6 +332,8 @@ class RedisStorage2(BaseStorage):
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs): bucket: typing.Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_data(chat=chat, user=user) temp_bucket = await self.get_data(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs) temp_bucket.update(bucket, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_bucket) await self.set_data(chat=chat, user=user, data=temp_bucket)

View file

@ -4,7 +4,7 @@ import weakref
import rethinkdb as r import rethinkdb as r
from ...dispatcher import BaseStorage from ...dispatcher.storage import BaseStorage
__all__ = ['RethinkDBStorage', 'ConnectionNotClosed'] __all__ = ['RethinkDBStorage', 'ConnectionNotClosed']

View file

@ -1,115 +0,0 @@
from aiogram import types
from aiogram.dispatcher import ctx
from aiogram.dispatcher.middlewares import BaseMiddleware
OBJ_KEY = '_context_data'
class ContextMiddleware(BaseMiddleware):
"""
Allow to store data at all of lifetime of Update object
"""
async def on_pre_process_update(self, update: types.Update):
"""
Start of Update lifetime
:param update:
:return:
"""
self._configure_update(update)
async def on_post_process_update(self, update: types.Update, result):
"""
On finishing of processing update
:param update:
:param result:
:return:
"""
if OBJ_KEY in update.conf:
del update.conf[OBJ_KEY]
def _configure_update(self, update: types.Update = None):
"""
Setup data storage
:param update:
:return:
"""
obj = update.conf[OBJ_KEY] = {}
return obj
def _get_dict(self):
"""
Get data from update stored in current context
:return:
"""
update = ctx.get_update()
obj = update.conf.get(OBJ_KEY, None)
if obj is None:
obj = self._configure_update(update)
return obj
def __getitem__(self, item):
"""
Item getter
:param item:
:return:
"""
return self._get_dict()[item]
def __setitem__(self, key, value):
"""
Item setter
:param key:
:param value:
:return:
"""
data = self._get_dict()
data[key] = value
def __iter__(self):
"""
Iterate over dict
:return:
"""
return self._get_dict().__iter__()
def keys(self):
"""
Iterate over dict keys
:return:
"""
return self._get_dict().keys()
def values(self):
"""
Iterate over dict values
:return:
"""
return self._get_dict().values()
def get(self, key, default=None):
"""
Get item from dit or return default value
:param key:
:param default:
:return:
"""
return self._get_dict().get(key, default)
def export(self):
"""
Export all data s dict
:return:
"""
return self._get_dict()

View file

@ -0,0 +1,25 @@
from aiogram.dispatcher.middlewares import BaseMiddleware
class EnvironmentMiddleware(BaseMiddleware):
def __init__(self, context=None):
super(EnvironmentMiddleware, self).__init__()
if context is None:
context = {}
self.context = context
def update_data(self, data):
dp = self.manager.dispatcher
data.update(
bot=dp.bot,
dispatcher=dp,
loop=dp.loop
)
if self.context:
data.update(self.context)
async def trigger(self, action, args):
if 'error' not in action and action.startswith('pre_process_'):
self.update_data(args[-1])
return True

View file

@ -0,0 +1,80 @@
import copy
import weakref
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
from aiogram.dispatcher.storage import FSMContext
class FSMMiddleware(LifetimeControllerMiddleware):
skip_patterns = ['error', 'update']
def __init__(self):
super(FSMMiddleware, self).__init__()
self._proxies = weakref.WeakKeyDictionary()
async def pre_process(self, obj, data, *args):
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
data['state_data'] = proxy
async def post_process(self, obj, data, *args):
proxy = data.get('state_data', None)
if isinstance(proxy, FSMSStorageProxy):
await proxy.save()
class FSMSStorageProxy(dict):
def __init__(self, fsm_context: FSMContext):
super(FSMSStorageProxy, self).__init__()
self.fsm_context = fsm_context
self._copy = {}
self._data = {}
self._state = None
self._is_dirty = False
@classmethod
async def create(cls, fsm_context: FSMContext):
"""
:param fsm_context:
:return:
"""
proxy = cls(fsm_context)
await proxy.load()
return proxy
async def load(self):
self.clear()
self._state = await self.fsm_context.get_state()
self.update(await self.fsm_context.get_data())
self._copy = copy.deepcopy(self)
self._is_dirty = False
@property
def state(self):
return self._state
@state.setter
def state(self, value):
self._state = value
self._is_dirty = True
@state.deleter
def state(self):
self._state = None
self._is_dirty = True
async def save(self, force=False):
if self._copy != self or force:
await self.fsm_context.set_data(data=self)
if self._is_dirty or force:
await self.fsm_context.set_state(self.state)
self._is_dirty = False
self._copy = copy.deepcopy(self)
def __str__(self):
s = super(FSMSStorageProxy, self).__str__()
readable_state = f"'{self.state}'" if self.state else "''"
return f"<{self.__class__.__name__}(state={readable_state}, data={s})>"
def clear(self):
del self.state
return super(FSMSStorageProxy, self).clear()

View file

@ -0,0 +1,140 @@
import gettext
import os
from contextvars import ContextVar
from typing import Any, Dict, Tuple
from babel import Locale
from ... import types
from ...dispatcher.middlewares import BaseMiddleware
class I18nMiddleware(BaseMiddleware):
"""
I18n middleware based on gettext util
>>> dp = Dispatcher(bot)
>>> i18n = I18nMiddleware(DOMAIN, LOCALES_DIR)
>>> dp.middleware.setup(i18n)
and then
>>> _ = i18n.gettext
or
>>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR)
"""
ctx_locale = ContextVar('ctx_user_locale', default=None)
def __init__(self, domain, path=None, default='en'):
"""
:param domain: domain
:param path: path where located all *.mo files
:param default: default locale name
"""
super(I18nMiddleware, self).__init__()
if path is None:
path = os.path.join(os.getcwd(), 'locales')
self.domain = domain
self.path = path
self.default = default
self.locales = self.find_locales()
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
"""
Load all compiled locales from path
:return: dict with locales
"""
translations = {}
for name in os.listdir(self.path):
if not os.path.isdir(os.path.join(self.path, name)):
continue
mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo')
if os.path.exists(mo_path):
with open(mo_path, 'rb') as fp:
translations[name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + 'po'):
raise RuntimeError(f"Found locale '{name} but this language is not compiled!")
return translations
def reload(self):
"""
Hot reload locles
"""
self.locales = self.find_locales()
@property
def available_locales(self) -> Tuple[str]:
"""
list of loaded locales
:return:
"""
return tuple(self.locales.keys())
def __call__(self, singular, plural=None, n=1, locale=None) -> str:
return self.gettext(singular, plural, n, locale)
def gettext(self, singular, plural=None, n=1, locale=None) -> str:
"""
Get text
:param singular:
:param plural:
:param n:
:param locale:
:return:
"""
if locale is None:
locale = self.ctx_locale.get()
if locale not in self.locales:
if n is 1:
return singular
else:
return plural
translator = self.locales[locale]
if plural is None:
return translator.gettext(singular)
else:
return translator.ngettext(singular, plural, n)
# noinspection PyMethodMayBeStatic,PyUnusedLocal
async def get_user_locale(self, action: str, args: Tuple[Any]) -> str:
"""
User locale getter
You can override the method if you want to use different way of getting user language.
:param action: event name
:param args: event arguments
:return: locale name
"""
user: types.User = types.User.current()
locale: Locale = user.locale
if locale:
*_, data = args
language = data['locale'] = locale.language
return language
async def trigger(self, action, args):
"""
Event trigger
:param action: event name
:param args: event arguments
:return:
"""
if 'update' not in action \
and 'error' not in action \
and action.startswith('pre_process'):
locale = await self.get_user_locale(action, args)
self.ctx_locale.set(locale)
return True

View file

@ -23,70 +23,70 @@ class LoggingMiddleware(BaseMiddleware):
return round((time.time() - start) * 1000) return round((time.time() - start) * 1000)
return -1 return -1
async def on_pre_process_update(self, update: types.Update): async def on_pre_process_update(self, update: types.Update, data: dict):
update.conf['_start'] = time.time() update.conf['_start'] = time.time()
self.logger.debug(f"Received update [ID:{update.update_id}]") self.logger.debug(f"Received update [ID:{update.update_id}]")
async def on_post_process_update(self, update: types.Update, result): async def on_post_process_update(self, update: types.Update, result, data: dict):
timeout = self.check_timeout(update) timeout = self.check_timeout(update)
if timeout > 0: if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)") self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
async def on_pre_process_message(self, message: types.Message): async def on_pre_process_message(self, message: types.Message, data: dict):
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_post_process_message(self, message: types.Message, results): async def on_post_process_message(self, message: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_pre_process_edited_message(self, edited_message): async def on_pre_process_edited_message(self, edited_message, data: dict):
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] " self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_post_process_edited_message(self, edited_message, results): async def on_post_process_edited_message(self, edited_message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] " f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_pre_process_channel_post(self, channel_post: types.Message): async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] " self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]") f"in channel [ID:{channel_post.chat.id}]")
async def on_post_process_channel_post(self, channel_post: types.Message, results): async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] " f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]") f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message): async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] " self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]") f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results): async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] " f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]") f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery): async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
self.logger.info(f"Received inline query [ID:{inline_query.id}] " self.logger.info(f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]") f"from user [ID:{inline_query.from_user.id}]")
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results): async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] " f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]") f"from user [ID:{inline_query.from_user.id}]")
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult): async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] " f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]") f"result [ID:{chosen_inline_result.result_id}]")
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results): async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] " f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]") f"result [ID:{chosen_inline_result.result_id}]")
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery): async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
if callback_query.message: if callback_query.message:
if callback_query.message.from_user: if callback_query.message.from_user:
self.logger.info(f"Received callback query [ID:{callback_query.id}] " self.logger.info(f"Received callback query [ID:{callback_query.id}] "
@ -100,7 +100,7 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] " f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]") f"from user [ID:{callback_query.from_user.id}]")
async def on_post_process_callback_query(self, callback_query, results): async def on_post_process_callback_query(self, callback_query, results, data: dict):
if callback_query.message: if callback_query.message:
if callback_query.message.from_user: if callback_query.message.from_user:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
@ -117,25 +117,25 @@ class LoggingMiddleware(BaseMiddleware):
f"from inline message [ID:{callback_query.inline_message_id}] " f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]") f"from user [ID:{callback_query.from_user.id}]")
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery): async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] " self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]") f"from user [ID:{shipping_query.from_user.id}]")
async def on_post_process_shipping_query(self, shipping_query, results): async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] " f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]") f"from user [ID:{shipping_query.from_user.id}]")
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery): async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] " self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]") f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results): async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] " f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]") f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_pre_process_error(self, dispatcher, update, error): async def on_pre_process_error(self, dispatcher, update, error, data: dict):
timeout = self.check_timeout(update) timeout = self.check_timeout(update)
if timeout > 0: if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)") self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
from . import Bot
from .. import types
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
from ..utils import context
def _get(key, default=None, no_error=False):
result = context.get_value(key, default)
if not no_error and result is None:
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
f"Maybe asyncio task factory is not configured!\n"
f"\t>>> from aiogram.utils import context\n"
f"\t>>> loop.set_task_factory(context.task_factory)")
return result
def get_bot() -> Bot:
return _get('bot')
def get_dispatcher() -> Dispatcher:
return _get('dispatcher')
def get_update() -> types.Update:
return _get(UPDATE_OBJECT)
def get_mode() -> str:
return _get(MODE, 'unknown')
def get_chat() -> int:
return _get('chat', no_error=True)
def get_user() -> int:
return _get('user', no_error=True)
def get_state() -> FSMContext:
return get_dispatcher().current_state()

File diff suppressed because it is too large Load diff

View file

@ -1,289 +0,0 @@
import asyncio
import inspect
import re
from ..types import CallbackQuery, ContentType, Message
from ..utils import context
from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args):
"""
Helper for executing filter
:param filter_:
:param args:
:param kwargs:
:return:
"""
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args)
else:
return filter_(*args)
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
if not f:
return False
return True
class Filter:
"""
Base class for filters
"""
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
def check(self, *args, **kwargs):
raise NotImplementedError
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __aiter__(self):
return None
def __await__(self):
return self.check
async def check(self, *args, **kwargs):
pass
class AnyFilter(AsyncFilter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(AsyncFilter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter):
"""
Check commands in message
"""
def __init__(self, commands):
self.commands = commands
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
if command not in self.commands:
return False
return True
class RegexpFilter(Filter):
"""
Regexp filter for messages
"""
def __init__(self, regexp):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
def check(self, obj):
if isinstance(obj, Message) and obj.text:
return bool(self.regexp.search(obj.text))
elif isinstance(obj, CallbackQuery) and obj.data:
return bool(self.regexp.search(obj.data))
return False
class RegexpCommandsFilter(AsyncFilter):
"""
Check commands by regexp in message
"""
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
message.conf['regexp_command'] = search
return True
return False
class ContentTypeFilter(Filter):
"""
Check message content type
"""
def __init__(self, content_types):
self.content_types = content_types
def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
class CancelFilter(Filter):
"""
Find cancel in message text
"""
def __init__(self, cancel_set=None):
if cancel_set is None:
cancel_set = ['/cancel', 'cancel', 'cancel.']
self.cancel_set = cancel_set
def check(self, message):
if message.text:
return message.text.lower() in self.cancel_set
class StateFilter(AsyncFilter):
"""
Check user state
"""
def __init__(self, dispatcher, state):
self.dispatcher = dispatcher
self.state = state
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if self.state == '*':
return True
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False
class StatesListFilter(StateFilter):
"""
List of states
"""
async def check(self, obj):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
class ExceptionsFilter(Filter):
"""
Filter for exceptions
"""
def __init__(self, exception):
self.exception = exception
def check(self, dispatcher, update, exception):
return isinstance(exception, self.exception)
def generate_default_filters(dispatcher, *args, **kwargs):
"""
Prepare filters
:param dispatcher:
:param args:
:param kwargs:
:return:
"""
filters_set = []
for name, filter_ in kwargs.items():
if filter_ is None and name != DefaultFilters.STATE:
continue
if name == DefaultFilters.COMMANDS:
if isinstance(filter_, str):
filters_set.append(CommandsFilter([filter_]))
else:
filters_set.append(CommandsFilter(filter_))
elif name == DefaultFilters.REGEXP:
filters_set.append(RegexpFilter(filter_))
elif name == DefaultFilters.CONTENT_TYPES:
filters_set.append(ContentTypeFilter(filter_))
elif name == DefaultFilters.FUNC:
filters_set.append(filter_)
elif name == DefaultFilters.STATE:
if isinstance(filter_, (list, set, tuple)):
filters_set.append(StatesListFilter(dispatcher, filter_))
else:
filters_set.append(StateFilter(dispatcher, filter_))
elif isinstance(filter_, Filter):
filters_set.append(filter_)
filters_set += list(args)
return filters_set
class DefaultFilters(Helper):
mode = HelperMode.snake_case
COMMANDS = Item() # commands
REGEXP = Item() # regexp
CONTENT_TYPES = Item() # content_type
FUNC = Item() # func
STATE = Item() # state

View file

@ -0,0 +1,24 @@
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
RegexpCommandsFilter, StateFilter, Text
from .factory import FiltersFactory
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [
'AbstractFilter',
'BoundFilter',
'Command',
'CommandStart',
'CommandHelp',
'ContentTypeFilter',
'ExceptionsFilter',
'Filter',
'FilterNotPassed',
'FilterRecord',
'FiltersFactory',
'RegexpCommandsFilter',
'Regexp',
'StateFilter',
'Text',
'check_filter',
'check_filters'
]

View file

@ -0,0 +1,313 @@
import inspect
import re
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Optional, Union
from aiogram import types
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
from aiogram.types import CallbackQuery, Message
class Command(Filter):
"""
You can handle commands by using this filter
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False):
"""
Filter can be initialized from filters factory or by simply creating instance of this class
:param commands: command or list of commands
:param prefixes:
:param ignore_case:
:param ignore_mention:
"""
if isinstance(commands, str):
commands = (commands,)
self.commands = list(map(str.lower, commands)) if ignore_case else commands
self.prefixes = prefixes
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
@classmethod
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Validator for filters factory
:param full_config:
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
full_command = message.text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
return False
elif prefix not in prefixes:
return False
elif (command.lower() if ignore_case else command) not in commands:
return False
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
@dataclass
class CommandObj:
prefix: str = '/'
command: str = ''
mention: str = None
args: str = field(repr=False, default=None)
@property
def mentioned(self) -> bool:
return bool(self.mention)
@property
def text(self) -> str:
line = self.prefix + self.command
if self.mentioned:
line += '@' + self.mention
if self.args:
line += ' ' + self.args
return line
class CommandStart(Command):
def __init__(self):
super(CommandStart, self).__init__(['start'])
class CommandHelp(Command):
def __init__(self):
super(CommandHelp, self).__init__(['help'])
class Text(Filter):
"""
Simple text filter
"""
def __init__(self,
equals: Optional[str] = None,
contains: Optional[str] = None,
startswith: Optional[str] = None,
endswith: Optional[str] = None,
ignore_case=False):
"""
Check text for one of pattern. Only one mode can be used in one filter.
:param equals:
:param contains:
:param startswith:
:param endswith:
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('startswith', startswith),
('endswith', endswith)
] if arg[1]])
raise ValueError(f"Arguments '{args}' cannot be used together.")
elif check == 0:
raise ValueError(f"No one mode is specified!")
self.equals = equals
self.contains = contains
self.endswith = endswith
self.startswith = startswith
self.ignore_case = ignore_case
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'text' in full_config:
return {'equals': full_config.pop('text')}
elif 'text_contains' in full_config:
return {'contains': full_config.pop('text_contains')}
elif 'text_startswith' in full_config:
return {'startswith': full_config.pop('text_startswith')}
elif 'text_endswith' in full_config:
return {'endswith': full_config.pop('text_endswith')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
elif isinstance(obj, CallbackQuery):
text = obj.data
else:
return False
if self.ignore_case:
text = text.lower()
if self.equals:
return text == self.equals
elif self.contains:
return self.contains in text
elif self.startswith:
return text.startswith(self.startswith)
elif self.endswith:
return text.endswith(self.endswith)
return False
class Regexp(Filter):
"""
Regexp filter for messages and callback query
"""
def __init__(self, regexp):
if not isinstance(regexp, re.Pattern):
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
self.regexp = regexp
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
match = self.regexp.search(obj.text or obj.caption or '')
elif isinstance(obj, CallbackQuery) and obj.data:
match = self.regexp.search(obj.data)
else:
return False
if match:
return {'regexp': match}
return False
class RegexpCommandsFilter(BoundFilter):
"""
Check commands by regexp in message
"""
key = 'regexp_commands'
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username:
return False
for command in self.regexp_commands:
search = command.search(message.text)
if search:
return {'regexp_command': search}
return False
class ContentTypeFilter(BoundFilter):
"""
Check message content type
"""
key = 'content_types'
required = True
default = types.ContentTypes.TEXT
def __init__(self, content_types):
self.content_types = content_types
async def check(self, message):
return types.ContentType.ANY in self.content_types or \
message.content_type in self.content_types
class StateFilter(BoundFilter):
"""
Check user state
"""
key = 'state'
required = True
ctx_state = ContextVar('user_state')
def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup
self.dispatcher = dispatcher
states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ]
for item in state:
if isinstance(item, State):
states.append(item.state)
elif inspect.isclass(item) and issubclass(item, StatesGroup):
states.extend(item.all_states_names)
else:
states.append(item)
self.states = states
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
async def check(self, obj):
if '*' in self.states:
return {'state': self.dispatcher.current_state()}
try:
state = self.ctx_state.get()
except LookupError:
chat, user = self.get_target(obj)
if chat or user:
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state)
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
else:
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return False
class ExceptionsFilter(BoundFilter):
"""
Filter for exceptions
"""
key = 'exception'
def __init__(self, dispatcher, exception):
super().__init__(dispatcher)
self.exception = exception
async def check(self, dispatcher, update, exception):
try:
raise exception
except self.exception:
return True
except:
return False

View file

@ -0,0 +1,73 @@
import typing
from .filters import AbstractFilter, FilterRecord
from ..handler import Handler
class FiltersFactory:
"""
Default filters factory
"""
def __init__(self, dispatcher):
self._dispatcher = dispatcher
self._registered: typing.List[FilterRecord] = []
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
"""
Register filter
:param callback: callable or subclass of :obj:`AbstractFilter`
:param validator: custom validator.
:param event_handlers: list of instances of :obj:`Handler`
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
"""
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
self._registered.append(record)
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
"""
Unregister callback
:param callback: callable of subclass of :obj:`AbstractFilter`
"""
for record in self._registered:
if record.callback == callback:
self._registered.remove(record)
def resolve(self, event_handler, *custom_filters, **full_config
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
"""
Resolve filters to filters-set
:param event_handler:
:param custom_filters:
:param full_config:
:return:
"""
filters_set = []
filters_set.extend(self._resolve_registered(event_handler,
{k: v for k, v in full_config.items() if v is not None}))
if custom_filters:
filters_set.extend(custom_filters)
return filters_set
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
"""
Resolve registered filters
:param event_handler:
:param full_config:
:return:
"""
for record in self._registered:
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
if filter_:
yield filter_
if full_config:
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')

View file

@ -0,0 +1,250 @@
import abc
import inspect
import typing
from ..handler import Handler
from ...types.base import TelegramObject
class FilterNotPassed(Exception):
pass
def wrap_async(func):
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
if inspect.isawaitable(func) \
or inspect.iscoroutinefunction(func) \
or isinstance(func, AbstractFilter):
return func
return async_wrapper
async def check_filter(dispatcher, filter_, args):
"""
Helper for executing filter
:param dispatcher:
:param filter_:
:param args:
:return:
"""
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
return await filter_(*args, **kwargs)
else:
return filter_(*args, **kwargs)
async def check_filters(dispatcher, filters, args):
"""
Check list of filters
:param dispatcher:
:param filters:
:param args:
:return:
"""
data = {}
if filters is not None:
for filter_ in filters:
f = await check_filter(dispatcher, filter_, args)
if not f:
raise FilterNotPassed()
elif isinstance(f, dict):
data.update(f)
return data
class FilterRecord:
"""
Filters record for factory
"""
def __init__(self, callback: typing.Callable,
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
if event_handlers and exclude_event_handlers:
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
self.callback = callback
self.event_handlers = event_handlers
self.exclude_event_handlers = exclude_event_handlers
if validator is not None:
if not callable(validator):
raise TypeError(f"validator must be callable, not {type(validator)}")
self.resolver = validator
elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate
else:
raise RuntimeError('validator is required!')
def resolve(self, dispatcher, event_handler, full_config):
if not self._check_event_handler(event_handler):
return
config = self.resolver(full_config)
if config:
if 'dispatcher' not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
for key in config:
if key in full_config:
full_config.pop(key)
return self.callback(**config)
def _check_event_handler(self, event_handler) -> bool:
if self.event_handlers:
return event_handler in self.event_handlers
elif self.exclude_event_handlers:
return event_handler not in self.exclude_event_handlers
return True
class AbstractFilter(abc.ABC):
"""
Abstract class for custom filters
"""
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Validate and parse config
:param full_config:
:return: config
"""
pass
@abc.abstractmethod
async def check(self, *args) -> bool:
"""
Check object
:param args:
:return:
"""
pass
async def __call__(self, obj: TelegramObject) -> bool:
return await self.check(obj)
def __invert__(self):
return NotFilter(self)
def __and__(self, other):
if isinstance(self, AndFilter):
self.append(other)
return self
return AndFilter(self, other)
def __or__(self, other):
if isinstance(self, OrFilter):
self.append(other)
return self
return OrFilter(self, other)
class Filter(AbstractFilter):
"""
You can make subclasses of that class for custom filters
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
pass
class BoundFilter(Filter):
"""
Base class for filters with default validator
"""
key = None
required = False
default = None
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None:
if cls.key in full_config:
return {cls.key: full_config[cls.key]}
elif cls.required:
return {cls.key: cls.default}
class _LogicFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
class NotFilter(_LogicFilter):
def __init__(self, target):
self.target = wrap_async(target)
async def check(self, *args):
return not bool(await self.target(*args))
class AndFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
"""
All filters must return a positive result
:param args:
:return:
"""
data = {}
for target in self.targets:
result = await target(*args)
if not result:
return False
if isinstance(result, dict):
data.update(result)
if not data:
return True
return data
def append(self, target):
self.targets.append(wrap_async(target))
class OrFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
"""
One of filters must return a positive result
:param args:
:return:
"""
for target in self.targets:
result = await target(*args)
if result:
if isinstance(result, dict):
return result
return True
return False
def append(self, target):
self.targets.append(wrap_async(target))

View file

@ -0,0 +1,198 @@
import inspect
from typing import Optional
from ..dispatcher import Dispatcher
class State:
"""
State object
"""
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
self._state = state
self._group_name = group_name
self._group = None
@property
def group(self):
if not self._group:
raise RuntimeError('This state is not in any group.')
return self._group
def get_root(self):
return self.group.get_root()
@property
def state(self):
if self._state is None:
return None
elif self._state == '*':
return self._state
elif self._group_name is None and self._group:
group = self._group.__full_group_name__
elif self._group_name:
group = self._group_name
else:
group = '@'
return f"{group}:{self._state}"
def set_parent(self, group):
if not issubclass(group, StatesGroup):
raise ValueError('Group must be subclass of StatesGroup')
self._group = group
def __set_name__(self, owner, name):
if self._state is None:
self._state = name
self.set_parent(owner)
def __str__(self):
return f"<State '{self.state or ''}'>"
__repr__ = __str__
async def set(self):
state = Dispatcher.current().current_state()
await state.set_state(self.state)
class StatesGroupMeta(type):
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
states = []
childs = []
cls._group_name = name
for name, prop in namespace.items():
if isinstance(prop, State):
states.append(prop)
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
childs.append(prop)
prop._parent = cls
# continue
cls._parent = None
cls._childs = tuple(childs)
cls._states = tuple(states)
cls._state_names = tuple(state.state for state in states)
return cls
@property
def __group_name__(cls):
return cls._group_name
@property
def __full_group_name__(cls):
if cls._parent:
return cls._parent.__full_group_name__ + '.' + cls._group_name
return cls._group_name
@property
def states(cls) -> tuple:
return cls._states
@property
def childs(cls):
return cls._childs
@property
def all_childs(cls):
result = cls.childs
for child in cls.childs:
result += child.childs
return result
@property
def all_states(cls):
result = cls.states
for group in cls.childs:
result += group.all_states
return result
@property
def all_states_names(cls):
return tuple(state.state for state in cls.all_states)
@property
def states_names(cls) -> tuple:
return tuple(state.state for state in cls.states)
def get_root(cls):
if cls._parent is None:
return cls
return cls._parent.get_root()
def __contains__(cls, item):
if isinstance(item, str):
return item in cls.all_states_names
elif isinstance(item, State):
return item in cls.all_states
elif isinstance(item, StatesGroup):
return item in cls.all_childs
return False
def __str__(self):
return f"<StatesGroup '{self.__full_group_name__}'>"
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def next(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
next_step = cls.states_names.index(state_name) + 1
except ValueError:
next_step = 0
try:
next_state_name = cls.states[next_step].state
except IndexError:
next_state_name = None
await state.set_state(next_state_name)
return next_state_name
@classmethod
async def previous(cls) -> str:
state = Dispatcher.current().current_state()
state_name = await state.get_state()
try:
previous_step = cls.states_names.index(state_name) - 1
except ValueError:
previous_step = 0
if previous_step < 0:
previous_state_name = None
else:
previous_state_name = cls.states[previous_step].state
await state.set_state(previous_state_name)
return previous_state_name
@classmethod
async def first(cls) -> str:
state = Dispatcher.current().current_state()
first_step_name = cls.states_names[0]
await state.set_state(first_step_name)
return first_step_name
@classmethod
async def last(cls) -> str:
state = Dispatcher.current().current_state()
last_step_name = cls.states_names[-1]
await state.set_state(last_step_name)
return last_step_name
default_state = State()
any_state = State(state='*')

View file

@ -1,5 +1,7 @@
from .filters import check_filters import inspect
from ..utils import context from contextvars import ContextVar
ctx_data = ContextVar('ctx_handler_data')
class SkipHandler(BaseException): class SkipHandler(BaseException):
@ -10,6 +12,14 @@ class CancelHandler(BaseException):
pass pass
def _check_spec(func: callable, kwargs: dict):
spec = inspect.getfullargspec(func)
if spec.varkw:
return kwargs
return {k: v for k, v in kwargs.items() if k in spec.args}
class Handler: class Handler:
def __init__(self, dispatcher, once=True, middleware_key=None): def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher self.dispatcher = dispatcher
@ -57,31 +67,43 @@ class Handler:
:param args: :param args:
:return: :return:
""" """
from .filters import check_filters, FilterNotPassed
results = [] results = []
data = {}
ctx_data.set(data)
if self.middleware_key: if self.middleware_key:
try: try:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args) await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
except CancelHandler: # Allow to cancel current event except CancelHandler: # Allow to cancel current event
return results return results
for filters, handler in self.handlers: try:
if await check_filters(filters, args): for filters, handler in self.handlers:
try: try:
if self.middleware_key: data.update(await check_filters(self.dispatcher, filters, args))
context.set_value('handler', handler) except FilterNotPassed:
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue continue
except CancelHandler: else:
break try:
if self.middleware_key: if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", # context.set_value('handler', handler)
args + (results,)) await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
partial_data = _check_spec(handler, data)
response = await handler(*args, **partial_data)
if response is not None:
results.append(response)
if self.once:
break
except SkipHandler:
continue
except CancelHandler:
break
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results, data,))
return results return results

View file

@ -101,3 +101,28 @@ class BaseMiddleware:
if not handler: if not handler:
return None return None
await handler(*args) await handler(*args)
class LifetimeControllerMiddleware(BaseMiddleware):
# TODO: Rename class
skip_patterns = None
async def pre_process(self, obj, data, *args):
pass
async def post_process(self, obj, data, *args):
pass
async def trigger(self, action, args):
if self.skip_patterns is not None and any(item in action for item in self.skip_patterns):
return False
obj, *args, data = args
if action.startswith('pre_process_'):
await self.pre_process(obj, data, *args)
elif action.startswith('post_process_'):
await self.post_process(obj, data, *args)
else:
return False
return True

View file

@ -281,8 +281,20 @@ class FSMContext:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
pass pass
@staticmethod
def _resolve_state(value):
from .filters.state import State
if value is None:
return
elif isinstance(value, str):
return value
elif isinstance(value, State):
return value.state
return str(value)
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]: async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
return await self.storage.get_state(chat=self.chat, user=self.user, default=default) return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict: async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
return await self.storage.get_data(chat=self.chat, user=self.user, default=default) return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
@ -291,7 +303,7 @@ class FSMContext:
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs) await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None): async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
await self.storage.set_state(chat=self.chat, user=self.user, state=state) await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
async def set_data(self, data: typing.Dict = None): async def set_data(self, data: typing.Dict = None):
await self.storage.set_data(chat=self.chat, user=self.user, data=data) await self.storage.set_data(chat=self.chat, user=self.user, data=data)

View file

@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union
from aiohttp import web from aiohttp import web
from aiohttp.web_exceptions import HTTPGone from aiohttp.web_exceptions import HTTPGone
from .. import types from .. import types
from ..bot import api from ..bot import api
from ..types import ParseMode from ..types import ParseMode
from ..types.base import Boolean, Float, Integer, String from ..types.base import Boolean, Float, Integer, String
from ..utils import context
from ..utils import helper, markdown from ..utils import helper, markdown
from ..utils import json from ..utils import json
from ..utils.deprecated import warn_deprecated as warn from ..utils.deprecated import warn_deprecated as warn
@ -89,8 +89,10 @@ class WebhookRequestHandler(web.View):
""" """
dp = self.request.app[BOT_DISPATCHER_KEY] dp = self.request.app[BOT_DISPATCHER_KEY]
try: try:
context.set_value('dispatcher', dp) from aiogram.bot import bot
context.set_value('bot', dp.bot) from aiogram.dispatcher import dispatcher
dispatcher.set(dp)
bot.bot.set(dp.bot)
except RuntimeError: except RuntimeError:
pass pass
return dp return dp
@ -117,9 +119,9 @@ class WebhookRequestHandler(web.View):
""" """
self.validate_ip() self.validate_ip()
context.update_state({'CALLER': WEBHOOK, # context.update_state({'CALLER': WEBHOOK,
WEBHOOK_CONNECTION: True, # WEBHOOK_CONNECTION: True,
WEBHOOK_REQUEST: self.request}) # WEBHOOK_REQUEST: self.request})
dispatcher = self.get_dispatcher() dispatcher = self.get_dispatcher()
update = await self.parse_update(dispatcher.bot) update = await self.parse_update(dispatcher.bot)
@ -177,7 +179,7 @@ class WebhookRequestHandler(web.View):
if fut.done(): if fut.done():
return fut.result() return fut.result()
else: else:
context.set_value(WEBHOOK_CONNECTION, False) # context.set_value(WEBHOOK_CONNECTION, False)
fut.remove_done_callback(cb) fut.remove_done_callback(cb)
fut.add_done_callback(self.respond_via_request) fut.add_done_callback(self.respond_via_request)
finally: finally:
@ -202,7 +204,7 @@ class WebhookRequestHandler(web.View):
results = task.result() results = task.result()
except Exception as e: except Exception as e:
loop.create_task( loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e)) dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
else: else:
response = self.get_response(results) response = self.get_response(results)
if response is not None: if response is not None:
@ -249,7 +251,7 @@ class WebhookRequestHandler(web.View):
ip_address, accept = self.check_ip() ip_address, accept = self.check_ip()
if not accept: if not accept:
raise web.HTTPUnauthorized() raise web.HTTPUnauthorized()
context.set_value('TELEGRAM_IP', ip_address) # context.set_value('TELEGRAM_IP', ip_address)
class GoneRequestHandler(web.View): class GoneRequestHandler(web.View):
@ -352,8 +354,8 @@ class BaseResponse:
async def __call__(self, bot=None): async def __call__(self, bot=None):
if bot is None: if bot is None:
from aiogram.dispatcher import ctx from aiogram import Bot
bot = ctx.get_bot() bot = Bot.current()
return await self.execute_response(bot) return await self.execute_response(bot)
async def __aenter__(self): async def __aenter__(self):
@ -446,7 +448,8 @@ class ParseModeMixin:
:return: :return:
""" """
bot = context.get_value('bot', None) from aiogram import Bot
bot = Bot.current()
if bot is not None: if bot is not None:
return bot.parse_mode return bot.parse_mode
@ -952,7 +955,7 @@ class SendMediaGroup(BaseResponse, ReplyToMixin, DisableNotificationMixin):
self.reply_to_message_id = reply_to_message_id self.reply_to_message_id = reply_to_message_id
def prepare(self): def prepare(self):
files = self.media.get_files() files = dict(self.media.get_files())
if files: if files:
raise TypeError('Allowed only file ID or URL\'s') raise TypeError('Allowed only file ID or URL\'s')

View file

@ -34,7 +34,7 @@ from .invoice import Invoice
from .labeled_price import LabeledPrice from .labeled_price import LabeledPrice
from .location import Location from .location import Location
from .mask_position import MaskPosition from .mask_position import MaskPosition
from .message import ContentType, Message, ParseMode from .message import ContentType, ContentTypes, Message, ParseMode
from .message_entity import MessageEntity, MessageEntityType from .message_entity import MessageEntity, MessageEntityType
from .order_info import OrderInfo from .order_info import OrderInfo
from .passport_data import PassportData from .passport_data import PassportData
@ -77,6 +77,7 @@ __all__ = (
'ChosenInlineResult', 'ChosenInlineResult',
'Contact', 'Contact',
'ContentType', 'ContentType',
'ContentTypes',
'Document', 'Document',
'EncryptedCredentials', 'EncryptedCredentials',
'EncryptedPassportElement', 'EncryptedPassportElement',

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import io import io
import typing import typing
from contextvars import ContextVar
from typing import TypeVar from typing import TypeVar
from .fields import BaseField from .fields import BaseField
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
setattr(cls, ALIASES_ATTR_NAME, aliases) setattr(cls, ALIASES_ATTR_NAME, aliases)
mcs._objects[cls.__name__] = cls mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls return cls
@property @property
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values: if value.default and key not in self.values:
self.values[key] = value.default self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property @property
def conf(self) -> typing.Dict[str, typing.Any]: def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf return self._conf
@ -137,8 +150,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
@property @property
def bot(self): def bot(self):
from ..dispatcher import ctx from ..bot.bot import Bot
return ctx.get_bot() return Bot.current()
def to_python(self) -> typing.Dict: def to_python(self) -> typing.Dict:
""" """

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import asyncio import asyncio
import typing import typing
from contextvars import ContextVar
from . import base from . import base
from . import fields from . import fields
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
if as_html: if as_html:
return markdown.hlink(name, self.user_url) return markdown.hlink(name, self.user_url)
return markdown.link(name, self.user_url) return markdown.link(name, self.user_url)
async def get_url(self): async def get_url(self):
""" """
Use this method to get chat link. Use this method to get chat link.
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
@classmethod @classmethod
async def _do(cls, action: str, sleep=None): async def _do(cls, action: str, sleep=None):
from ..dispatcher.ctx import get_bot, get_chat from aiogram import Bot
await get_bot().send_chat_action(get_chat(), action) await Bot.current().send_chat_action(Chat.current(), action)
if sleep: if sleep:
await asyncio.sleep(sleep) await asyncio.sleep(sleep)

View file

@ -9,7 +9,7 @@ class BaseField(metaclass=abc.ABCMeta):
Base field (prop) Base field (prop)
""" """
def __init__(self, *, base=None, default=None, alias=None): def __init__(self, *, base=None, default=None, alias=None, on_change=None):
""" """
Init prop Init prop
@ -17,10 +17,12 @@ class BaseField(metaclass=abc.ABCMeta):
:param default: default value :param default: default value
:param alias: alias name (for e.g. field 'from' has to be named 'from_user' :param alias: alias name (for e.g. field 'from' has to be named 'from_user'
as 'from' is a builtin Python keyword as 'from' is a builtin Python keyword
:param on_change: callback will be called when value is changed
""" """
self.base_object = base self.base_object = base
self.default = default self.default = default
self.alias = alias self.alias = alias
self.on_change = on_change
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
if self.alias is None: if self.alias is None:
@ -53,6 +55,13 @@ class BaseField(metaclass=abc.ABCMeta):
self.resolve_base(instance) self.resolve_base(instance)
value = self.deserialize(value, parent) value = self.deserialize(value, parent)
instance.values[self.alias] = value instance.values[self.alias] = value
self._trigger_changed(instance, value)
def _trigger_changed(self, instance, value):
if not self.on_change and instance is not None:
return
callback = getattr(instance, self.on_change)
callback(value)
def __get__(self, instance, owner): def __get__(self, instance, owner):
return self.get_value(instance) return self.get_value(instance)
@ -154,7 +163,7 @@ class ListOfLists(Field):
return result return result
class DateTimeField(BaseField): class DateTimeField(Field):
""" """
In this field st_ored datetime In this field st_ored datetime
@ -167,3 +176,24 @@ class DateTimeField(BaseField):
def deserialize(self, value, parent=None): def deserialize(self, value, parent=None):
return datetime.datetime.fromtimestamp(value) return datetime.datetime.fromtimestamp(value)
class TextField(Field):
def __init__(self, *, prefix=None, suffix=None, default=None, alias=None):
super(TextField, self).__init__(default=default, alias=alias)
self.prefix = prefix
self.suffix = suffix
def serialize(self, value):
if value is None:
return value
if self.prefix:
value = self.prefix + value
if self.suffix:
value += self.suffix
return value
def deserialize(self, value, parent=None):
if value is not None and not isinstance(value, str):
raise TypeError(f"Field '{self.alias}' should be str not {type(value).__name__}")
return value

View file

@ -1,6 +1,7 @@
import io import io
import logging import logging
import os import os
import secrets
import time import time
import aiohttp import aiohttp
@ -45,6 +46,8 @@ class InputFile(base.TelegramObject):
self._filename = filename self._filename = filename
self.attachment_key = secrets.token_urlsafe(16)
def __del__(self): def __del__(self):
""" """
Close file descriptor Close file descriptor
@ -54,13 +57,17 @@ class InputFile(base.TelegramObject):
@property @property
def filename(self): def filename(self):
if self._filename is None: if self._filename is None:
self._filename = api._guess_filename(self._file) self._filename = api.guess_filename(self._file)
return self._filename return self._filename
@filename.setter @filename.setter
def filename(self, value): def filename(self, value):
self._filename = value self._filename = value
@property
def attach(self):
return f"attach://{self.attachment_key}"
def get_filename(self) -> str: def get_filename(self) -> str:
""" """
Get file name Get file name
@ -159,6 +166,9 @@ class InputFile(base.TelegramObject):
return writer return writer
def __str__(self):
return f"<InputFile 'attach://{self.attachment_key}' with file='{self.file}'>"
def to_python(self): def to_python(self):
raise TypeError('Object of this type is not exportable!') raise TypeError('Object of this type is not exportable!')

View file

@ -12,6 +12,9 @@ ATTACHMENT_PREFIX = 'attach://'
class InputMedia(base.TelegramObject): class InputMedia(base.TelegramObject):
""" """
This object represents the content of a media message to be sent. It should be one of This object represents the content of a media message to be sent. It should be one of
- InputMediaAnimation
- InputMediaDocument
- InputMediaAudio
- InputMediaPhoto - InputMediaPhoto
- InputMediaVideo - InputMediaVideo
@ -20,36 +23,76 @@ class InputMedia(base.TelegramObject):
https://core.telegram.org/bots/api#inputmedia https://core.telegram.org/bots/api#inputmedia
""" """
type: base.String = fields.Field(default='photo') type: base.String = fields.Field(default='photo')
media: base.String = fields.Field() media: base.String = fields.Field(alias='media', on_change='_media_changed')
thumb: typing.Union[base.InputFile, base.String] = fields.Field() thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed')
caption: base.String = fields.Field() caption: base.String = fields.Field()
parse_mode: base.Boolean = fields.Field() parse_mode: base.Boolean = fields.Field()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._thumb_file = None
self._media_file = None
media = kwargs.pop('media', None)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
elif media is not None:
self.media = media
thumb = kwargs.pop('thumb', None)
if isinstance(thumb, (io.IOBase, InputFile)):
self.thumb_file = thumb
elif thumb is not None:
self.thumb = thumb
super(InputMedia, self).__init__(*args, **kwargs) super(InputMedia, self).__init__(*args, **kwargs)
try: try:
if self.parse_mode is None and self.bot.parse_mode: if self.parse_mode is None and self.bot and self.bot.parse_mode:
self.parse_mode = self.bot.parse_mode self.parse_mode = self.bot.parse_mode
except RuntimeError: except RuntimeError:
pass pass
@property @property
def file(self): def file(self):
return getattr(self, '_file', None) return self._media_file
@file.setter @file.setter
def file(self, file: io.IOBase): def file(self, file: io.IOBase):
setattr(self, '_file', file) self.media = 'attach://' + secrets.token_urlsafe(16)
attachment_key = self.attachment_key = secrets.token_urlsafe(16) self._media_file = file
self.media = ATTACHMENT_PREFIX + attachment_key
@file.deleter
def file(self):
self.media = None
self._media_file = None
def _media_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
self._media_file = None
@property @property
def attachment_key(self): def thumb_file(self):
return self.conf.get('attachment_key', None) return self._thumb_file
@attachment_key.setter @thumb_file.setter
def attachment_key(self, value): def thumb_file(self, file: io.IOBase):
self.conf['attachment_key'] = value self.thumb = 'attach://' + secrets.token_urlsafe(16)
self._thumb_file = file
@thumb_file.deleter
def thumb_file(self):
self.thumb = None
self._thumb_file = None
def _thumb_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
self._thumb_file = None
def get_files(self):
if self._media_file:
yield self.media[9:], self._media_file
if self._thumb_file:
yield self.thumb[9:], self._thumb_file
class InputMediaAnimation(InputMedia): class InputMediaAnimation(InputMedia):
@ -72,9 +115,6 @@ class InputMediaAnimation(InputMedia):
width=width, height=height, duration=duration, width=width, height=height, duration=duration,
parse_mode=parse_mode, conf=kwargs) parse_mode=parse_mode, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaDocument(InputMedia): class InputMediaDocument(InputMedia):
""" """
@ -89,9 +129,6 @@ class InputMediaDocument(InputMedia):
caption=caption, parse_mode=parse_mode, caption=caption, parse_mode=parse_mode,
conf=kwargs) conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaAudio(InputMedia): class InputMediaAudio(InputMedia):
""" """
@ -119,9 +156,6 @@ class InputMediaAudio(InputMedia):
performer=performer, title=title, performer=performer, title=title,
parse_mode=parse_mode, conf=kwargs) parse_mode=parse_mode, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaPhoto(InputMedia): class InputMediaPhoto(InputMedia):
""" """
@ -136,9 +170,6 @@ class InputMediaPhoto(InputMedia):
caption=caption, parse_mode=parse_mode, caption=caption, parse_mode=parse_mode,
conf=kwargs) conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class InputMediaVideo(InputMedia): class InputMediaVideo(InputMedia):
""" """
@ -151,18 +182,17 @@ class InputMediaVideo(InputMedia):
duration: base.Integer = fields.Field() duration: base.Integer = fields.Field()
supports_streaming: base.Boolean = fields.Field() supports_streaming: base.Boolean = fields.Field()
def __init__(self, media: base.InputFile, caption: base.String = None, def __init__(self, media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None, width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
parse_mode: base.Boolean = None, parse_mode: base.Boolean = None,
supports_streaming: base.Boolean = None, **kwargs): supports_streaming: base.Boolean = None, **kwargs):
super(InputMediaVideo, self).__init__(type='video', media=media, caption=caption, super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption,
width=width, height=height, duration=duration, width=width, height=height, duration=duration,
parse_mode=parse_mode, parse_mode=parse_mode,
supports_streaming=supports_streaming, conf=kwargs) supports_streaming=supports_streaming, conf=kwargs)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
class MediaGroup(base.TelegramObject): class MediaGroup(base.TelegramObject):
""" """
@ -296,6 +326,7 @@ class MediaGroup(base.TelegramObject):
self.attach(photo) self.attach(photo)
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile], def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None, caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None): width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
""" """
@ -308,7 +339,7 @@ class MediaGroup(base.TelegramObject):
:param duration: :param duration:
""" """
if not isinstance(video, InputMedia): if not isinstance(video, InputMedia):
video = InputMediaVideo(media=video, caption=caption, video = InputMediaVideo(media=video, thumb=thumb, caption=caption,
width=width, height=height, duration=duration) width=width, height=height, duration=duration)
self.attach(video) self.attach(video)
@ -327,6 +358,7 @@ class MediaGroup(base.TelegramObject):
return result return result
def get_files(self): def get_files(self):
return {inputmedia.attachment_key: inputmedia.file for inputmedia in self.media:
for inputmedia in self.media if not isinstance(inputmedia, InputMedia) or not inputmedia.file:
if isinstance(inputmedia, InputMedia) and inputmedia.file} continue
yield from inputmedia.get_files()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime import datetime
import functools import functools
import sys import sys
@ -42,7 +44,7 @@ class Message(base.TelegramObject):
forward_from_message_id: base.Integer = fields.Field() forward_from_message_id: base.Integer = fields.Field()
forward_signature: base.String = fields.Field() forward_signature: base.String = fields.Field()
forward_date: datetime.datetime = fields.DateTimeField() forward_date: datetime.datetime = fields.DateTimeField()
reply_to_message: 'Message' = fields.Field(base='Message') reply_to_message: Message = fields.Field(base='Message')
edit_date: datetime.datetime = fields.DateTimeField() edit_date: datetime.datetime = fields.DateTimeField()
media_group_id: base.String = fields.Field() media_group_id: base.String = fields.Field()
author_signature: base.String = fields.Field() author_signature: base.String = fields.Field()
@ -72,7 +74,7 @@ class Message(base.TelegramObject):
channel_chat_created: base.Boolean = fields.Field() channel_chat_created: base.Boolean = fields.Field()
migrate_to_chat_id: base.Integer = fields.Field() migrate_to_chat_id: base.Integer = fields.Field()
migrate_from_chat_id: base.Integer = fields.Field() migrate_from_chat_id: base.Integer = fields.Field()
pinned_message: 'Message' = fields.Field(base='Message') pinned_message: Message = fields.Field(base='Message')
invoice: Invoice = fields.Field(base=Invoice) invoice: Invoice = fields.Field(base=Invoice)
successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment) successful_payment: SuccessfulPayment = fields.Field(base=SuccessfulPayment)
connected_website: base.String = fields.Field() connected_website: base.String = fields.Field()
@ -82,59 +84,59 @@ class Message(base.TelegramObject):
@functools.lru_cache() @functools.lru_cache()
def content_type(self): def content_type(self):
if self.text: if self.text:
return ContentType.TEXT[0] return ContentType.TEXT
elif self.audio: elif self.audio:
return ContentType.AUDIO[0] return ContentType.AUDIO
elif self.animation: elif self.animation:
return ContentType.ANIMATION[0] return ContentType.ANIMATION
elif self.document: elif self.document:
return ContentType.DOCUMENT[0] return ContentType.DOCUMENT
elif self.game: elif self.game:
return ContentType.GAME[0] return ContentType.GAME
elif self.photo: elif self.photo:
return ContentType.PHOTO[0] return ContentType.PHOTO
elif self.sticker: elif self.sticker:
return ContentType.STICKER[0] return ContentType.STICKER
elif self.video: elif self.video:
return ContentType.VIDEO[0] return ContentType.VIDEO
elif self.video_note: elif self.video_note:
return ContentType.VIDEO_NOTE[0] return ContentType.VIDEO_NOTE
elif self.voice: elif self.voice:
return ContentType.VOICE[0] return ContentType.VOICE
elif self.contact: elif self.contact:
return ContentType.CONTACT[0] return ContentType.CONTACT
elif self.venue: elif self.venue:
return ContentType.VENUE[0] return ContentType.VENUE
elif self.location: elif self.location:
return ContentType.LOCATION[0] return ContentType.LOCATION
elif self.new_chat_members: elif self.new_chat_members:
return ContentType.NEW_CHAT_MEMBERS[0] return ContentType.NEW_CHAT_MEMBERS
elif self.left_chat_member: elif self.left_chat_member:
return ContentType.LEFT_CHAT_MEMBER[0] return ContentType.LEFT_CHAT_MEMBER
elif self.invoice: elif self.invoice:
return ContentType.INVOICE[0] return ContentType.INVOICE
elif self.successful_payment: elif self.successful_payment:
return ContentType.SUCCESSFUL_PAYMENT[0] return ContentType.SUCCESSFUL_PAYMENT
elif self.connected_website: elif self.connected_website:
return ContentType.CONNECTED_WEBSITE[0] return ContentType.CONNECTED_WEBSITE
elif self.migrate_from_chat_id: elif self.migrate_from_chat_id:
return ContentType.MIGRATE_FROM_CHAT_ID[0] return ContentType.MIGRATE_FROM_CHAT_ID
elif self.migrate_to_chat_id: elif self.migrate_to_chat_id:
return ContentType.MIGRATE_TO_CHAT_ID[0] return ContentType.MIGRATE_TO_CHAT_ID
elif self.pinned_message: elif self.pinned_message:
return ContentType.PINNED_MESSAGE[0] return ContentType.PINNED_MESSAGE
elif self.new_chat_title: elif self.new_chat_title:
return ContentType.NEW_CHAT_TITLE[0] return ContentType.NEW_CHAT_TITLE
elif self.new_chat_photo: elif self.new_chat_photo:
return ContentType.NEW_CHAT_PHOTO[0] return ContentType.NEW_CHAT_PHOTO
elif self.delete_chat_photo: elif self.delete_chat_photo:
return ContentType.DELETE_CHAT_PHOTO[0] return ContentType.DELETE_CHAT_PHOTO
elif self.group_chat_created: elif self.group_chat_created:
return ContentType.GROUP_CHAT_CREATED[0] return ContentType.GROUP_CHAT_CREATED
elif self.passport_data: elif self.passport_data:
return ContentType.PASSPORT_DATA[0] return ContentType.PASSPORT_DATA
else: else:
return ContentType.UNKNOWN[0] return ContentType.UNKNOWN
def is_command(self): def is_command(self):
""" """
@ -239,7 +241,7 @@ class Message(base.TelegramObject):
return self.parse_entities() return self.parse_entities()
async def reply(self, text, parse_mode=None, disable_web_page_preview=None, async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
disable_notification=None, reply_markup=None, reply=True) -> 'Message': disable_notification=None, reply_markup=None, reply=True) -> Message:
""" """
Reply to this message Reply to this message
@ -729,6 +731,69 @@ class ContentType(helper.Helper):
""" """
List of message content types List of message content types
WARNING: Single elements
:key: TEXT
:key: AUDIO
:key: DOCUMENT
:key: GAME
:key: PHOTO
:key: STICKER
:key: VIDEO
:key: VIDEO_NOTE
:key: VOICE
:key: CONTACT
:key: LOCATION
:key: VENUE
:key: NEW_CHAT_MEMBERS
:key: LEFT_CHAT_MEMBER
:key: INVOICE
:key: SUCCESSFUL_PAYMENT
:key: CONNECTED_WEBSITE
:key: MIGRATE_TO_CHAT_ID
:key: MIGRATE_FROM_CHAT_ID
:key: UNKNOWN
:key: ANY
"""
mode = helper.HelperMode.snake_case
TEXT = helper.Item() # text
AUDIO = helper.Item() # audio
DOCUMENT = helper.Item() # document
ANIMATION = helper.Item() # animation
GAME = helper.Item() # game
PHOTO = helper.Item() # photo
STICKER = helper.Item() # sticker
VIDEO = helper.Item() # video
VIDEO_NOTE = helper.Item() # video_note
VOICE = helper.Item() # voice
CONTACT = helper.Item() # contact
LOCATION = helper.Item() # location
VENUE = helper.Item() # venue
NEW_CHAT_MEMBERS = helper.Item() # new_chat_member
LEFT_CHAT_MEMBER = helper.Item() # left_chat_member
INVOICE = helper.Item() # invoice
SUCCESSFUL_PAYMENT = helper.Item() # successful_payment
CONNECTED_WEBSITE = helper.Item() # connected_website
MIGRATE_TO_CHAT_ID = helper.Item() # migrate_to_chat_id
MIGRATE_FROM_CHAT_ID = helper.Item() # migrate_from_chat_id
PINNED_MESSAGE = helper.Item() # pinned_message
NEW_CHAT_TITLE = helper.Item() # new_chat_title
NEW_CHAT_PHOTO = helper.Item() # new_chat_photo
DELETE_CHAT_PHOTO = helper.Item() # delete_chat_photo
GROUP_CHAT_CREATED = helper.Item() # group_chat_created
PASSPORT_DATA = helper.Item() # passport_data
UNKNOWN = helper.Item() # unknown
ANY = helper.Item() # any
class ContentTypes(helper.Helper):
"""
List of message content types
WARNING: List elements.
:key: TEXT :key: TEXT
:key: AUDIO :key: AUDIO
:key: DOCUMENT :key: DOCUMENT

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from . import base from . import base
from . import fields from . import fields
from .callback_query import CallbackQuery from .callback_query import CallbackQuery

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import babel import babel
from . import base from . import base

View file

@ -1,140 +0,0 @@
"""
You need to setup task factory:
>>> from aiogram.utils import context
>>> loop = asyncio.get_event_loop()
>>> loop.set_task_factory(context.task_factory)
"""
import asyncio
import typing
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
"""
Task factory for implementing context processor
:param loop:
:param coro:
:return: new task
:rtype: :obj:`asyncio.Task`
"""
# Is not allowed when loop is closed.
if loop.is_closed():
raise RuntimeError('Event loop is closed.')
task = asyncio.Task(coro, loop=loop)
# Hide factory
if task._source_traceback:
del task._source_traceback[-1]
try:
task.context = asyncio.Task.current_task().context.copy()
except AttributeError:
task.context = {CONFIGURED: True}
return task
def get_current_state() -> typing.Dict:
"""
Get current execution context from task
:return: context
:rtype: :obj:`dict`
"""
task = asyncio.Task.current_task()
if task is None:
raise RuntimeError('Can be used only in Task context.')
context_ = getattr(task, 'context', None)
if context_ is None:
context_ = task.context = {}
return context_
def get_value(key, default=None):
"""
Get value from task
:param key:
:param default:
:return: value
"""
return get_current_state().get(key, default)
def check_value(key):
"""
Key in context?
:param key:
:return:
"""
return key in get_current_state()
def set_value(key, value):
"""
Set value
:param key:
:param value:
:return:
"""
get_current_state()[key] = value
def del_value(key):
"""
Remove value from context
:param key:
:return:
"""
del get_current_state()[key]
def update_state(data=None, **kwargs):
"""
Update multiple state items
:param data:
:param kwargs:
:return:
"""
if data is None:
data = {}
state = get_current_state()
state.update(data, **kwargs)
def check_configured():
"""
Check loop is configured
:return:
"""
return get_value(CONFIGURED)
class _Context:
"""
Other things for interactions with the execution context.
"""
def __getitem__(self, item):
return get_value(item)
def __setitem__(self, key, value):
set_value(key, value)
def __delitem__(self, key):
del_value(key)
@staticmethod
def get_context():
return get_current_state()
context = _Context()

View file

@ -182,6 +182,10 @@ class MessageToEditNotFound(MessageError):
match = 'message to edit not found' match = 'message to edit not found'
class MessageIsTooLong(MessageError):
match = 'message is too long'
class ToMuchMessages(MessageError): class ToMuchMessages(MessageError):
""" """
Will be raised when you try to send media group with more than 10 items. Will be raised when you try to send media group with more than 10 items.

View file

@ -6,7 +6,6 @@ from warnings import warn
from aiohttp import web from aiohttp import web
from . import context
from ..bot.api import log from ..bot.api import log
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
@ -104,6 +103,11 @@ class Executor:
self._freeze = False self._freeze = False
from aiogram.bot.bot import bot as ctx_bot
from aiogram.dispatcher import dispatcher as ctx_dp
ctx_bot.set(dispatcher.bot)
ctx_dp.set(dispatcher)
@property @property
def frozen(self): def frozen(self):
return self._freeze return self._freeze
@ -176,13 +180,13 @@ class Executor:
self._check_frozen() self._check_frozen()
self._freeze = True self._freeze = True
self.loop.set_task_factory(context.task_factory) # self.loop.set_task_factory(context.task_factory)
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler): def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
self._check_frozen() self._check_frozen()
self._freeze = True self._freeze = True
self.loop.set_task_factory(context.task_factory) # self.loop.set_task_factory(context.task_factory)
app = self._web_app app = self._web_app
if app is None: if app is None:
@ -203,6 +207,7 @@ class Executor:
for callback in self._on_startup_webhook: for callback in self._on_startup_webhook:
app.on_startup.append(functools.partial(_wrap_callback, callback)) app.on_startup.append(functools.partial(_wrap_callback, callback))
# for callback in self._on_shutdown_webhook: # for callback in self._on_shutdown_webhook:
# app.on_shutdown.append(functools.partial(_wrap_callback, callback)) # app.on_shutdown.append(functools.partial(_wrap_callback, callback))

View file

@ -1,26 +1,52 @@
import json import os
JSON = 'json'
RAPIDJSON = 'rapidjson'
UJSON = 'ujson'
try: try:
import ujson if 'DISABLE_UJSON' not in os.environ:
import ujson as json
mode = UJSON
def dumps(data):
return json.dumps(data, ensure_ascii=False)
else:
mode = JSON
except ImportError: except ImportError:
ujson = None mode = JSON
_use_ujson = True if ujson else False try:
if 'DISABLE_RAPIDJSON' not in os.environ:
import rapidjson as json
mode = RAPIDJSON
def disable_ujson(): def dumps(data):
global _use_ujson return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE,
_use_ujson = False datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
def dumps(data): def loads(data):
if _use_ujson: return json.loads(data, number_mode=json.NM_NATIVE,
return ujson.dumps(data) datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
return json.dumps(data)
else:
mode = JSON
except ImportError:
mode = JSON
if mode == JSON:
import json
def loads(data): def dumps(data):
if _use_ujson: return json.dumps(data, ensure_ascii=False)
return ujson.loads(data)
return json.loads(data)
def loads(data):
return json.loads(data)

View file

@ -1,5 +1,7 @@
import datetime import datetime
import secrets
from aiogram import types
from . import json from . import json
DEFAULT_FILTER = ['self', 'cls'] DEFAULT_FILTER = ['self', 'cls']
@ -56,3 +58,22 @@ def prepare_arg(value):
elif isinstance(value, datetime.datetime): elif isinstance(value, datetime.datetime):
return round(value.timestamp()) return round(value.timestamp())
return value return value
def prepare_file(payload, files, key, file):
if isinstance(file, str):
payload[key] = file
elif file is not None:
files[key] = file
def prepare_attachment(payload, files, key, file):
if isinstance(file, str):
payload[key] = file
elif isinstance(file, types.InputFile):
payload[key] = file.attach
files[file.attachment_key] = file.file
elif file is not None:
file_attach_name = secrets.token_urlsafe(16)
payload[key] = "attach://" + file_attach_name
files[file_attach_name] = file

View file

@ -1,6 +1,7 @@
-r requirements.txt -r requirements.txt
ujson>=1.35 ujson>=1.35
python-rapidjson>=0.6.3
emoji>=0.5.0 emoji>=0.5.0
pytest>=3.5.0 pytest>=3.5.0
pytest-asyncio>=0.8.0 pytest-asyncio>=0.8.0

View file

@ -25,16 +25,16 @@ Next step: interaction with bots starts with one command. Register your first co
.. code-block:: python3 .. code-block:: python3
@dp.message_handler(commands=['start', 'help']) @dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message): async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
Last step: run long polling. Last step: run long polling.
.. code-block:: python3 .. code-block:: python3
if __name__ == '__main__': if __name__ == '__main__':
executor.start_polling(dp) executor.start_polling(dp)
Summary Summary
------- -------
@ -48,9 +48,9 @@ Summary
bot = Bot(token='BOT TOKEN HERE') bot = Bot(token='BOT TOKEN HERE')
dp = Dispatcher(bot) dp = Dispatcher(bot)
@dp.message_handler(commands=['start', 'help']) @dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message): async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
if __name__ == '__main__': if __name__ == '__main__':
executor.start_polling(dp) executor.start_polling(dp)

View file

@ -1,8 +1,8 @@
name: py36 name: py37
channels: channels:
- conda-forge - conda-forge
dependencies: dependencies:
- python=3.6 - python=3.7
- sphinx=1.5.3 - sphinx=1.5.3
- sphinx_rtd_theme=0.2.4 - sphinx_rtd_theme=0.2.4
- pip - pip

View file

@ -29,6 +29,7 @@ async def send_message(user_id: int, text: str, disable_notification: bool = Fal
:param user_id: :param user_id:
:param text: :param text:
:param disable_notification:
:return: :return:
""" """
try: try:

View file

@ -5,18 +5,14 @@ Babel is required.
import asyncio import asyncio
import logging import logging
from aiogram import Bot, types from aiogram import Bot, Dispatcher, executor, md, types
from aiogram.dispatcher import Dispatcher
from aiogram.types import ParseMode
from aiogram.utils.executor import start_polling
from aiogram.utils.markdown import *
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop) bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.MARKDOWN)
dp = Dispatcher(bot) dp = Dispatcher(bot)
@ -24,14 +20,14 @@ dp = Dispatcher(bot)
async def check_language(message: types.Message): async def check_language(message: types.Message):
locale = message.from_user.locale locale = message.from_user.locale
await message.reply(text( await message.reply(md.text(
bold('Info about your language:'), md.bold('Info about your language:'),
text(' 🔸', bold('Code:'), italic(locale.locale)), md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)),
text(' 🔸', bold('Territory:'), italic(locale.territory or 'Unknown')), md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')),
text(' 🔸', bold('Language name:'), italic(locale.language_name)), md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)),
text(' 🔸', bold('English language name:'), italic(locale.english_name)), md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)),
sep='\n'), parse_mode=ParseMode.MARKDOWN) sep='\n'))
if __name__ == '__main__': if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True) executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -1,9 +1,7 @@
import asyncio import asyncio
import logging import logging
from aiogram import Bot, types from aiogram import Bot, types, Dispatcher, executor
from aiogram.dispatcher import Dispatcher
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -32,10 +30,4 @@ async def echo(message: types.Message):
if __name__ == '__main__': if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True) executor.start_polling(dp, loop=loop, skip_updates=True)
# Also you can use another execution method
# >>> try:
# >>> loop.run_until_complete(main())
# >>> except KeyboardInterrupt:
# >>> loop.stop()

View file

@ -1,47 +0,0 @@
from aiogram import Bot, types
from aiogram.contrib.middlewares.context import ContextMiddleware
from aiogram.dispatcher import Dispatcher
from aiogram.types import ParseMode
from aiogram.utils import markdown as md
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE'
bot = Bot(token=API_TOKEN)
dp = Dispatcher(bot)
# Setup Context middleware
data: ContextMiddleware = dp.middleware.setup(ContextMiddleware())
# Write custom filter
async def demo_filter(message: types.Message):
# Store some data in context
command = data['command'] = message.get_command() or ''
args = data['args'] = message.get_args() or ''
data['has_args'] = bool(args)
data['some_random_data'] = 42
return command != '/bad_command'
@dp.message_handler(demo_filter)
async def send_welcome(message: types.Message):
# Get data from context
# All of this is available only in current context and from current update object
# `data`- pseudo-alias for `ctx.get_update().conf['_context_data']`
command = data['command']
args = data['args']
rand = data['some_random_data']
has_args = data['has_args']
# Send as pre-formatted code block.
await message.reply(md.hpre(f"""command: {command}
args: {['Not available', 'available'][has_args]}: {args}
some random data: {rand}
message ID: {message.message_id}
message: {message.html_text}
"""), parse_mode=ParseMode.HTML)
if __name__ == '__main__':
start_polling(dp)

View file

@ -1,11 +1,13 @@
import asyncio import asyncio
from typing import Optional
from aiogram import Bot, types import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher from aiogram.dispatcher import FSMContext
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.types import ParseMode from aiogram.types import ParseMode
from aiogram.utils import executor from aiogram.utils import executor
from aiogram.utils.markdown import text, bold
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -17,10 +19,12 @@ bot = Bot(token=API_TOKEN, loop=loop)
storage = MemoryStorage() storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage) dp = Dispatcher(bot, storage=storage)
# States # States
AGE = 'process_age' class Form(StatesGroup):
NAME = 'process_name' name = State() # Will be represented in storage as 'Form:name'
GENDER = 'process_gender' age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start']) @dp.message_handler(commands=['start'])
@ -28,48 +32,41 @@ async def cmd_start(message: types.Message):
""" """
Conversation's entry point Conversation's entry point
""" """
# Get current state # Set state
state = dp.current_state(chat=message.chat.id, user=message.from_user.id) await Form.name.set()
# Update user's state
await state.set_state(NAME)
await message.reply("Hi there! What's your name?") await message.reply("Hi there! What's your name?")
# You can use state '*' if you need to handle all states # You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel']) @dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(state='*', func=lambda message: message.text.lower() == 'cancel') @dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message): async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
""" """
Allow user to cancel any action Allow user to cancel any action
""" """
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: if raw_state is None:
# Ignore command if user is not in any (defined) state return
if await state.get_state() is None:
return
# Otherwise cancel state and inform user about it # Cancel state and inform user about it
# And remove keyboard (just in case) await state.finish()
await state.reset_state(with_data=True) # And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=NAME) @dp.message_handler(state=Form.name)
async def process_name(message: types.Message): async def process_name(message: types.Message, state: FSMContext):
""" """
Process user name Process user name
""" """
# Save name to storage and go to next step await Form.next()
# You can use context manager await state.update_data(name=message.text)
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state:
await state.update_data(name=message.text)
await state.set_state(AGE)
await message.reply("How old are you?") await message.reply("How old are you?")
# Check age. Age gotta be digit # Check age. Age gotta be digit
@dp.message_handler(state=AGE, func=lambda message: not message.text.isdigit()) @dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def failed_process_age(message: types.Message): async def failed_process_age(message: types.Message):
""" """
If age is invalid If age is invalid
@ -77,12 +74,11 @@ async def failed_process_age(message: types.Message):
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(state=AGE, func=lambda message: message.text.isdigit()) @dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message): async def process_age(message: types.Message, state: FSMContext):
# Update state and data # Update state and data
with dp.current_state(chat=message.chat.id, user=message.from_user.id) as state: await Form.next()
await state.set_state(GENDER) await state.update_data(age=int(message.text))
await state.update_data(age=int(message.text))
# Configure ReplyKeyboardMarkup # Configure ReplyKeyboardMarkup
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
@ -92,7 +88,7 @@ async def process_age(message: types.Message):
await message.reply("What is your gender?", reply_markup=markup) await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(state=GENDER, func=lambda message: message.text not in ["Male", "Female", "Other"]) @dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def failed_process_gender(message: types.Message): async def failed_process_gender(message: types.Message):
""" """
In this example gender has to be one of: Male, Female, Other. In this example gender has to be one of: Male, Female, Other.
@ -100,10 +96,8 @@ async def failed_process_gender(message: types.Message):
return await message.reply("Bad gender name. Choose you gender from keyboard.") return await message.reply("Bad gender name. Choose you gender from keyboard.")
@dp.message_handler(state=GENDER) @dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message): async def process_gender(message: types.Message, state: FSMContext):
state = dp.current_state(chat=message.chat.id, user=message.from_user.id)
data = await state.get_data() data = await state.get_data()
data['gender'] = message.text data['gender'] = message.text
@ -111,10 +105,10 @@ async def process_gender(message: types.Message):
markup = types.ReplyKeyboardRemove() markup = types.ReplyKeyboardRemove()
# And send message # And send message
await bot.send_message(message.chat.id, text( await bot.send_message(message.chat.id, md.text(
text('Hi! Nice to meet you,', bold(data['name'])), md.text('Hi! Nice to meet you,', md.bold(data['name'])),
text('Age:', data['age']), md.text('Age:', data['age']),
text('Gender:', data['gender']), md.text('Gender:', data['gender']),
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN) sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
# Finish conversation # Finish conversation
@ -122,10 +116,5 @@ async def process_gender(message: types.Message):
await state.finish() await state.finish()
async def shutdown(dispatcher: Dispatcher):
await dispatcher.storage.close()
await dispatcher.storage.wait_closed()
if __name__ == '__main__': if __name__ == '__main__':
executor.start_polling(dp, loop=loop, skip_updates=True, on_shutdown=shutdown) executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -0,0 +1,126 @@
"""
This example is equals with 'finite_state_machine_example.py' but with FSM Middleware
Note that FSM Middleware implements the more simple methods for working with storage.
With that middleware all data from storage will be loaded before event will be processed
and data will be stored after processing the event.
"""
import asyncio
import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.utils import executor
API_TOKEN = 'BOT TOKEN HERE'
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop)
# For example use simple MemoryStorage for Dispatcher.
storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage)
dp.middleware.setup(FSMMiddleware())
# States
class Form(StatesGroup):
name = State() # Will be represented in storage as 'Form:name'
age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start'])
async def cmd_start(message: types.Message):
"""
Conversation's entry point
"""
# Set state
await Form.first()
await message.reply("Hi there! What's your name?")
# You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy):
"""
Allow user to cancel any action
"""
if state_data.state is None:
return
# Cancel state and inform user about it
del state_data.state
# And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=Form.name)
async def process_name(message: types.Message, state_data: FSMSStorageProxy):
"""
Process user name
"""
state_data.state = Form.age
state_data['name'] = message.text
await message.reply("How old are you?")
# Check age. Age gotta be digit
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def failed_process_age(message: types.Message):
"""
If age is invalid
"""
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message, state_data: FSMSStorageProxy):
# Update state and data
state_data.state = Form.gender
state_data['age'] = int(message.text)
# Configure ReplyKeyboardMarkup
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
markup.add("Male", "Female")
markup.add("Other")
await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def failed_process_gender(message: types.Message):
"""
In this example gender has to be one of: Male, Female, Other.
"""
return await message.reply("Bad gender name. Choose you gender from keyboard.")
@dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message, state_data: FSMSStorageProxy):
state_data['gender'] = message.text
# Remove keyboard
markup = types.ReplyKeyboardRemove()
# And send message
await bot.send_message(message.chat.id, md.text(
md.text('Hi! Nice to meet you,', md.bold(state_data['name'])),
md.text('Age:', state_data['age']),
md.text('Gender:', state_data['gender']),
sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN)
# Finish conversation
# WARNING! This method will destroy all data in storage for current user!
state_data.clear()
if __name__ == '__main__':
executor.start_polling(dp, loop=loop, skip_updates=True)

56
examples/i18n_example.py Normal file
View file

@ -0,0 +1,56 @@
"""
Internalize your bot
Step 1: extract texts
# pybabel extract i18n_example.py -o locales/mybot.pot
Step 2: create *.po files. For e.g. create en, ru, uk locales.
# echo {en,ru,uk} | xargs -n1 pybabel init -i locales/mybot.pot -d locales -D mybot -l
Step 3: translate texts
Step 4: compile translations
# pybabel compile -d locales -D mybot
Step 5: When you change the code of your bot you need to update po & mo files.
Step 5.1: regenerate pot file:
command from step 1
Step 5.2: update po files
# pybabel update -d locales -D mybot -i locales/mybot.pot
Step 5.3: update your translations
Step 5.4: compile mo files
command from step 4
"""
from pathlib import Path
from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.middlewares.i18n import I18nMiddleware
TOKEN = 'BOT TOKEN HERE'
I18N_DOMAIN = 'mybot'
BASE_DIR = Path(__file__).parent
LOCALES_DIR = BASE_DIR / 'locales'
bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML)
dp = Dispatcher(bot)
# Setup i18n middleware
i18n = I18nMiddleware(I18N_DOMAIN, LOCALES_DIR)
dp.middleware.setup(i18n)
# Alias for gettext method
_ = i18n.gettext
@dp.message_handler(commands=['start'])
async def cmd_start(message: types.Message):
# Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts.
await message.reply(_('Hello, <b>{user}</b>!').format(user=message.from_user.full_name))
@dp.message_handler(commands=['lang'])
async def cmd_lang(message: types.Message, locale):
await message.reply(_('Your current language: <i>{language}</i>').format(language=locale))
if __name__ == '__main__':
executor.start_polling(dp, skip_updates=True)

View file

@ -1,9 +1,7 @@
import asyncio import asyncio
import logging import logging
from aiogram import Bot, types from aiogram import Bot, types, Dispatcher, executor
from aiogram.dispatcher import Dispatcher
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -23,4 +21,4 @@ async def inline_echo(inline_query: types.InlineQuery):
if __name__ == '__main__': if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True) executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -0,0 +1,28 @@
# English translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n"
"Language-Team: en <LL@li.org>\n"
"Plural-Forms: nplurals=2; plural=(n != 1)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr ""
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr ""

View file

@ -0,0 +1,27 @@
# Translations template for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr ""
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr ""

View file

@ -0,0 +1,29 @@
# Russian translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: ru\n"
"Language-Team: ru <LL@li.org>\n"
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr "Привет, <b>{user}</b>!"
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr "Твой язык: <i>{language}</i>"

View file

@ -0,0 +1,29 @@
# Ukrainian translations for PROJECT.
# Copyright (C) 2018 ORGANIZATION
# This file is distributed under the same license as the PROJECT project.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: PROJECT VERSION\n"
"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n"
"POT-Creation-Date: 2018-06-30 03:50+0300\n"
"PO-Revision-Date: 2018-06-30 03:43+0300\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: uk\n"
"Language-Team: uk <LL@li.org>\n"
"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && "
"n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2)\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.6.0\n"
#: i18n_example.py:48
msgid "Hello, <b>{user}</b>!"
msgstr "Привіт, <b>{user}</b>!"
#: i18n_example.py:53
msgid "Your current language: <i>{language}</i>"
msgstr "Твоя мова: <i>{language}</i>"

View file

@ -1,9 +1,6 @@
import asyncio import asyncio
from aiogram import Bot, types from aiogram import Bot, Dispatcher, executor, filters, types
from aiogram.dispatcher import Dispatcher
from aiogram.types import ChatActions
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE' API_TOKEN = 'BOT TOKEN HERE'
@ -12,7 +9,7 @@ bot = Bot(token=API_TOKEN, loop=loop)
dp = Dispatcher(bot) dp = Dispatcher(bot)
@dp.message_handler(commands=['start']) @dp.message_handler(filters.CommandStart())
async def send_welcome(message: types.Message): async def send_welcome(message: types.Message):
# So... At first I want to send something like this: # So... At first I want to send something like this:
await message.reply("Do you want to see many pussies? Are you ready?") await message.reply("Do you want to see many pussies? Are you ready?")
@ -21,7 +18,7 @@ async def send_welcome(message: types.Message):
await asyncio.sleep(1) await asyncio.sleep(1)
# Good bots should send chat actions. Or not. # Good bots should send chat actions. Or not.
await ChatActions.upload_photo() await types.ChatActions.upload_photo()
# Create media group # Create media group
media = types.MediaGroup() media = types.MediaGroup()
@ -39,9 +36,8 @@ async def send_welcome(message: types.Message):
# media.attach_photo('<file_id>', 'cat-cat-cat.') # media.attach_photo('<file_id>', 'cat-cat-cat.')
# Done! Send media group # Done! Send media group
await bot.send_media_group(message.chat.id, media=media, await message.reply_media_group(media=media)
reply_to_message_id=message.message_id)
if __name__ == '__main__': if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True) executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -1,11 +1,10 @@
import asyncio import asyncio
from aiogram import Bot, types from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2 from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler
from aiogram.dispatcher.middlewares import BaseMiddleware from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils import context, executor
from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE' TOKEN = 'BOT TOKEN HERE'
@ -53,10 +52,10 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message: :param message:
""" """
# Get current handler # Get current handler
handler = context.get_value('handler') # handler = context.get_value('handler')
# Get dispatcher from context # Get dispatcher from context
dispatcher = ctx.get_dispatcher() dispatcher = Dispatcher.current()
# If handler was configured, get rate limit and key from handler # If handler was configured, get rate limit and key from handler
if handler: if handler:
@ -83,8 +82,8 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message: :param message:
:param throttled: :param throttled:
""" """
handler = context.get_value('handler') # handler = context.get_value('handler')
dispatcher = ctx.get_dispatcher() dispatcher = Dispatcher.current()
if handler: if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else: else:

View file

@ -4,7 +4,7 @@ from aiogram import Bot
from aiogram import types from aiogram import types
from aiogram.utils import executor from aiogram.utils import executor
from aiogram.dispatcher import Dispatcher from aiogram.dispatcher import Dispatcher
from aiogram.types.message import ContentType from aiogram.types.message import ContentTypes
BOT_TOKEN = 'BOT TOKEN HERE' BOT_TOKEN = 'BOT TOKEN HERE'
@ -86,7 +86,7 @@ async def checkout(pre_checkout_query: types.PreCheckoutQuery):
" try to pay again in a few minutes, we need a small rest.") " try to pay again in a few minutes, we need a small rest.")
@dp.message_handler(content_types=ContentType.SUCCESSFUL_PAYMENT) @dp.message_handler(content_types=ContentTypes.SUCCESSFUL_PAYMENT)
async def got_payment(message: types.Message): async def got_payment(message: types.Message):
await bot.send_message(message.chat.id, await bot.send_message(message.chat.id,
'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`' 'Hoooooray! Thanks for payment! We will proceed your order for `{} {}`'

View file

@ -18,7 +18,7 @@ PROXY_URL = 'http://PROXY_URL' # Or 'socks5://...'
# PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password') # PROXY_AUTH = aiohttp.BasicAuth(login='login', password='password')
# And add `proxy_auth=PROXY_AUTH` argument in line 25, like this: # And add `proxy_auth=PROXY_AUTH` argument in line 25, like this:
# >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH) # >>> bot = Bot(token=API_TOKEN, loop=loop, proxy=PROXY_URL, proxy_auth=PROXY_AUTH)
# Also you can use Socks5 proxy but you need manually install aiosocksy package. # Also you can use Socks5 proxy but you need manually install aiohttp_socks package.
# Get my ip URL # Get my ip URL
GET_IP_URL = 'http://bot.whatismyipaddress.com/' GET_IP_URL = 'http://bot.whatismyipaddress.com/'

View file

@ -9,7 +9,7 @@ from aiogram import Bot, types, Version
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher from aiogram.dispatcher import Dispatcher
from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage from aiogram.dispatcher.webhook import get_new_configured_app, SendMessage
from aiogram.types import ChatType, ParseMode, ContentType from aiogram.types import ChatType, ParseMode, ContentTypes
from aiogram.utils.markdown import hbold, bold, text, link from aiogram.utils.markdown import hbold, bold, text, link
TOKEN = 'BOT TOKEN HERE' TOKEN = 'BOT TOKEN HERE'
@ -31,7 +31,7 @@ WEBHOOK_URL = f"https://{WEBHOOK_HOST}:{WEBHOOK_PORT}{WEBHOOK_URL_PATH}"
WEBAPP_HOST = 'localhost' WEBAPP_HOST = 'localhost'
WEBAPP_PORT = 3001 WEBAPP_PORT = 3001
BAD_CONTENT = ContentType.PHOTO & ContentType.DOCUMENT & ContentType.STICKER & ContentType.AUDIO BAD_CONTENT = ContentTypes.PHOTO & ContentTypes.DOCUMENT & ContentTypes.STICKER & ContentTypes.AUDIO
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
bot = Bot(TOKEN, loop=loop) bot = Bot(TOKEN, loop=loop)

View file

@ -13,7 +13,7 @@ except ImportError: # pip >= 10.0.0
WORK_DIR = pathlib.Path(__file__).parent WORK_DIR = pathlib.Path(__file__).parent
# Check python version # Check python version
MINIMAL_PY_VERSION = (3, 6) MINIMAL_PY_VERSION = (3, 7)
if sys.version_info < MINIMAL_PY_VERSION: if sys.version_info < MINIMAL_PY_VERSION:
raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION)))) raise RuntimeError('aiogram works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION))))
@ -65,7 +65,7 @@ setup(
url='https://github.com/aiogram/aiogram', url='https://github.com/aiogram/aiogram',
license='MIT', license='MIT',
author='Alex Root Junior', author='Alex Root Junior',
requires_python='>=3.6', requires_python='>=3.7',
author_email='aiogram@illemius.xyz', author_email='aiogram@illemius.xyz',
description='Is a pretty simple and fully asynchronous library for Telegram Bot API', description='Is a pretty simple and fully asynchronous library for Telegram Bot API',
long_description=get_description(), long_description=get_description(),
@ -76,7 +76,7 @@ setup(
'Intended Audience :: Developers', 'Intended Audience :: Developers',
'Intended Audience :: System Administrators', 'Intended Audience :: System Administrators',
'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7',
'Topic :: Software Development :: Libraries :: Application Frameworks', 'Topic :: Software Development :: Libraries :: Application Frameworks',
], ],
install_requires=get_requirements() install_requires=get_requirements()

102
tests/states_group.py Normal file
View file

@ -0,0 +1,102 @@
import pytest
from aiogram.dispatcher.filters.state import State, StatesGroup, any_state, default_state
class MyGroup(StatesGroup):
state = State()
state_1 = State()
state_2 = State()
class MySubGroup(StatesGroup):
sub_state = State()
sub_state_1 = State()
sub_state_2 = State()
in_custom_group = State(group_name='custom_group')
class NewGroup(StatesGroup):
spam = State()
renamed_state = State(state='spam_state')
alone_state = State('alone')
alone_in_group = State('alone', group_name='home')
def test_default_state():
assert default_state.state is None
def test_any_state():
assert any_state.state == '*'
def test_alone_state():
assert alone_state.state == '@:alone'
assert alone_in_group.state == 'home:alone'
def test_group_names():
assert MyGroup.__group_name__ == 'MyGroup'
assert MyGroup.__full_group_name__ == 'MyGroup'
assert MyGroup.MySubGroup.__group_name__ == 'MySubGroup'
assert MyGroup.MySubGroup.__full_group_name__ == 'MyGroup.MySubGroup'
assert MyGroup.MySubGroup.NewGroup.__group_name__ == 'NewGroup'
assert MyGroup.MySubGroup.NewGroup.__full_group_name__ == 'MyGroup.MySubGroup.NewGroup'
def test_custom_group_in_group():
assert MyGroup.MySubGroup.in_custom_group.state == 'custom_group:in_custom_group'
def test_custom_state_name_in_group():
assert MyGroup.MySubGroup.NewGroup.renamed_state.state == 'MyGroup.MySubGroup.NewGroup:spam_state'
def test_group_states_names():
assert len(MyGroup.states) == 3
assert len(MyGroup.all_states) == 9
assert MyGroup.states_names == ('MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2')
assert MyGroup.MySubGroup.states_names == (
'MyGroup.MySubGroup:sub_state', 'MyGroup.MySubGroup:sub_state_1', 'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group')
assert MyGroup.MySubGroup.NewGroup.states_names == (
'MyGroup.MySubGroup.NewGroup:spam', 'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.all_states_names == (
'MyGroup:state', 'MyGroup:state_1', 'MyGroup:state_2',
'MyGroup.MySubGroup:sub_state',
'MyGroup.MySubGroup:sub_state_1',
'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group',
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.MySubGroup.all_states_names == (
'MyGroup.MySubGroup:sub_state',
'MyGroup.MySubGroup:sub_state_1',
'MyGroup.MySubGroup:sub_state_2',
'custom_group:in_custom_group',
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
assert MyGroup.MySubGroup.NewGroup.all_states_names == (
'MyGroup.MySubGroup.NewGroup:spam',
'MyGroup.MySubGroup.NewGroup:spam_state')
def test_root_element():
root = MyGroup.MySubGroup.NewGroup.spam.get_root()
assert issubclass(root, StatesGroup)
assert root == MyGroup
assert root == MyGroup.state.get_root()
assert root == MyGroup.MySubGroup.get_root()
with pytest.raises(RuntimeError):
any_state.get_root()

View file

@ -1,5 +1,5 @@
[tox] [tox]
envlist = py36 envlist = py37
[testenv] [testenv]
deps = -rdev_requirements.txt deps = -rdev_requirements.txt