mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Format code with black, autopep8 and isort
This commit fixes the style issues introduced in7ca09aaaccording to the output from black, autopep8 and isort. Details:520006ec-d635-4384-b8eb-d844943c6f8c/
This commit is contained in:
parent
7ca09aac1b
commit
0fe5a92d4e
4 changed files with 288 additions and 161 deletions
|
|
@ -10,7 +10,7 @@ from ..utils import exceptions, json
|
||||||
from ..utils.helper import Helper, HelperMode, Item
|
from ..utils.helper import Helper, HelperMode, Item
|
||||||
|
|
||||||
# Main aiogram logger
|
# Main aiogram logger
|
||||||
log = logging.getLogger('aiogram')
|
log = logging.getLogger("aiogram")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -43,7 +43,7 @@ class TelegramAPIServer:
|
||||||
return self.file.format(token=token, path=path)
|
return self.file.format(token=token, path=path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_base(cls, base: str) -> 'TelegramAPIServer':
|
def from_base(cls, base: str) -> "TelegramAPIServer":
|
||||||
base = base.rstrip("/")
|
base = base.rstrip("/")
|
||||||
return cls(
|
return cls(
|
||||||
base=f"{base}/bot{{token}}/{{method}}",
|
base=f"{base}/bot{{token}}/{{method}}",
|
||||||
|
|
@ -62,17 +62,19 @@ def check_token(token: str) -> bool:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not isinstance(token, str):
|
if not isinstance(token, str):
|
||||||
message = (f"Token is invalid! "
|
message = (
|
||||||
f"It must be 'str' type instead of {type(token)} type.")
|
f"Token is invalid! "
|
||||||
|
f"It must be 'str' type instead of {type(token)} type."
|
||||||
|
)
|
||||||
raise exceptions.ValidationError(message)
|
raise exceptions.ValidationError(message)
|
||||||
|
|
||||||
if any(x.isspace() for x in token):
|
if any(x.isspace() for x in token):
|
||||||
message = "Token is invalid! It can't contains spaces."
|
message = "Token is invalid! It can't contains spaces."
|
||||||
raise exceptions.ValidationError(message)
|
raise exceptions.ValidationError(message)
|
||||||
|
|
||||||
left, sep, right = token.partition(':')
|
left, sep, right = token.partition(":")
|
||||||
if (not sep) or (not left.isdigit()) or (not right):
|
if (not sep) or (not left.isdigit()) or (not right):
|
||||||
raise exceptions.ValidationError('Token is invalid!')
|
raise exceptions.ValidationError("Token is invalid!")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -94,19 +96,22 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
|
||||||
"""
|
"""
|
||||||
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
|
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
|
||||||
|
|
||||||
if content_type != 'application/json':
|
if content_type != "application/json":
|
||||||
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
|
raise exceptions.NetworkError(
|
||||||
|
f'Invalid response with content type {content_type}: "{body}"'
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_json = json.loads(body)
|
result_json = json.loads(body)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
result_json = {}
|
result_json = {}
|
||||||
|
|
||||||
description = result_json.get('description') or body
|
description = result_json.get("description") or body
|
||||||
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
parameters = types.ResponseParameters(
|
||||||
|
**result_json.get("parameters", {}) or {})
|
||||||
|
|
||||||
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
|
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
|
||||||
return result_json.get('result')
|
return result_json.get("result")
|
||||||
if parameters.retry_after:
|
if parameters.retry_after:
|
||||||
raise exceptions.RetryAfter(parameters.retry_after)
|
raise exceptions.RetryAfter(parameters.retry_after)
|
||||||
if parameters.migrate_to_chat_id:
|
if parameters.migrate_to_chat_id:
|
||||||
|
|
@ -120,26 +125,33 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
|
||||||
elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
|
elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
|
||||||
exceptions.Unauthorized.detect(description)
|
exceptions.Unauthorized.detect(description)
|
||||||
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
||||||
raise exceptions.NetworkError('File too large for uploading. '
|
raise exceptions.NetworkError(
|
||||||
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
"File too large for uploading. "
|
||||||
|
"Check telegram api limits https://core.telegram.org/bots/api#senddocument"
|
||||||
|
)
|
||||||
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||||
if 'restart' in description:
|
if "restart" in description:
|
||||||
raise exceptions.RestartingTelegram()
|
raise exceptions.RestartingTelegram()
|
||||||
raise exceptions.TelegramAPIError(description)
|
raise exceptions.TelegramAPIError(description)
|
||||||
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
|
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
|
||||||
|
|
||||||
|
|
||||||
async def make_request(session, server, token, method, data=None, files=None, **kwargs):
|
async def make_request(session, server, token, method, data=None, files=None, **kwargs):
|
||||||
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
|
log.debug('Make request: "%s" with data: "%r" and files "%r"',
|
||||||
|
method, data, files)
|
||||||
|
|
||||||
url = server.api_url(token=token, method=method)
|
url = server.api_url(token=token, method=method)
|
||||||
|
|
||||||
req = compose_data(data, files)
|
req = compose_data(data, files)
|
||||||
try:
|
try:
|
||||||
async with session.post(url, data=req, **kwargs) as response:
|
async with session.post(url, data=req, **kwargs) as response:
|
||||||
return check_result(method, response.content_type, response.status, await response.text())
|
return check_result(
|
||||||
|
method, response.content_type, response.status, await response.text()
|
||||||
|
)
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
raise exceptions.NetworkError(
|
||||||
|
f"aiohttp client throws an error: {e.__class__.__name__}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def guess_filename(obj):
|
def guess_filename(obj):
|
||||||
|
|
@ -149,8 +161,8 @@ def guess_filename(obj):
|
||||||
:param obj:
|
:param obj:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
name = getattr(obj, 'name', None)
|
name = getattr(obj, "name", None)
|
||||||
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
|
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
|
||||||
return os.path.basename(name)
|
return os.path.basename(name)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -174,7 +186,9 @@ def compose_data(params=None, files=None):
|
||||||
if len(f) == 2:
|
if len(f) == 2:
|
||||||
filename, fileobj = f
|
filename, fileobj = f
|
||||||
else:
|
else:
|
||||||
raise ValueError('Tuple must have exactly 2 elements: filename, fileobj')
|
raise ValueError(
|
||||||
|
"Tuple must have exactly 2 elements: filename, fileobj"
|
||||||
|
)
|
||||||
elif isinstance(f, types.InputFile):
|
elif isinstance(f, types.InputFile):
|
||||||
filename, fileobj = f.filename, f.file
|
filename, fileobj = f.filename, f.file
|
||||||
else:
|
else:
|
||||||
|
|
@ -191,6 +205,7 @@ class Methods(Helper):
|
||||||
|
|
||||||
List is updated to Bot API 5.0
|
List is updated to Bot API 5.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mode = HelperMode.lowerCamelCase
|
mode = HelperMode.lowerCamelCase
|
||||||
|
|
||||||
# Getting Updates
|
# Getting Updates
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,15 @@ ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str,
|
||||||
def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]:
|
def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]:
|
||||||
# since "str" is also an "Iterable", we have to check for it first
|
# since "str" is also an "Iterable", we have to check for it first
|
||||||
if isinstance(chat_id, str):
|
if isinstance(chat_id, str):
|
||||||
return {int(chat_id), }
|
return {
|
||||||
|
int(chat_id),
|
||||||
|
}
|
||||||
if isinstance(chat_id, Iterable):
|
if isinstance(chat_id, Iterable):
|
||||||
return {int(item) for (item) in chat_id}
|
return {int(item) for (item) in chat_id}
|
||||||
# the last possible type is a single "int"
|
# the last possible type is a single "int"
|
||||||
return {chat_id, }
|
return {
|
||||||
|
chat_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Command(Filter):
|
class Command(Filter):
|
||||||
|
|
@ -34,11 +38,14 @@ class Command(Filter):
|
||||||
By default this filter is registered for messages and edited messages handlers.
|
By default this filter is registered for messages and edited messages handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, commands: Union[Iterable, str],
|
def __init__(
|
||||||
prefixes: Union[Iterable, str] = '/',
|
self,
|
||||||
ignore_case: bool = True,
|
commands: Union[Iterable, str],
|
||||||
ignore_mention: bool = False,
|
prefixes: Union[Iterable, str] = "/",
|
||||||
ignore_caption: bool = True):
|
ignore_case: bool = True,
|
||||||
|
ignore_mention: bool = False,
|
||||||
|
ignore_caption: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Filter can be initialized from filters factory or by simply creating instance of this class.
|
Filter can be initialized from filters factory or by simply creating instance of this class.
|
||||||
|
|
||||||
|
|
@ -69,7 +76,8 @@ class Command(Filter):
|
||||||
if isinstance(commands, str):
|
if isinstance(commands, str):
|
||||||
commands = (commands,)
|
commands = (commands,)
|
||||||
|
|
||||||
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
self.commands = list(map(str.lower, commands)
|
||||||
|
) if ignore_case else commands
|
||||||
self.prefixes = prefixes
|
self.prefixes = prefixes
|
||||||
self.ignore_case = ignore_case
|
self.ignore_case = ignore_case
|
||||||
self.ignore_mention = ignore_mention
|
self.ignore_mention = ignore_mention
|
||||||
|
|
@ -91,36 +99,61 @@ class Command(Filter):
|
||||||
:return: config or empty dict
|
:return: config or empty dict
|
||||||
"""
|
"""
|
||||||
config = {}
|
config = {}
|
||||||
if 'commands' in full_config:
|
if "commands" in full_config:
|
||||||
config['commands'] = full_config.pop('commands')
|
config["commands"] = full_config.pop("commands")
|
||||||
if config and 'commands_prefix' in full_config:
|
if config and "commands_prefix" in full_config:
|
||||||
config['prefixes'] = full_config.pop('commands_prefix')
|
config["prefixes"] = full_config.pop("commands_prefix")
|
||||||
if config and 'commands_ignore_mention' in full_config:
|
if config and "commands_ignore_mention" in full_config:
|
||||||
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
config["ignore_mention"] = full_config.pop(
|
||||||
if config and 'commands_ignore_caption' in full_config:
|
"commands_ignore_mention")
|
||||||
config['ignore_caption'] = full_config.pop('commands_ignore_caption')
|
if config and "commands_ignore_caption" in full_config:
|
||||||
|
config["ignore_caption"] = full_config.pop(
|
||||||
|
"commands_ignore_caption")
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def check(self, message: types.Message):
|
async def check(self, message: types.Message):
|
||||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention, self.ignore_caption)
|
return await self.check_command(
|
||||||
|
message,
|
||||||
|
self.commands,
|
||||||
|
self.prefixes,
|
||||||
|
self.ignore_case,
|
||||||
|
self.ignore_mention,
|
||||||
|
self.ignore_caption,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def check_command(cls, message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False, ignore_caption=True):
|
async def check_command(
|
||||||
text = message.text or (message.caption if not ignore_caption else None)
|
cls,
|
||||||
|
message: types.Message,
|
||||||
|
commands,
|
||||||
|
prefixes,
|
||||||
|
ignore_case=True,
|
||||||
|
ignore_mention=False,
|
||||||
|
ignore_caption=True,
|
||||||
|
):
|
||||||
|
text = message.text or (
|
||||||
|
message.caption if not ignore_caption else None)
|
||||||
if not text:
|
if not text:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
full_command = text.split()[0]
|
full_command = text.split()[0]
|
||||||
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
prefix, (command, _,
|
||||||
|
mention) = full_command[0], full_command[1:].partition("@")
|
||||||
|
|
||||||
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
if (
|
||||||
|
not ignore_mention
|
||||||
|
and mention
|
||||||
|
and (await message.bot.me).username.lower() != mention.lower()
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
if prefix not in prefixes:
|
if prefix not in prefixes:
|
||||||
return False
|
return False
|
||||||
if (command.lower() if ignore_case else command) not in commands:
|
if (command.lower() if ignore_case else command) not in commands:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return {'command': cls.CommandObj(command=command, prefix=prefix, mention=mention)}
|
return {
|
||||||
|
"command": cls.CommandObj(command=command, prefix=prefix, mention=mention)
|
||||||
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandObj:
|
class CommandObj:
|
||||||
|
|
@ -131,9 +164,9 @@ class Command(Filter):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Command prefix"""
|
"""Command prefix"""
|
||||||
prefix: str = '/'
|
prefix: str = "/"
|
||||||
"""Command without prefix and mention"""
|
"""Command without prefix and mention"""
|
||||||
command: str = ''
|
command: str = ""
|
||||||
"""Mention (if available)"""
|
"""Mention (if available)"""
|
||||||
mention: str = None
|
mention: str = None
|
||||||
"""Command argument"""
|
"""Command argument"""
|
||||||
|
|
@ -157,9 +190,9 @@ class Command(Filter):
|
||||||
"""
|
"""
|
||||||
line = self.prefix + self.command
|
line = self.prefix + self.command
|
||||||
if self.mentioned:
|
if self.mentioned:
|
||||||
line += '@' + self.mention
|
line += "@" + self.mention
|
||||||
if self.args:
|
if self.args:
|
||||||
line += ' ' + self.args
|
line += " " + self.args
|
||||||
return line
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -168,9 +201,12 @@ class CommandStart(Command):
|
||||||
This filter based on :obj:`Command` filter but can handle only ``/start`` command.
|
This filter based on :obj:`Command` filter but can handle only ``/start`` command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None,
|
self,
|
||||||
encoded: bool = False):
|
deep_link: typing.Optional[typing.Union[str,
|
||||||
|
typing.Pattern[str]]] = None,
|
||||||
|
encoded: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments.
|
Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments.
|
||||||
|
|
||||||
|
|
@ -183,7 +219,7 @@ class CommandStart(Command):
|
||||||
:param deep_link: string or compiled regular expression (by ``re.compile(...)``).
|
:param deep_link: string or compiled regular expression (by ``re.compile(...)``).
|
||||||
:param encoded: set True if you're waiting for encoded payload (default - False).
|
:param encoded: set True if you're waiting for encoded payload (default - False).
|
||||||
"""
|
"""
|
||||||
super().__init__(['start'])
|
super().__init__(["start"])
|
||||||
self.deep_link = deep_link
|
self.deep_link = deep_link
|
||||||
self.encoded = encoded
|
self.encoded = encoded
|
||||||
|
|
||||||
|
|
@ -195,17 +231,22 @@ class CommandStart(Command):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
from ...utils.deep_linking import decode_payload
|
from ...utils.deep_linking import decode_payload
|
||||||
|
|
||||||
check = await super().check(message)
|
check = await super().check(message)
|
||||||
|
|
||||||
if check and self.deep_link is not None:
|
if check and self.deep_link is not None:
|
||||||
payload = decode_payload(message.get_args()) if self.encoded else message.get_args()
|
payload = (
|
||||||
|
decode_payload(message.get_args())
|
||||||
|
if self.encoded
|
||||||
|
else message.get_args()
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(self.deep_link, typing.Pattern):
|
if not isinstance(self.deep_link, typing.Pattern):
|
||||||
return False if payload != self.deep_link else {'deep_link': payload}
|
return False if payload != self.deep_link else {"deep_link": payload}
|
||||||
|
|
||||||
match = self.deep_link.match(payload)
|
match = self.deep_link.match(payload)
|
||||||
if match:
|
if match:
|
||||||
return {'deep_link': match}
|
return {"deep_link": match}
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return check
|
return check
|
||||||
|
|
@ -217,7 +258,7 @@ class CommandHelp(Command):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(['help'])
|
super().__init__(["help"])
|
||||||
|
|
||||||
|
|
||||||
class CommandSettings(Command):
|
class CommandSettings(Command):
|
||||||
|
|
@ -226,7 +267,7 @@ class CommandSettings(Command):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(['settings'])
|
super().__init__(["settings"])
|
||||||
|
|
||||||
|
|
||||||
class CommandPrivacy(Command):
|
class CommandPrivacy(Command):
|
||||||
|
|
@ -235,7 +276,7 @@ class CommandPrivacy(Command):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(['privacy'])
|
super().__init__(["privacy"])
|
||||||
|
|
||||||
|
|
||||||
class Text(Filter):
|
class Text(Filter):
|
||||||
|
|
@ -244,18 +285,27 @@ class Text(Filter):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_default_params = (
|
_default_params = (
|
||||||
('text', 'equals'),
|
("text", "equals"),
|
||||||
('text_contains', 'contains'),
|
("text_contains", "contains"),
|
||||||
('text_startswith', 'startswith'),
|
("text_startswith", "startswith"),
|
||||||
('text_endswith', 'endswith'),
|
("text_endswith", "endswith"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
self,
|
||||||
contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
equals: Optional[Union[str, LazyProxy,
|
||||||
startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
Iterable[Union[str, LazyProxy]]]] = None,
|
||||||
endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
contains: Optional[
|
||||||
ignore_case=False):
|
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||||
|
] = None,
|
||||||
|
startswith: Optional[
|
||||||
|
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||||
|
] = None,
|
||||||
|
endswith: Optional[
|
||||||
|
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||||
|
] = None,
|
||||||
|
ignore_case=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Check text for one of pattern. Only one mode can be used in one filter.
|
Check text for one of pattern. Only one mode can be used in one filter.
|
||||||
In every pattern, a single string is treated as a list with 1 element.
|
In every pattern, a single string is treated as a list with 1 element.
|
||||||
|
|
@ -267,13 +317,22 @@ class Text(Filter):
|
||||||
:param ignore_case: case insensitive
|
:param ignore_case: case insensitive
|
||||||
"""
|
"""
|
||||||
# Only one mode can be used. check it.
|
# Only one mode can be used. check it.
|
||||||
check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith)))
|
check = sum(
|
||||||
|
map(lambda s: s is not None, (equals, contains, startswith, endswith))
|
||||||
|
)
|
||||||
if check > 1:
|
if check > 1:
|
||||||
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
args = "' and '".join(
|
||||||
('contains', contains),
|
[
|
||||||
('startswith', startswith),
|
arg[0]
|
||||||
('endswith', endswith)
|
for arg in [
|
||||||
] if arg[1] is not None])
|
("equals", equals),
|
||||||
|
("contains", contains),
|
||||||
|
("startswith", startswith),
|
||||||
|
("endswith", endswith),
|
||||||
|
]
|
||||||
|
if arg[1] is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||||
if check == 0:
|
if check == 0:
|
||||||
raise ValueError(f"No one mode is specified!")
|
raise ValueError(f"No one mode is specified!")
|
||||||
|
|
@ -297,7 +356,7 @@ class Text(Filter):
|
||||||
|
|
||||||
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
||||||
if isinstance(obj, Message):
|
if isinstance(obj, Message):
|
||||||
text = obj.text or obj.caption or ''
|
text = obj.text or obj.caption or ""
|
||||||
if not text and obj.poll:
|
if not text and obj.poll:
|
||||||
text = obj.poll.question
|
text = obj.poll.question
|
||||||
elif isinstance(obj, CallbackQuery):
|
elif isinstance(obj, CallbackQuery):
|
||||||
|
|
@ -311,7 +370,10 @@ class Text(Filter):
|
||||||
|
|
||||||
if self.ignore_case:
|
if self.ignore_case:
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
_pre_process_func = lambda s: str(s).lower()
|
|
||||||
|
def _pre_process_func(s):
|
||||||
|
return str(s).lower()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_pre_process_func = str
|
_pre_process_func = str
|
||||||
|
|
||||||
|
|
@ -344,7 +406,7 @@ class HashTag(Filter):
|
||||||
|
|
||||||
def __init__(self, hashtags=None, cashtags=None):
|
def __init__(self, hashtags=None, cashtags=None):
|
||||||
if not hashtags and not cashtags:
|
if not hashtags and not cashtags:
|
||||||
raise ValueError('No one hashtag or cashtag is specified!')
|
raise ValueError("No one hashtag or cashtag is specified!")
|
||||||
|
|
||||||
if hashtags is None:
|
if hashtags is None:
|
||||||
hashtags = []
|
hashtags = []
|
||||||
|
|
@ -364,10 +426,10 @@ class HashTag(Filter):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: Dict[str, Any]):
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
config = {}
|
config = {}
|
||||||
if 'hashtags' in full_config:
|
if "hashtags" in full_config:
|
||||||
config['hashtags'] = full_config.pop('hashtags')
|
config["hashtags"] = full_config.pop("hashtags")
|
||||||
if 'cashtags' in full_config:
|
if "cashtags" in full_config:
|
||||||
config['cashtags'] = full_config.pop('cashtags')
|
config["cashtags"] = full_config.pop("cashtags")
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def check(self, message: types.Message):
|
async def check(self, message: types.Message):
|
||||||
|
|
@ -381,9 +443,13 @@ class HashTag(Filter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
hashtags, cashtags = self._get_tags(text, entities)
|
hashtags, cashtags = self._get_tags(text, entities)
|
||||||
if self.hashtags and set(hashtags) & set(self.hashtags) \
|
if (
|
||||||
or self.cashtags and set(cashtags) & set(self.cashtags):
|
self.hashtags
|
||||||
return {'hashtags': hashtags, 'cashtags': cashtags}
|
and set(hashtags) & set(self.hashtags)
|
||||||
|
or self.cashtags
|
||||||
|
and set(cashtags) & set(self.cashtags)
|
||||||
|
):
|
||||||
|
return {"hashtags": hashtags, "cashtags": cashtags}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_tags(text, entities):
|
def _get_tags(text, entities):
|
||||||
|
|
@ -392,11 +458,11 @@ class HashTag(Filter):
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if entity.type == types.MessageEntityType.HASHTAG:
|
if entity.type == types.MessageEntityType.HASHTAG:
|
||||||
value = entity.get_text(text).lstrip('#')
|
value = entity.get_text(text).lstrip("#")
|
||||||
hashtags.append(value)
|
hashtags.append(value)
|
||||||
|
|
||||||
elif entity.type == types.MessageEntityType.CASHTAG:
|
elif entity.type == types.MessageEntityType.CASHTAG:
|
||||||
value = entity.get_text(text).lstrip('$')
|
value = entity.get_text(text).lstrip("$")
|
||||||
cashtags.append(value)
|
cashtags.append(value)
|
||||||
|
|
||||||
return hashtags, cashtags
|
return hashtags, cashtags
|
||||||
|
|
@ -414,12 +480,12 @@ class Regexp(Filter):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: Dict[str, Any]):
|
def validate(cls, full_config: Dict[str, Any]):
|
||||||
if 'regexp' in full_config:
|
if "regexp" in full_config:
|
||||||
return {'regexp': full_config.pop('regexp')}
|
return {"regexp": full_config.pop("regexp")}
|
||||||
|
|
||||||
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
||||||
if isinstance(obj, Message):
|
if isinstance(obj, Message):
|
||||||
content = obj.text or obj.caption or ''
|
content = obj.text or obj.caption or ""
|
||||||
if not content and obj.poll:
|
if not content and obj.poll:
|
||||||
content = obj.poll.question
|
content = obj.poll.question
|
||||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||||
|
|
@ -434,7 +500,7 @@ class Regexp(Filter):
|
||||||
match = self.regexp.search(content)
|
match = self.regexp.search(content)
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
return {'regexp': match}
|
return {"regexp": match}
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -443,17 +509,20 @@ class RegexpCommandsFilter(BoundFilter):
|
||||||
Check commands by regexp in message
|
Check commands by regexp in message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = 'regexp_commands'
|
key = "regexp_commands"
|
||||||
|
|
||||||
def __init__(self, regexp_commands):
|
def __init__(self, regexp_commands):
|
||||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
self.regexp_commands = [
|
||||||
|
re.compile(command, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
for command in regexp_commands
|
||||||
|
]
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
if not message.is_command():
|
if not message.is_command():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
command = message.text.split()[0][1:]
|
command = message.text.split()[0][1:]
|
||||||
command, _, mention = command.partition('@')
|
command, _, mention = command.partition("@")
|
||||||
|
|
||||||
if mention and mention != (await message.bot.me).username:
|
if mention and mention != (await message.bot.me).username:
|
||||||
return False
|
return False
|
||||||
|
|
@ -461,7 +530,7 @@ class RegexpCommandsFilter(BoundFilter):
|
||||||
for command in self.regexp_commands:
|
for command in self.regexp_commands:
|
||||||
search = command.search(message.text)
|
search = command.search(message.text)
|
||||||
if search:
|
if search:
|
||||||
return {'regexp_command': search}
|
return {"regexp_command": search}
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,7 +539,7 @@ class ContentTypeFilter(BoundFilter):
|
||||||
Check message content type
|
Check message content type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = 'content_types'
|
key = "content_types"
|
||||||
required = True
|
required = True
|
||||||
default = types.ContentTypes.TEXT
|
default = types.ContentTypes.TEXT
|
||||||
|
|
||||||
|
|
@ -480,8 +549,10 @@ class ContentTypeFilter(BoundFilter):
|
||||||
self.content_types = content_types
|
self.content_types = content_types
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
return types.ContentType.ANY in self.content_types or \
|
return (
|
||||||
message.content_type in self.content_types
|
types.ContentType.ANY in self.content_types
|
||||||
|
or message.content_type in self.content_types
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IsSenderContact(BoundFilter):
|
class IsSenderContact(BoundFilter):
|
||||||
|
|
@ -491,7 +562,8 @@ class IsSenderContact(BoundFilter):
|
||||||
`is_sender_contact=True` - contact matches the sender
|
`is_sender_contact=True` - contact matches the sender
|
||||||
`is_sender_contact=False` - result will be inverted
|
`is_sender_contact=False` - result will be inverted
|
||||||
"""
|
"""
|
||||||
key = 'is_sender_contact'
|
|
||||||
|
key = "is_sender_contact"
|
||||||
|
|
||||||
def __init__(self, is_sender_contact: bool):
|
def __init__(self, is_sender_contact: bool):
|
||||||
self.is_sender_contact = is_sender_contact
|
self.is_sender_contact = is_sender_contact
|
||||||
|
|
@ -509,10 +581,11 @@ class StateFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Check user state
|
Check user state
|
||||||
"""
|
"""
|
||||||
key = 'state'
|
|
||||||
|
key = "state"
|
||||||
required = True
|
required = True
|
||||||
|
|
||||||
ctx_state = ContextVar('user_state')
|
ctx_state = ContextVar("user_state")
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||||
|
|
@ -520,7 +593,9 @@ class StateFilter(BoundFilter):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
states = []
|
states = []
|
||||||
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||||
state = [state, ]
|
state = [
|
||||||
|
state,
|
||||||
|
]
|
||||||
for item in state:
|
for item in state:
|
||||||
if isinstance(item, State):
|
if isinstance(item, State):
|
||||||
states.append(item.state)
|
states.append(item.state)
|
||||||
|
|
@ -532,11 +607,13 @@ class StateFilter(BoundFilter):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_target(obj):
|
def get_target(obj):
|
||||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
return getattr(getattr(obj, "chat", None), "id", None), getattr(
|
||||||
|
getattr(obj, "from_user", None), "id", None
|
||||||
|
)
|
||||||
|
|
||||||
async def check(self, obj):
|
async def check(self, obj):
|
||||||
if '*' in self.states:
|
if "*" in self.states:
|
||||||
return {'state': self.dispatcher.current_state()}
|
return {"state": self.dispatcher.current_state()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = self.ctx_state.get()
|
state = self.ctx_state.get()
|
||||||
|
|
@ -547,11 +624,14 @@ class StateFilter(BoundFilter):
|
||||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
||||||
self.ctx_state.set(state)
|
self.ctx_state.set(state)
|
||||||
if state in self.states:
|
if state in self.states:
|
||||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
return {
|
||||||
|
"state": self.dispatcher.current_state(),
|
||||||
|
"raw_state": state,
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if state in self.states:
|
if state in self.states:
|
||||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
return {"state": self.dispatcher.current_state(), "raw_state": state}
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -561,7 +641,7 @@ class ExceptionsFilter(BoundFilter):
|
||||||
Filter for exceptions
|
Filter for exceptions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = 'exception'
|
key = "exception"
|
||||||
|
|
||||||
def __init__(self, exception):
|
def __init__(self, exception):
|
||||||
self.exception = exception
|
self.exception = exception
|
||||||
|
|
@ -576,10 +656,11 @@ class ExceptionsFilter(BoundFilter):
|
||||||
|
|
||||||
|
|
||||||
class IDFilter(Filter):
|
class IDFilter(Filter):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
user_id: Optional[ChatIDArgumentType] = None,
|
self,
|
||||||
chat_id: Optional[ChatIDArgumentType] = None,
|
user_id: Optional[ChatIDArgumentType] = None,
|
||||||
):
|
chat_id: Optional[ChatIDArgumentType] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param user_id:
|
:param user_id:
|
||||||
:param chat_id:
|
:param chat_id:
|
||||||
|
|
@ -597,13 +678,15 @@ class IDFilter(Filter):
|
||||||
self.chat_id = extract_chat_ids(chat_id)
|
self.chat_id = extract_chat_ids(chat_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(
|
||||||
|
cls, full_config: typing.Dict[str, typing.Any]
|
||||||
|
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
result = {}
|
result = {}
|
||||||
if 'user_id' in full_config:
|
if "user_id" in full_config:
|
||||||
result['user_id'] = full_config.pop('user_id')
|
result["user_id"] = full_config.pop("user_id")
|
||||||
|
|
||||||
if 'chat_id' in full_config:
|
if "chat_id" in full_config:
|
||||||
result['chat_id'] = full_config.pop('chat_id')
|
result["chat_id"] = full_config.pop("chat_id")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -658,7 +741,9 @@ class AdminFilter(Filter):
|
||||||
self._chat_ids = extract_chat_ids(is_chat_admin)
|
self._chat_ids = extract_chat_ids(is_chat_admin)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(
|
||||||
|
cls, full_config: typing.Dict[str, typing.Any]
|
||||||
|
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if "is_chat_admin" in full_config:
|
if "is_chat_admin" in full_config:
|
||||||
|
|
@ -676,13 +761,19 @@ class AdminFilter(Filter):
|
||||||
message = obj.message
|
message = obj.message
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats
|
if (
|
||||||
|
message.chat.type == ChatType.PRIVATE
|
||||||
|
): # there is no admin in private chats
|
||||||
return False
|
return False
|
||||||
chat_ids = [message.chat.id]
|
chat_ids = [message.chat.id]
|
||||||
else:
|
else:
|
||||||
chat_ids = self._chat_ids
|
chat_ids = self._chat_ids
|
||||||
|
|
||||||
admins = [member.user.id for chat_id in chat_ids for member in await obj.bot.get_chat_administrators(chat_id)]
|
admins = [
|
||||||
|
member.user.id
|
||||||
|
for chat_id in chat_ids
|
||||||
|
for member in await obj.bot.get_chat_administrators(chat_id)
|
||||||
|
]
|
||||||
|
|
||||||
return user_id in admins
|
return user_id in admins
|
||||||
|
|
||||||
|
|
@ -691,20 +782,21 @@ class IsReplyFilter(BoundFilter):
|
||||||
"""
|
"""
|
||||||
Check if message is replied and send reply message to handler
|
Check if message is replied and send reply message to handler
|
||||||
"""
|
"""
|
||||||
key = 'is_reply'
|
|
||||||
|
key = "is_reply"
|
||||||
|
|
||||||
def __init__(self, is_reply):
|
def __init__(self, is_reply):
|
||||||
self.is_reply = is_reply
|
self.is_reply = is_reply
|
||||||
|
|
||||||
async def check(self, msg: Message):
|
async def check(self, msg: Message):
|
||||||
if msg.reply_to_message and self.is_reply:
|
if msg.reply_to_message and self.is_reply:
|
||||||
return {'reply': msg.reply_to_message}
|
return {"reply": msg.reply_to_message}
|
||||||
if not msg.reply_to_message and not self.is_reply:
|
if not msg.reply_to_message and not self.is_reply:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ForwardedMessageFilter(BoundFilter):
|
class ForwardedMessageFilter(BoundFilter):
|
||||||
key = 'is_forwarded'
|
key = "is_forwarded"
|
||||||
|
|
||||||
def __init__(self, is_forwarded: bool):
|
def __init__(self, is_forwarded: bool):
|
||||||
self.is_forwarded = is_forwarded
|
self.is_forwarded = is_forwarded
|
||||||
|
|
@ -714,7 +806,7 @@ class ForwardedMessageFilter(BoundFilter):
|
||||||
|
|
||||||
|
|
||||||
class ChatTypeFilter(BoundFilter):
|
class ChatTypeFilter(BoundFilter):
|
||||||
key = 'chat_type'
|
key = "chat_type"
|
||||||
|
|
||||||
def __init__(self, chat_type: typing.Container[ChatType]):
|
def __init__(self, chat_type: typing.Container[ChatType]):
|
||||||
if isinstance(chat_type, str):
|
if isinstance(chat_type, str):
|
||||||
|
|
@ -728,7 +820,8 @@ class ChatTypeFilter(BoundFilter):
|
||||||
elif isinstance(obj, CallbackQuery):
|
elif isinstance(obj, CallbackQuery):
|
||||||
obj = obj.message.chat
|
obj = obj.message.chat
|
||||||
else:
|
else:
|
||||||
warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj))
|
warnings.warn(
|
||||||
|
"ChatTypeFilter doesn't support %s as input", type(obj))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return obj.type in self.chat_type
|
return obj.type in self.chat_type
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,11 @@ def wrap_async(func):
|
||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if inspect.isawaitable(func) \
|
if (
|
||||||
or inspect.iscoroutinefunction(func) \
|
inspect.isawaitable(func)
|
||||||
or isinstance(func, AbstractFilter):
|
or inspect.iscoroutinefunction(func)
|
||||||
|
or isinstance(func, AbstractFilter)
|
||||||
|
):
|
||||||
return func
|
return func
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
|
|
||||||
|
|
@ -23,14 +25,16 @@ def wrap_async(func):
|
||||||
def get_filter_spec(dispatcher, filter_: callable):
|
def get_filter_spec(dispatcher, filter_: callable):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if not callable(filter_):
|
if not callable(filter_):
|
||||||
raise TypeError('Filter must be callable and/or awaitable!')
|
raise TypeError("Filter must be callable and/or awaitable!")
|
||||||
|
|
||||||
spec = inspect.getfullargspec(filter_)
|
spec = inspect.getfullargspec(filter_)
|
||||||
if 'dispatcher' in spec:
|
if "dispatcher" in spec:
|
||||||
kwargs['dispatcher'] = dispatcher
|
kwargs["dispatcher"] = dispatcher
|
||||||
if inspect.isawaitable(filter_) \
|
if (
|
||||||
or inspect.iscoroutinefunction(filter_) \
|
inspect.isawaitable(filter_)
|
||||||
or isinstance(filter_, AbstractFilter):
|
or inspect.iscoroutinefunction(filter_)
|
||||||
|
or isinstance(filter_, AbstractFilter)
|
||||||
|
):
|
||||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
|
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
|
||||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
|
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
|
||||||
|
|
||||||
|
|
@ -80,12 +84,17 @@ class FilterRecord:
|
||||||
Filters record for factory
|
Filters record for factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, callback: typing.Union[typing.Callable, 'AbstractFilter'],
|
def __init__(
|
||||||
validator: typing.Optional[typing.Callable] = None,
|
self,
|
||||||
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
callback: typing.Union[typing.Callable, "AbstractFilter"],
|
||||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
validator: typing.Optional[typing.Callable] = None,
|
||||||
|
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||||
|
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||||
|
):
|
||||||
if event_handlers and exclude_event_handlers:
|
if event_handlers and exclude_event_handlers:
|
||||||
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
|
raise ValueError(
|
||||||
|
"'event_handlers' and 'exclude_event_handlers' arguments cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.event_handlers = event_handlers
|
self.event_handlers = event_handlers
|
||||||
|
|
@ -93,22 +102,23 @@ class FilterRecord:
|
||||||
|
|
||||||
if validator is not None:
|
if validator is not None:
|
||||||
if not callable(validator):
|
if not callable(validator):
|
||||||
raise TypeError(f"validator must be callable, not {type(validator)}")
|
raise TypeError(
|
||||||
|
f"validator must be callable, not {type(validator)}")
|
||||||
self.resolver = validator
|
self.resolver = validator
|
||||||
elif issubclass(callback, AbstractFilter):
|
elif issubclass(callback, AbstractFilter):
|
||||||
self.resolver = callback.validate
|
self.resolver = callback.validate
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('validator is required!')
|
raise RuntimeError("validator is required!")
|
||||||
|
|
||||||
def resolve(self, dispatcher, event_handler, full_config):
|
def resolve(self, dispatcher, event_handler, full_config):
|
||||||
if not self._check_event_handler(event_handler):
|
if not self._check_event_handler(event_handler):
|
||||||
return
|
return
|
||||||
config = self.resolver(full_config)
|
config = self.resolver(full_config)
|
||||||
if config:
|
if config:
|
||||||
if 'dispatcher' not in config:
|
if "dispatcher" not in config:
|
||||||
spec = inspect.getfullargspec(self.callback)
|
spec = inspect.getfullargspec(self.callback)
|
||||||
if 'dispatcher' in spec.args:
|
if "dispatcher" in spec.args:
|
||||||
config['dispatcher'] = dispatcher
|
config["dispatcher"] = dispatcher
|
||||||
|
|
||||||
for key in config:
|
for key in config:
|
||||||
if key in full_config:
|
if key in full_config:
|
||||||
|
|
@ -131,7 +141,9 @@ class AbstractFilter(abc.ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(
|
||||||
|
cls, full_config: typing.Dict[str, typing.Any]
|
||||||
|
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
"""
|
"""
|
||||||
Validate and parse config.
|
Validate and parse config.
|
||||||
|
|
||||||
|
|
@ -182,7 +194,9 @@ class Filter(AbstractFilter):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
def validate(
|
||||||
|
cls, full_config: typing.Dict[str, typing.Any]
|
||||||
|
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||||
"""
|
"""
|
||||||
Here method ``validate`` is optional.
|
Here method ``validate`` is optional.
|
||||||
If you need to use filter from filters factory you need to override this method.
|
If you need to use filter from filters factory you need to override this method.
|
||||||
|
|
@ -200,16 +214,18 @@ class BoundFilter(Filter):
|
||||||
You need to implement ``__init__`` method with single argument related with key attribute
|
You need to implement ``__init__`` method with single argument related with key attribute
|
||||||
and ``check`` method where you need to implement filter logic.
|
and ``check`` method where you need to implement filter logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = None
|
key = None
|
||||||
"""Unique name of the filter argument. You need to override this attribute."""
|
"""Unique name of the filter argument. You need to override this attribute."""
|
||||||
required = False
|
required = False
|
||||||
"""If :obj:`True` this filter will be added to the all of the registered handlers"""
|
"""If :obj:`True` this filter will be added to the all of the registered handlers"""
|
||||||
default = None
|
default = None
|
||||||
"""Default value for configure required filters"""
|
"""Default value for configure required filters"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
def validate(
|
||||||
|
cls, full_config: typing.Dict[str, typing.Any]
|
||||||
|
) -> typing.Dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument.
|
If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument.
|
||||||
|
|
||||||
|
|
@ -226,7 +242,7 @@ class BoundFilter(Filter):
|
||||||
class _LogicFilter(Filter):
|
class _LogicFilter(Filter):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||||
raise ValueError('That filter can\'t be used in filters factory!')
|
raise ValueError("That filter can't be used in filters factory!")
|
||||||
|
|
||||||
|
|
||||||
class NotFilter(_LogicFilter):
|
class NotFilter(_LogicFilter):
|
||||||
|
|
@ -238,7 +254,6 @@ class NotFilter(_LogicFilter):
|
||||||
|
|
||||||
|
|
||||||
class AndFilter(_LogicFilter):
|
class AndFilter(_LogicFilter):
|
||||||
|
|
||||||
def __init__(self, *targets):
|
def __init__(self, *targets):
|
||||||
self.targets = [wrap_async(target) for target in targets]
|
self.targets = [wrap_async(target) for target in targets]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,15 +26,17 @@ class CallbackData:
|
||||||
Callback data factory
|
Callback data factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefix, *parts, sep=':'):
|
def __init__(self, prefix, *parts, sep=":"):
|
||||||
if not isinstance(prefix, str):
|
if not isinstance(prefix, str):
|
||||||
raise TypeError(f'Prefix must be instance of str not {type(prefix).__name__}')
|
raise TypeError(
|
||||||
|
f"Prefix must be instance of str not {type(prefix).__name__}"
|
||||||
|
)
|
||||||
if not prefix:
|
if not prefix:
|
||||||
raise ValueError("Prefix can't be empty")
|
raise ValueError("Prefix can't be empty")
|
||||||
if sep in prefix:
|
if sep in prefix:
|
||||||
raise ValueError(f"Separator {sep!r} can't be used in prefix")
|
raise ValueError(f"Separator {sep!r} can't be used in prefix")
|
||||||
if not parts:
|
if not parts:
|
||||||
raise TypeError('Parts were not passed!')
|
raise TypeError("Parts were not passed!")
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
|
|
@ -59,7 +61,7 @@ class CallbackData:
|
||||||
if args:
|
if args:
|
||||||
value = args.pop(0)
|
value = args.pop(0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Value for {part!r} was not passed!')
|
raise ValueError(f"Value for {part!r} was not passed!")
|
||||||
|
|
||||||
if value is not None and not isinstance(value, str):
|
if value is not None and not isinstance(value, str):
|
||||||
value = str(value)
|
value = str(value)
|
||||||
|
|
@ -67,16 +69,18 @@ class CallbackData:
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError(f"Value for part {part!r} can't be empty!'")
|
raise ValueError(f"Value for part {part!r} can't be empty!'")
|
||||||
if self.sep in value:
|
if self.sep in value:
|
||||||
raise ValueError(f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values")
|
raise ValueError(
|
||||||
|
f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values"
|
||||||
|
)
|
||||||
|
|
||||||
data.append(value)
|
data.append(value)
|
||||||
|
|
||||||
if args or kwargs:
|
if args or kwargs:
|
||||||
raise TypeError('Too many arguments were passed!')
|
raise TypeError("Too many arguments were passed!")
|
||||||
|
|
||||||
callback_data = self.sep.join(data)
|
callback_data = self.sep.join(data)
|
||||||
if len(callback_data.encode()) > 64:
|
if len(callback_data.encode()) > 64:
|
||||||
raise ValueError('Resulted callback data is too long!')
|
raise ValueError("Resulted callback data is too long!")
|
||||||
|
|
||||||
return callback_data
|
return callback_data
|
||||||
|
|
||||||
|
|
@ -89,11 +93,12 @@ class CallbackData:
|
||||||
"""
|
"""
|
||||||
prefix, *parts = callback_data.split(self.sep)
|
prefix, *parts = callback_data.split(self.sep)
|
||||||
if prefix != self.prefix:
|
if prefix != self.prefix:
|
||||||
raise ValueError("Passed callback data can't be parsed with that prefix.")
|
raise ValueError(
|
||||||
|
"Passed callback data can't be parsed with that prefix.")
|
||||||
if len(parts) != len(self._part_names):
|
if len(parts) != len(self._part_names):
|
||||||
raise ValueError('Invalid parts count!')
|
raise ValueError("Invalid parts count!")
|
||||||
|
|
||||||
result = {'@': prefix}
|
result = {"@": prefix}
|
||||||
result.update(zip(self._part_names, parts))
|
result.update(zip(self._part_names, parts))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -106,12 +111,11 @@ class CallbackData:
|
||||||
"""
|
"""
|
||||||
for key in config.keys():
|
for key in config.keys():
|
||||||
if key not in self._part_names:
|
if key not in self._part_names:
|
||||||
raise ValueError(f'Invalid field name {key!r}')
|
raise ValueError(f"Invalid field name {key!r}")
|
||||||
return CallbackDataFilter(self, config)
|
return CallbackDataFilter(self, config)
|
||||||
|
|
||||||
|
|
||||||
class CallbackDataFilter(Filter):
|
class CallbackDataFilter(Filter):
|
||||||
|
|
||||||
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
|
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.factory = factory
|
self.factory = factory
|
||||||
|
|
@ -133,4 +137,4 @@ class CallbackDataFilter(Filter):
|
||||||
else:
|
else:
|
||||||
if data.get(key) != value:
|
if data.get(key) != value:
|
||||||
return False
|
return False
|
||||||
return {'callback_data': data}
|
return {"callback_data": data}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue