mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge pull request #13 from muhammedfurkan/deepsource-fix-f820c4b0
Refactor unnecessary `else` / `elif` when `if` block has a `raise` statement
This commit is contained in:
commit
bd51d299df
4 changed files with 293 additions and 166 deletions
|
|
@ -10,7 +10,7 @@ from ..utils import exceptions, json
|
|||
from ..utils.helper import Helper, HelperMode, Item
|
||||
|
||||
# Main aiogram logger
|
||||
log = logging.getLogger('aiogram')
|
||||
log = logging.getLogger("aiogram")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -43,7 +43,7 @@ class TelegramAPIServer:
|
|||
return self.file.format(token=token, path=path)
|
||||
|
||||
@classmethod
|
||||
def from_base(cls, base: str) -> 'TelegramAPIServer':
|
||||
def from_base(cls, base: str) -> "TelegramAPIServer":
|
||||
base = base.rstrip("/")
|
||||
return cls(
|
||||
base=f"{base}/bot{{token}}/{{method}}",
|
||||
|
|
@ -62,17 +62,19 @@ def check_token(token: str) -> bool:
|
|||
:return:
|
||||
"""
|
||||
if not isinstance(token, str):
|
||||
message = (f"Token is invalid! "
|
||||
f"It must be 'str' type instead of {type(token)} type.")
|
||||
message = (
|
||||
f"Token is invalid! "
|
||||
f"It must be 'str' type instead of {type(token)} type."
|
||||
)
|
||||
raise exceptions.ValidationError(message)
|
||||
|
||||
if any(x.isspace() for x in token):
|
||||
message = "Token is invalid! It can't contains spaces."
|
||||
raise exceptions.ValidationError(message)
|
||||
|
||||
left, sep, right = token.partition(':')
|
||||
left, sep, right = token.partition(":")
|
||||
if (not sep) or (not left.isdigit()) or (not right):
|
||||
raise exceptions.ValidationError('Token is invalid!')
|
||||
raise exceptions.ValidationError("Token is invalid!")
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -94,24 +96,27 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
|
|||
"""
|
||||
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
|
||||
|
||||
if content_type != 'application/json':
|
||||
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
|
||||
if content_type != "application/json":
|
||||
raise exceptions.NetworkError(
|
||||
f'Invalid response with content type {content_type}: "{body}"'
|
||||
)
|
||||
|
||||
try:
|
||||
result_json = json.loads(body)
|
||||
except ValueError:
|
||||
result_json = {}
|
||||
|
||||
description = result_json.get('description') or body
|
||||
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
|
||||
description = result_json.get("description") or body
|
||||
parameters = types.ResponseParameters(
|
||||
**result_json.get("parameters", {}) or {})
|
||||
|
||||
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
|
||||
return result_json.get('result')
|
||||
return result_json.get("result")
|
||||
if parameters.retry_after:
|
||||
raise exceptions.RetryAfter(parameters.retry_after)
|
||||
elif parameters.migrate_to_chat_id:
|
||||
if parameters.migrate_to_chat_id:
|
||||
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
|
||||
elif status_code == HTTPStatus.BAD_REQUEST:
|
||||
if status_code == HTTPStatus.BAD_REQUEST:
|
||||
exceptions.BadRequest.detect(description)
|
||||
elif status_code == HTTPStatus.NOT_FOUND:
|
||||
exceptions.NotFound.detect(description)
|
||||
|
|
@ -120,26 +125,33 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
|
|||
elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
|
||||
exceptions.Unauthorized.detect(description)
|
||||
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
|
||||
raise exceptions.NetworkError('File too large for uploading. '
|
||||
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
|
||||
raise exceptions.NetworkError(
|
||||
"File too large for uploading. "
|
||||
"Check telegram api limits https://core.telegram.org/bots/api#senddocument"
|
||||
)
|
||||
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
if 'restart' in description:
|
||||
if "restart" in description:
|
||||
raise exceptions.RestartingTelegram()
|
||||
raise exceptions.TelegramAPIError(description)
|
||||
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
|
||||
|
||||
|
||||
async def make_request(session, server, token, method, data=None, files=None, **kwargs):
|
||||
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
|
||||
log.debug('Make request: "%s" with data: "%r" and files "%r"',
|
||||
method, data, files)
|
||||
|
||||
url = server.api_url(token=token, method=method)
|
||||
|
||||
req = compose_data(data, files)
|
||||
try:
|
||||
async with session.post(url, data=req, **kwargs) as response:
|
||||
return check_result(method, response.content_type, response.status, await response.text())
|
||||
return check_result(
|
||||
method, response.content_type, response.status, await response.text()
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||
raise exceptions.NetworkError(
|
||||
f"aiohttp client throws an error: {e.__class__.__name__}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def guess_filename(obj):
|
||||
|
|
@ -149,8 +161,8 @@ def guess_filename(obj):
|
|||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
name = getattr(obj, 'name', None)
|
||||
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
|
||||
name = getattr(obj, "name", None)
|
||||
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
|
||||
return os.path.basename(name)
|
||||
|
||||
|
||||
|
|
@ -174,7 +186,9 @@ def compose_data(params=None, files=None):
|
|||
if len(f) == 2:
|
||||
filename, fileobj = f
|
||||
else:
|
||||
raise ValueError('Tuple must have exactly 2 elements: filename, fileobj')
|
||||
raise ValueError(
|
||||
"Tuple must have exactly 2 elements: filename, fileobj"
|
||||
)
|
||||
elif isinstance(f, types.InputFile):
|
||||
filename, fileobj = f.filename, f.file
|
||||
else:
|
||||
|
|
@ -191,6 +205,7 @@ class Methods(Helper):
|
|||
|
||||
List is updated to Bot API 5.0
|
||||
"""
|
||||
|
||||
mode = HelperMode.lowerCamelCase
|
||||
|
||||
# Getting Updates
|
||||
|
|
|
|||
|
|
@ -18,11 +18,15 @@ ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str,
|
|||
def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]:
|
||||
# since "str" is also an "Iterable", we have to check for it first
|
||||
if isinstance(chat_id, str):
|
||||
return {int(chat_id), }
|
||||
return {
|
||||
int(chat_id),
|
||||
}
|
||||
if isinstance(chat_id, Iterable):
|
||||
return {int(item) for (item) in chat_id}
|
||||
# the last possible type is a single "int"
|
||||
return {chat_id, }
|
||||
return {
|
||||
chat_id,
|
||||
}
|
||||
|
||||
|
||||
class Command(Filter):
|
||||
|
|
@ -34,11 +38,14 @@ class Command(Filter):
|
|||
By default this filter is registered for messages and edited messages handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, commands: Union[Iterable, str],
|
||||
prefixes: Union[Iterable, str] = '/',
|
||||
ignore_case: bool = True,
|
||||
ignore_mention: bool = False,
|
||||
ignore_caption: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
commands: Union[Iterable, str],
|
||||
prefixes: Union[Iterable, str] = "/",
|
||||
ignore_case: bool = True,
|
||||
ignore_mention: bool = False,
|
||||
ignore_caption: bool = True,
|
||||
):
|
||||
"""
|
||||
Filter can be initialized from filters factory or by simply creating instance of this class.
|
||||
|
||||
|
|
@ -69,7 +76,8 @@ class Command(Filter):
|
|||
if isinstance(commands, str):
|
||||
commands = (commands,)
|
||||
|
||||
self.commands = list(map(str.lower, commands)) if ignore_case else commands
|
||||
self.commands = list(map(str.lower, commands)
|
||||
) if ignore_case else commands
|
||||
self.prefixes = prefixes
|
||||
self.ignore_case = ignore_case
|
||||
self.ignore_mention = ignore_mention
|
||||
|
|
@ -91,36 +99,61 @@ class Command(Filter):
|
|||
:return: config or empty dict
|
||||
"""
|
||||
config = {}
|
||||
if 'commands' in full_config:
|
||||
config['commands'] = full_config.pop('commands')
|
||||
if config and 'commands_prefix' in full_config:
|
||||
config['prefixes'] = full_config.pop('commands_prefix')
|
||||
if config and 'commands_ignore_mention' in full_config:
|
||||
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
|
||||
if config and 'commands_ignore_caption' in full_config:
|
||||
config['ignore_caption'] = full_config.pop('commands_ignore_caption')
|
||||
if "commands" in full_config:
|
||||
config["commands"] = full_config.pop("commands")
|
||||
if config and "commands_prefix" in full_config:
|
||||
config["prefixes"] = full_config.pop("commands_prefix")
|
||||
if config and "commands_ignore_mention" in full_config:
|
||||
config["ignore_mention"] = full_config.pop(
|
||||
"commands_ignore_mention")
|
||||
if config and "commands_ignore_caption" in full_config:
|
||||
config["ignore_caption"] = full_config.pop(
|
||||
"commands_ignore_caption")
|
||||
return config
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention, self.ignore_caption)
|
||||
return await self.check_command(
|
||||
message,
|
||||
self.commands,
|
||||
self.prefixes,
|
||||
self.ignore_case,
|
||||
self.ignore_mention,
|
||||
self.ignore_caption,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def check_command(cls, message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False, ignore_caption=True):
|
||||
text = message.text or (message.caption if not ignore_caption else None)
|
||||
async def check_command(
|
||||
cls,
|
||||
message: types.Message,
|
||||
commands,
|
||||
prefixes,
|
||||
ignore_case=True,
|
||||
ignore_mention=False,
|
||||
ignore_caption=True,
|
||||
):
|
||||
text = message.text or (
|
||||
message.caption if not ignore_caption else None)
|
||||
if not text:
|
||||
return False
|
||||
|
||||
full_command = text.split()[0]
|
||||
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
|
||||
prefix, (command, _,
|
||||
mention) = full_command[0], full_command[1:].partition("@")
|
||||
|
||||
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
|
||||
if (
|
||||
not ignore_mention
|
||||
and mention
|
||||
and (await message.bot.me).username.lower() != mention.lower()
|
||||
):
|
||||
return False
|
||||
if prefix not in prefixes:
|
||||
return False
|
||||
if (command.lower() if ignore_case else command) not in commands:
|
||||
return False
|
||||
|
||||
return {'command': cls.CommandObj(command=command, prefix=prefix, mention=mention)}
|
||||
return {
|
||||
"command": cls.CommandObj(command=command, prefix=prefix, mention=mention)
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class CommandObj:
|
||||
|
|
@ -131,9 +164,9 @@ class Command(Filter):
|
|||
"""
|
||||
|
||||
"""Command prefix"""
|
||||
prefix: str = '/'
|
||||
prefix: str = "/"
|
||||
"""Command without prefix and mention"""
|
||||
command: str = ''
|
||||
command: str = ""
|
||||
"""Mention (if available)"""
|
||||
mention: str = None
|
||||
"""Command argument"""
|
||||
|
|
@ -157,9 +190,9 @@ class Command(Filter):
|
|||
"""
|
||||
line = self.prefix + self.command
|
||||
if self.mentioned:
|
||||
line += '@' + self.mention
|
||||
line += "@" + self.mention
|
||||
if self.args:
|
||||
line += ' ' + self.args
|
||||
line += " " + self.args
|
||||
return line
|
||||
|
||||
|
||||
|
|
@ -168,9 +201,12 @@ class CommandStart(Command):
|
|||
This filter based on :obj:`Command` filter but can handle only ``/start`` command.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None,
|
||||
encoded: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
deep_link: typing.Optional[typing.Union[str,
|
||||
typing.Pattern[str]]] = None,
|
||||
encoded: bool = False,
|
||||
):
|
||||
"""
|
||||
Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments.
|
||||
|
||||
|
|
@ -183,7 +219,7 @@ class CommandStart(Command):
|
|||
:param deep_link: string or compiled regular expression (by ``re.compile(...)``).
|
||||
:param encoded: set True if you're waiting for encoded payload (default - False).
|
||||
"""
|
||||
super().__init__(['start'])
|
||||
super().__init__(["start"])
|
||||
self.deep_link = deep_link
|
||||
self.encoded = encoded
|
||||
|
||||
|
|
@ -195,17 +231,22 @@ class CommandStart(Command):
|
|||
:return:
|
||||
"""
|
||||
from ...utils.deep_linking import decode_payload
|
||||
|
||||
check = await super().check(message)
|
||||
|
||||
if check and self.deep_link is not None:
|
||||
payload = decode_payload(message.get_args()) if self.encoded else message.get_args()
|
||||
payload = (
|
||||
decode_payload(message.get_args())
|
||||
if self.encoded
|
||||
else message.get_args()
|
||||
)
|
||||
|
||||
if not isinstance(self.deep_link, typing.Pattern):
|
||||
return False if payload != self.deep_link else {'deep_link': payload}
|
||||
return False if payload != self.deep_link else {"deep_link": payload}
|
||||
|
||||
match = self.deep_link.match(payload)
|
||||
if match:
|
||||
return {'deep_link': match}
|
||||
return {"deep_link": match}
|
||||
return False
|
||||
|
||||
return check
|
||||
|
|
@ -217,7 +258,7 @@ class CommandHelp(Command):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(['help'])
|
||||
super().__init__(["help"])
|
||||
|
||||
|
||||
class CommandSettings(Command):
|
||||
|
|
@ -226,7 +267,7 @@ class CommandSettings(Command):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(['settings'])
|
||||
super().__init__(["settings"])
|
||||
|
||||
|
||||
class CommandPrivacy(Command):
|
||||
|
|
@ -235,7 +276,7 @@ class CommandPrivacy(Command):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(['privacy'])
|
||||
super().__init__(["privacy"])
|
||||
|
||||
|
||||
class Text(Filter):
|
||||
|
|
@ -244,18 +285,27 @@ class Text(Filter):
|
|||
"""
|
||||
|
||||
_default_params = (
|
||||
('text', 'equals'),
|
||||
('text_contains', 'contains'),
|
||||
('text_startswith', 'startswith'),
|
||||
('text_endswith', 'endswith'),
|
||||
("text", "equals"),
|
||||
("text_contains", "contains"),
|
||||
("text_startswith", "startswith"),
|
||||
("text_endswith", "endswith"),
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
||||
contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
||||
startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
||||
endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
|
||||
ignore_case=False):
|
||||
def __init__(
|
||||
self,
|
||||
equals: Optional[Union[str, LazyProxy,
|
||||
Iterable[Union[str, LazyProxy]]]] = None,
|
||||
contains: Optional[
|
||||
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||
] = None,
|
||||
startswith: Optional[
|
||||
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||
] = None,
|
||||
endswith: Optional[
|
||||
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
|
||||
] = None,
|
||||
ignore_case=False,
|
||||
):
|
||||
"""
|
||||
Check text for one of pattern. Only one mode can be used in one filter.
|
||||
In every pattern, a single string is treated as a list with 1 element.
|
||||
|
|
@ -267,15 +317,24 @@ class Text(Filter):
|
|||
:param ignore_case: case insensitive
|
||||
"""
|
||||
# Only one mode can be used. check it.
|
||||
check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith)))
|
||||
check = sum(
|
||||
map(lambda s: s is not None, (equals, contains, startswith, endswith))
|
||||
)
|
||||
if check > 1:
|
||||
args = "' and '".join([arg[0] for arg in [('equals', equals),
|
||||
('contains', contains),
|
||||
('startswith', startswith),
|
||||
('endswith', endswith)
|
||||
] if arg[1] is not None])
|
||||
args = "' and '".join(
|
||||
[
|
||||
arg[0]
|
||||
for arg in [
|
||||
("equals", equals),
|
||||
("contains", contains),
|
||||
("startswith", startswith),
|
||||
("endswith", endswith),
|
||||
]
|
||||
if arg[1] is not None
|
||||
]
|
||||
)
|
||||
raise ValueError(f"Arguments '{args}' cannot be used together.")
|
||||
elif check == 0:
|
||||
if check == 0:
|
||||
raise ValueError(f"No one mode is specified!")
|
||||
|
||||
equals, contains, endswith, startswith = map(
|
||||
|
|
@ -297,7 +356,7 @@ class Text(Filter):
|
|||
|
||||
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
||||
if isinstance(obj, Message):
|
||||
text = obj.text or obj.caption or ''
|
||||
text = obj.text or obj.caption or ""
|
||||
if not text and obj.poll:
|
||||
text = obj.poll.question
|
||||
elif isinstance(obj, CallbackQuery):
|
||||
|
|
@ -311,7 +370,10 @@ class Text(Filter):
|
|||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
_pre_process_func = lambda s: str(s).lower()
|
||||
|
||||
def _pre_process_func(s):
|
||||
return str(s).lower()
|
||||
|
||||
else:
|
||||
_pre_process_func = str
|
||||
|
||||
|
|
@ -344,7 +406,7 @@ class HashTag(Filter):
|
|||
|
||||
def __init__(self, hashtags=None, cashtags=None):
|
||||
if not hashtags and not cashtags:
|
||||
raise ValueError('No one hashtag or cashtag is specified!')
|
||||
raise ValueError("No one hashtag or cashtag is specified!")
|
||||
|
||||
if hashtags is None:
|
||||
hashtags = []
|
||||
|
|
@ -364,10 +426,10 @@ class HashTag(Filter):
|
|||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
config = {}
|
||||
if 'hashtags' in full_config:
|
||||
config['hashtags'] = full_config.pop('hashtags')
|
||||
if 'cashtags' in full_config:
|
||||
config['cashtags'] = full_config.pop('cashtags')
|
||||
if "hashtags" in full_config:
|
||||
config["hashtags"] = full_config.pop("hashtags")
|
||||
if "cashtags" in full_config:
|
||||
config["cashtags"] = full_config.pop("cashtags")
|
||||
return config
|
||||
|
||||
async def check(self, message: types.Message):
|
||||
|
|
@ -381,9 +443,13 @@ class HashTag(Filter):
|
|||
return False
|
||||
|
||||
hashtags, cashtags = self._get_tags(text, entities)
|
||||
if self.hashtags and set(hashtags) & set(self.hashtags) \
|
||||
or self.cashtags and set(cashtags) & set(self.cashtags):
|
||||
return {'hashtags': hashtags, 'cashtags': cashtags}
|
||||
if (
|
||||
self.hashtags
|
||||
and set(hashtags) & set(self.hashtags)
|
||||
or self.cashtags
|
||||
and set(cashtags) & set(self.cashtags)
|
||||
):
|
||||
return {"hashtags": hashtags, "cashtags": cashtags}
|
||||
|
||||
@staticmethod
|
||||
def _get_tags(text, entities):
|
||||
|
|
@ -392,11 +458,11 @@ class HashTag(Filter):
|
|||
|
||||
for entity in entities:
|
||||
if entity.type == types.MessageEntityType.HASHTAG:
|
||||
value = entity.get_text(text).lstrip('#')
|
||||
value = entity.get_text(text).lstrip("#")
|
||||
hashtags.append(value)
|
||||
|
||||
elif entity.type == types.MessageEntityType.CASHTAG:
|
||||
value = entity.get_text(text).lstrip('$')
|
||||
value = entity.get_text(text).lstrip("$")
|
||||
cashtags.append(value)
|
||||
|
||||
return hashtags, cashtags
|
||||
|
|
@ -414,12 +480,12 @@ class Regexp(Filter):
|
|||
|
||||
@classmethod
|
||||
def validate(cls, full_config: Dict[str, Any]):
|
||||
if 'regexp' in full_config:
|
||||
return {'regexp': full_config.pop('regexp')}
|
||||
if "regexp" in full_config:
|
||||
return {"regexp": full_config.pop("regexp")}
|
||||
|
||||
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
|
||||
if isinstance(obj, Message):
|
||||
content = obj.text or obj.caption or ''
|
||||
content = obj.text or obj.caption or ""
|
||||
if not content and obj.poll:
|
||||
content = obj.poll.question
|
||||
elif isinstance(obj, CallbackQuery) and obj.data:
|
||||
|
|
@ -434,7 +500,7 @@ class Regexp(Filter):
|
|||
match = self.regexp.search(content)
|
||||
|
||||
if match:
|
||||
return {'regexp': match}
|
||||
return {"regexp": match}
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -443,17 +509,20 @@ class RegexpCommandsFilter(BoundFilter):
|
|||
Check commands by regexp in message
|
||||
"""
|
||||
|
||||
key = 'regexp_commands'
|
||||
key = "regexp_commands"
|
||||
|
||||
def __init__(self, regexp_commands):
|
||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||
self.regexp_commands = [
|
||||
re.compile(command, flags=re.IGNORECASE | re.MULTILINE)
|
||||
for command in regexp_commands
|
||||
]
|
||||
|
||||
async def check(self, message):
|
||||
if not message.is_command():
|
||||
return False
|
||||
|
||||
command = message.text.split()[0][1:]
|
||||
command, _, mention = command.partition('@')
|
||||
command, _, mention = command.partition("@")
|
||||
|
||||
if mention and mention != (await message.bot.me).username:
|
||||
return False
|
||||
|
|
@ -461,7 +530,7 @@ class RegexpCommandsFilter(BoundFilter):
|
|||
for command in self.regexp_commands:
|
||||
search = command.search(message.text)
|
||||
if search:
|
||||
return {'regexp_command': search}
|
||||
return {"regexp_command": search}
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -470,7 +539,7 @@ class ContentTypeFilter(BoundFilter):
|
|||
Check message content type
|
||||
"""
|
||||
|
||||
key = 'content_types'
|
||||
key = "content_types"
|
||||
required = True
|
||||
default = types.ContentTypes.TEXT
|
||||
|
||||
|
|
@ -480,8 +549,10 @@ class ContentTypeFilter(BoundFilter):
|
|||
self.content_types = content_types
|
||||
|
||||
async def check(self, message):
|
||||
return types.ContentType.ANY in self.content_types or \
|
||||
message.content_type in self.content_types
|
||||
return (
|
||||
types.ContentType.ANY in self.content_types
|
||||
or message.content_type in self.content_types
|
||||
)
|
||||
|
||||
|
||||
class IsSenderContact(BoundFilter):
|
||||
|
|
@ -491,7 +562,8 @@ class IsSenderContact(BoundFilter):
|
|||
`is_sender_contact=True` - contact matches the sender
|
||||
`is_sender_contact=False` - result will be inverted
|
||||
"""
|
||||
key = 'is_sender_contact'
|
||||
|
||||
key = "is_sender_contact"
|
||||
|
||||
def __init__(self, is_sender_contact: bool):
|
||||
self.is_sender_contact = is_sender_contact
|
||||
|
|
@ -509,10 +581,11 @@ class StateFilter(BoundFilter):
|
|||
"""
|
||||
Check user state
|
||||
"""
|
||||
key = 'state'
|
||||
|
||||
key = "state"
|
||||
required = True
|
||||
|
||||
ctx_state = ContextVar('user_state')
|
||||
ctx_state = ContextVar("user_state")
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
|
|
@ -520,7 +593,9 @@ class StateFilter(BoundFilter):
|
|||
self.dispatcher = dispatcher
|
||||
states = []
|
||||
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
|
||||
state = [state, ]
|
||||
state = [
|
||||
state,
|
||||
]
|
||||
for item in state:
|
||||
if isinstance(item, State):
|
||||
states.append(item.state)
|
||||
|
|
@ -532,11 +607,13 @@ class StateFilter(BoundFilter):
|
|||
|
||||
@staticmethod
|
||||
def get_target(obj):
|
||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||
return getattr(getattr(obj, "chat", None), "id", None), getattr(
|
||||
getattr(obj, "from_user", None), "id", None
|
||||
)
|
||||
|
||||
async def check(self, obj):
|
||||
if '*' in self.states:
|
||||
return {'state': self.dispatcher.current_state()}
|
||||
if "*" in self.states:
|
||||
return {"state": self.dispatcher.current_state()}
|
||||
|
||||
try:
|
||||
state = self.ctx_state.get()
|
||||
|
|
@ -547,11 +624,14 @@ class StateFilter(BoundFilter):
|
|||
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
|
||||
self.ctx_state.set(state)
|
||||
if state in self.states:
|
||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||
return {
|
||||
"state": self.dispatcher.current_state(),
|
||||
"raw_state": state,
|
||||
}
|
||||
|
||||
else:
|
||||
if state in self.states:
|
||||
return {'state': self.dispatcher.current_state(), 'raw_state': state}
|
||||
return {"state": self.dispatcher.current_state(), "raw_state": state}
|
||||
|
||||
return False
|
||||
|
||||
|
|
@ -561,7 +641,7 @@ class ExceptionsFilter(BoundFilter):
|
|||
Filter for exceptions
|
||||
"""
|
||||
|
||||
key = 'exception'
|
||||
key = "exception"
|
||||
|
||||
def __init__(self, exception):
|
||||
self.exception = exception
|
||||
|
|
@ -576,10 +656,11 @@ class ExceptionsFilter(BoundFilter):
|
|||
|
||||
|
||||
class IDFilter(Filter):
|
||||
def __init__(self,
|
||||
user_id: Optional[ChatIDArgumentType] = None,
|
||||
chat_id: Optional[ChatIDArgumentType] = None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[ChatIDArgumentType] = None,
|
||||
chat_id: Optional[ChatIDArgumentType] = None,
|
||||
):
|
||||
"""
|
||||
:param user_id:
|
||||
:param chat_id:
|
||||
|
|
@ -597,13 +678,15 @@ class IDFilter(Filter):
|
|||
self.chat_id = extract_chat_ids(chat_id)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
def validate(
|
||||
cls, full_config: typing.Dict[str, typing.Any]
|
||||
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
result = {}
|
||||
if 'user_id' in full_config:
|
||||
result['user_id'] = full_config.pop('user_id')
|
||||
if "user_id" in full_config:
|
||||
result["user_id"] = full_config.pop("user_id")
|
||||
|
||||
if 'chat_id' in full_config:
|
||||
result['chat_id'] = full_config.pop('chat_id')
|
||||
if "chat_id" in full_config:
|
||||
result["chat_id"] = full_config.pop("chat_id")
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -658,7 +741,9 @@ class AdminFilter(Filter):
|
|||
self._chat_ids = extract_chat_ids(is_chat_admin)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
def validate(
|
||||
cls, full_config: typing.Dict[str, typing.Any]
|
||||
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
result = {}
|
||||
|
||||
if "is_chat_admin" in full_config:
|
||||
|
|
@ -676,13 +761,19 @@ class AdminFilter(Filter):
|
|||
message = obj.message
|
||||
else:
|
||||
return False
|
||||
if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats
|
||||
if (
|
||||
message.chat.type == ChatType.PRIVATE
|
||||
): # there is no admin in private chats
|
||||
return False
|
||||
chat_ids = [message.chat.id]
|
||||
else:
|
||||
chat_ids = self._chat_ids
|
||||
|
||||
admins = [member.user.id for chat_id in chat_ids for member in await obj.bot.get_chat_administrators(chat_id)]
|
||||
admins = [
|
||||
member.user.id
|
||||
for chat_id in chat_ids
|
||||
for member in await obj.bot.get_chat_administrators(chat_id)
|
||||
]
|
||||
|
||||
return user_id in admins
|
||||
|
||||
|
|
@ -691,20 +782,21 @@ class IsReplyFilter(BoundFilter):
|
|||
"""
|
||||
Check if message is replied and send reply message to handler
|
||||
"""
|
||||
key = 'is_reply'
|
||||
|
||||
key = "is_reply"
|
||||
|
||||
def __init__(self, is_reply):
|
||||
self.is_reply = is_reply
|
||||
|
||||
async def check(self, msg: Message):
|
||||
if msg.reply_to_message and self.is_reply:
|
||||
return {'reply': msg.reply_to_message}
|
||||
return {"reply": msg.reply_to_message}
|
||||
if not msg.reply_to_message and not self.is_reply:
|
||||
return True
|
||||
|
||||
|
||||
class ForwardedMessageFilter(BoundFilter):
|
||||
key = 'is_forwarded'
|
||||
key = "is_forwarded"
|
||||
|
||||
def __init__(self, is_forwarded: bool):
|
||||
self.is_forwarded = is_forwarded
|
||||
|
|
@ -714,7 +806,7 @@ class ForwardedMessageFilter(BoundFilter):
|
|||
|
||||
|
||||
class ChatTypeFilter(BoundFilter):
|
||||
key = 'chat_type'
|
||||
key = "chat_type"
|
||||
|
||||
def __init__(self, chat_type: typing.Container[ChatType]):
|
||||
if isinstance(chat_type, str):
|
||||
|
|
@ -728,7 +820,8 @@ class ChatTypeFilter(BoundFilter):
|
|||
elif isinstance(obj, CallbackQuery):
|
||||
obj = obj.message.chat
|
||||
else:
|
||||
warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj))
|
||||
warnings.warn(
|
||||
"ChatTypeFilter doesn't support %s as input", type(obj))
|
||||
return False
|
||||
|
||||
return obj.type in self.chat_type
|
||||
|
|
|
|||
|
|
@ -13,9 +13,11 @@ def wrap_async(func):
|
|||
async def async_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isawaitable(func) \
|
||||
or inspect.iscoroutinefunction(func) \
|
||||
or isinstance(func, AbstractFilter):
|
||||
if (
|
||||
inspect.isawaitable(func)
|
||||
or inspect.iscoroutinefunction(func)
|
||||
or isinstance(func, AbstractFilter)
|
||||
):
|
||||
return func
|
||||
return async_wrapper
|
||||
|
||||
|
|
@ -23,14 +25,16 @@ def wrap_async(func):
|
|||
def get_filter_spec(dispatcher, filter_: callable):
|
||||
kwargs = {}
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
raise TypeError("Filter must be callable and/or awaitable!")
|
||||
|
||||
spec = inspect.getfullargspec(filter_)
|
||||
if 'dispatcher' in spec:
|
||||
kwargs['dispatcher'] = dispatcher
|
||||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, AbstractFilter):
|
||||
if "dispatcher" in spec:
|
||||
kwargs["dispatcher"] = dispatcher
|
||||
if (
|
||||
inspect.isawaitable(filter_)
|
||||
or inspect.iscoroutinefunction(filter_)
|
||||
or isinstance(filter_, AbstractFilter)
|
||||
):
|
||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
|
||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
|
||||
|
||||
|
|
@ -70,7 +74,7 @@ async def check_filters(filters: typing.Iterable[FilterObj], args):
|
|||
f = await execute_filter(filter_, args)
|
||||
if not f:
|
||||
raise FilterNotPassed()
|
||||
elif isinstance(f, dict):
|
||||
if isinstance(f, dict):
|
||||
data.update(f)
|
||||
return data
|
||||
|
||||
|
|
@ -80,12 +84,17 @@ class FilterRecord:
|
|||
Filters record for factory
|
||||
"""
|
||||
|
||||
def __init__(self, callback: typing.Union[typing.Callable, 'AbstractFilter'],
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
callback: typing.Union[typing.Callable, "AbstractFilter"],
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
|
||||
):
|
||||
if event_handlers and exclude_event_handlers:
|
||||
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
|
||||
raise ValueError(
|
||||
"'event_handlers' and 'exclude_event_handlers' arguments cannot be used together."
|
||||
)
|
||||
|
||||
self.callback = callback
|
||||
self.event_handlers = event_handlers
|
||||
|
|
@ -93,22 +102,23 @@ class FilterRecord:
|
|||
|
||||
if validator is not None:
|
||||
if not callable(validator):
|
||||
raise TypeError(f"validator must be callable, not {type(validator)}")
|
||||
raise TypeError(
|
||||
f"validator must be callable, not {type(validator)}")
|
||||
self.resolver = validator
|
||||
elif issubclass(callback, AbstractFilter):
|
||||
self.resolver = callback.validate
|
||||
else:
|
||||
raise RuntimeError('validator is required!')
|
||||
raise RuntimeError("validator is required!")
|
||||
|
||||
def resolve(self, dispatcher, event_handler, full_config):
|
||||
if not self._check_event_handler(event_handler):
|
||||
return
|
||||
config = self.resolver(full_config)
|
||||
if config:
|
||||
if 'dispatcher' not in config:
|
||||
if "dispatcher" not in config:
|
||||
spec = inspect.getfullargspec(self.callback)
|
||||
if 'dispatcher' in spec.args:
|
||||
config['dispatcher'] = dispatcher
|
||||
if "dispatcher" in spec.args:
|
||||
config["dispatcher"] = dispatcher
|
||||
|
||||
for key in config:
|
||||
if key in full_config:
|
||||
|
|
@ -131,7 +141,9 @@ class AbstractFilter(abc.ABC):
|
|||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
def validate(
|
||||
cls, full_config: typing.Dict[str, typing.Any]
|
||||
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
"""
|
||||
Validate and parse config.
|
||||
|
||||
|
|
@ -182,7 +194,9 @@ class Filter(AbstractFilter):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
def validate(
|
||||
cls, full_config: typing.Dict[str, typing.Any]
|
||||
) -> typing.Optional[typing.Dict[str, typing.Any]]:
|
||||
"""
|
||||
Here method ``validate`` is optional.
|
||||
If you need to use filter from filters factory you need to override this method.
|
||||
|
|
@ -200,16 +214,18 @@ class BoundFilter(Filter):
|
|||
You need to implement ``__init__`` method with single argument related with key attribute
|
||||
and ``check`` method where you need to implement filter logic.
|
||||
"""
|
||||
|
||||
|
||||
key = None
|
||||
"""Unique name of the filter argument. You need to override this attribute."""
|
||||
required = False
|
||||
"""If :obj:`True` this filter will be added to the all of the registered handlers"""
|
||||
default = None
|
||||
"""Default value for configure required filters"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
||||
def validate(
|
||||
cls, full_config: typing.Dict[str, typing.Any]
|
||||
) -> typing.Dict[str, typing.Any]:
|
||||
"""
|
||||
If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument.
|
||||
|
||||
|
|
@ -226,7 +242,7 @@ class BoundFilter(Filter):
|
|||
class _LogicFilter(Filter):
|
||||
@classmethod
|
||||
def validate(cls, full_config: typing.Dict[str, typing.Any]):
|
||||
raise ValueError('That filter can\'t be used in filters factory!')
|
||||
raise ValueError("That filter can't be used in filters factory!")
|
||||
|
||||
|
||||
class NotFilter(_LogicFilter):
|
||||
|
|
@ -238,7 +254,6 @@ class NotFilter(_LogicFilter):
|
|||
|
||||
|
||||
class AndFilter(_LogicFilter):
|
||||
|
||||
def __init__(self, *targets):
|
||||
self.targets = [wrap_async(target) for target in targets]
|
||||
|
||||
|
|
|
|||
|
|
@ -26,15 +26,17 @@ class CallbackData:
|
|||
Callback data factory
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, *parts, sep=':'):
|
||||
def __init__(self, prefix, *parts, sep=":"):
|
||||
if not isinstance(prefix, str):
|
||||
raise TypeError(f'Prefix must be instance of str not {type(prefix).__name__}')
|
||||
raise TypeError(
|
||||
f"Prefix must be instance of str not {type(prefix).__name__}"
|
||||
)
|
||||
if not prefix:
|
||||
raise ValueError("Prefix can't be empty")
|
||||
if sep in prefix:
|
||||
raise ValueError(f"Separator {sep!r} can't be used in prefix")
|
||||
if not parts:
|
||||
raise TypeError('Parts were not passed!')
|
||||
raise TypeError("Parts were not passed!")
|
||||
|
||||
self.prefix = prefix
|
||||
self.sep = sep
|
||||
|
|
@ -59,7 +61,7 @@ class CallbackData:
|
|||
if args:
|
||||
value = args.pop(0)
|
||||
else:
|
||||
raise ValueError(f'Value for {part!r} was not passed!')
|
||||
raise ValueError(f"Value for {part!r} was not passed!")
|
||||
|
||||
if value is not None and not isinstance(value, str):
|
||||
value = str(value)
|
||||
|
|
@ -67,16 +69,18 @@ class CallbackData:
|
|||
if not value:
|
||||
raise ValueError(f"Value for part {part!r} can't be empty!'")
|
||||
if self.sep in value:
|
||||
raise ValueError(f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values")
|
||||
raise ValueError(
|
||||
f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values"
|
||||
)
|
||||
|
||||
data.append(value)
|
||||
|
||||
if args or kwargs:
|
||||
raise TypeError('Too many arguments were passed!')
|
||||
raise TypeError("Too many arguments were passed!")
|
||||
|
||||
callback_data = self.sep.join(data)
|
||||
if len(callback_data.encode()) > 64:
|
||||
raise ValueError('Resulted callback data is too long!')
|
||||
raise ValueError("Resulted callback data is too long!")
|
||||
|
||||
return callback_data
|
||||
|
||||
|
|
@ -89,11 +93,12 @@ class CallbackData:
|
|||
"""
|
||||
prefix, *parts = callback_data.split(self.sep)
|
||||
if prefix != self.prefix:
|
||||
raise ValueError("Passed callback data can't be parsed with that prefix.")
|
||||
elif len(parts) != len(self._part_names):
|
||||
raise ValueError('Invalid parts count!')
|
||||
raise ValueError(
|
||||
"Passed callback data can't be parsed with that prefix.")
|
||||
if len(parts) != len(self._part_names):
|
||||
raise ValueError("Invalid parts count!")
|
||||
|
||||
result = {'@': prefix}
|
||||
result = {"@": prefix}
|
||||
result.update(zip(self._part_names, parts))
|
||||
return result
|
||||
|
||||
|
|
@ -106,12 +111,11 @@ class CallbackData:
|
|||
"""
|
||||
for key in config.keys():
|
||||
if key not in self._part_names:
|
||||
raise ValueError(f'Invalid field name {key!r}')
|
||||
raise ValueError(f"Invalid field name {key!r}")
|
||||
return CallbackDataFilter(self, config)
|
||||
|
||||
|
||||
class CallbackDataFilter(Filter):
|
||||
|
||||
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
|
||||
self.config = config
|
||||
self.factory = factory
|
||||
|
|
@ -133,4 +137,4 @@ class CallbackDataFilter(Filter):
|
|||
else:
|
||||
if data.get(key) != value:
|
||||
return False
|
||||
return {'callback_data': data}
|
||||
return {"callback_data": data}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue