Format code with black, autopep8 and isort

This commit fixes the style issues introduced in 7ca09aa according to the output
from black, autopep8 and isort.

Details: 520006ec-d635-4384-b8eb-d844943c6f8c/
This commit is contained in:
deepsource-autofix[bot] 2020-11-08 22:01:28 +00:00 committed by GitHub
parent 7ca09aac1b
commit 0fe5a92d4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 288 additions and 161 deletions

View file

@ -10,7 +10,7 @@ from ..utils import exceptions, json
from ..utils.helper import Helper, HelperMode, Item from ..utils.helper import Helper, HelperMode, Item
# Main aiogram logger # Main aiogram logger
log = logging.getLogger('aiogram') log = logging.getLogger("aiogram")
@dataclass(frozen=True) @dataclass(frozen=True)
@ -43,7 +43,7 @@ class TelegramAPIServer:
return self.file.format(token=token, path=path) return self.file.format(token=token, path=path)
@classmethod @classmethod
def from_base(cls, base: str) -> 'TelegramAPIServer': def from_base(cls, base: str) -> "TelegramAPIServer":
base = base.rstrip("/") base = base.rstrip("/")
return cls( return cls(
base=f"{base}/bot{{token}}/{{method}}", base=f"{base}/bot{{token}}/{{method}}",
@ -62,17 +62,19 @@ def check_token(token: str) -> bool:
:return: :return:
""" """
if not isinstance(token, str): if not isinstance(token, str):
message = (f"Token is invalid! " message = (
f"It must be 'str' type instead of {type(token)} type.") f"Token is invalid! "
f"It must be 'str' type instead of {type(token)} type."
)
raise exceptions.ValidationError(message) raise exceptions.ValidationError(message)
if any(x.isspace() for x in token): if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces." message = "Token is invalid! It can't contains spaces."
raise exceptions.ValidationError(message) raise exceptions.ValidationError(message)
left, sep, right = token.partition(':') left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right): if (not sep) or (not left.isdigit()) or (not right):
raise exceptions.ValidationError('Token is invalid!') raise exceptions.ValidationError("Token is invalid!")
return True return True
@ -94,19 +96,22 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
""" """
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body) log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
if content_type != 'application/json': if content_type != "application/json":
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"") raise exceptions.NetworkError(
f'Invalid response with content type {content_type}: "{body}"'
)
try: try:
result_json = json.loads(body) result_json = json.loads(body)
except ValueError: except ValueError:
result_json = {} result_json = {}
description = result_json.get('description') or body description = result_json.get("description") or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {}) parameters = types.ResponseParameters(
**result_json.get("parameters", {}) or {})
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED: if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
return result_json.get('result') return result_json.get("result")
if parameters.retry_after: if parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after) raise exceptions.RetryAfter(parameters.retry_after)
if parameters.migrate_to_chat_id: if parameters.migrate_to_chat_id:
@ -120,26 +125,33 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
exceptions.Unauthorized.detect(description) exceptions.Unauthorized.detect(description)
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. ' raise exceptions.NetworkError(
'Check telegram api limits https://core.telegram.org/bots/api#senddocument') "File too large for uploading. "
"Check telegram api limits https://core.telegram.org/bots/api#senddocument"
)
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description: if "restart" in description:
raise exceptions.RestartingTelegram() raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description) raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{status_code}]") raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
async def make_request(session, server, token, method, data=None, files=None, **kwargs): async def make_request(session, server, token, method, data=None, files=None, **kwargs):
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files) log.debug('Make request: "%s" with data: "%r" and files "%r"',
method, data, files)
url = server.api_url(token=token, method=method) url = server.api_url(token=token, method=method)
req = compose_data(data, files) req = compose_data(data, files)
try: try:
async with session.post(url, data=req, **kwargs) as response: async with session.post(url, data=req, **kwargs) as response:
return check_result(method, response.content_type, response.status, await response.text()) return check_result(
method, response.content_type, response.status, await response.text()
)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}") raise exceptions.NetworkError(
f"aiohttp client throws an error: {e.__class__.__name__}: {e}"
)
def guess_filename(obj): def guess_filename(obj):
@ -149,8 +161,8 @@ def guess_filename(obj):
:param obj: :param obj:
:return: :return:
""" """
name = getattr(obj, 'name', None) name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>': if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return os.path.basename(name) return os.path.basename(name)
@ -174,7 +186,9 @@ def compose_data(params=None, files=None):
if len(f) == 2: if len(f) == 2:
filename, fileobj = f filename, fileobj = f
else: else:
raise ValueError('Tuple must have exactly 2 elements: filename, fileobj') raise ValueError(
"Tuple must have exactly 2 elements: filename, fileobj"
)
elif isinstance(f, types.InputFile): elif isinstance(f, types.InputFile):
filename, fileobj = f.filename, f.file filename, fileobj = f.filename, f.file
else: else:
@ -191,6 +205,7 @@ class Methods(Helper):
List is updated to Bot API 5.0 List is updated to Bot API 5.0
""" """
mode = HelperMode.lowerCamelCase mode = HelperMode.lowerCamelCase
# Getting Updates # Getting Updates

View file

@ -18,11 +18,15 @@ ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str,
def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]: def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]:
# since "str" is also an "Iterable", we have to check for it first # since "str" is also an "Iterable", we have to check for it first
if isinstance(chat_id, str): if isinstance(chat_id, str):
return {int(chat_id), } return {
int(chat_id),
}
if isinstance(chat_id, Iterable): if isinstance(chat_id, Iterable):
return {int(item) for (item) in chat_id} return {int(item) for (item) in chat_id}
# the last possible type is a single "int" # the last possible type is a single "int"
return {chat_id, } return {
chat_id,
}
class Command(Filter): class Command(Filter):
@ -34,11 +38,14 @@ class Command(Filter):
By default this filter is registered for messages and edited messages handlers. By default this filter is registered for messages and edited messages handlers.
""" """
def __init__(self, commands: Union[Iterable, str], def __init__(
prefixes: Union[Iterable, str] = '/', self,
ignore_case: bool = True, commands: Union[Iterable, str],
ignore_mention: bool = False, prefixes: Union[Iterable, str] = "/",
ignore_caption: bool = True): ignore_case: bool = True,
ignore_mention: bool = False,
ignore_caption: bool = True,
):
""" """
Filter can be initialized from filters factory or by simply creating instance of this class. Filter can be initialized from filters factory or by simply creating instance of this class.
@ -69,7 +76,8 @@ class Command(Filter):
if isinstance(commands, str): if isinstance(commands, str):
commands = (commands,) commands = (commands,)
self.commands = list(map(str.lower, commands)) if ignore_case else commands self.commands = list(map(str.lower, commands)
) if ignore_case else commands
self.prefixes = prefixes self.prefixes = prefixes
self.ignore_case = ignore_case self.ignore_case = ignore_case
self.ignore_mention = ignore_mention self.ignore_mention = ignore_mention
@ -91,36 +99,61 @@ class Command(Filter):
:return: config or empty dict :return: config or empty dict
""" """
config = {} config = {}
if 'commands' in full_config: if "commands" in full_config:
config['commands'] = full_config.pop('commands') config["commands"] = full_config.pop("commands")
if config and 'commands_prefix' in full_config: if config and "commands_prefix" in full_config:
config['prefixes'] = full_config.pop('commands_prefix') config["prefixes"] = full_config.pop("commands_prefix")
if config and 'commands_ignore_mention' in full_config: if config and "commands_ignore_mention" in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention') config["ignore_mention"] = full_config.pop(
if config and 'commands_ignore_caption' in full_config: "commands_ignore_mention")
config['ignore_caption'] = full_config.pop('commands_ignore_caption') if config and "commands_ignore_caption" in full_config:
config["ignore_caption"] = full_config.pop(
"commands_ignore_caption")
return config return config
async def check(self, message: types.Message): async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention, self.ignore_caption) return await self.check_command(
message,
self.commands,
self.prefixes,
self.ignore_case,
self.ignore_mention,
self.ignore_caption,
)
@classmethod @classmethod
async def check_command(cls, message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False, ignore_caption=True): async def check_command(
text = message.text or (message.caption if not ignore_caption else None) cls,
message: types.Message,
commands,
prefixes,
ignore_case=True,
ignore_mention=False,
ignore_caption=True,
):
text = message.text or (
message.caption if not ignore_caption else None)
if not text: if not text:
return False return False
full_command = text.split()[0] full_command = text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@') prefix, (command, _,
mention) = full_command[0], full_command[1:].partition("@")
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower(): if (
not ignore_mention
and mention
and (await message.bot.me).username.lower() != mention.lower()
):
return False return False
if prefix not in prefixes: if prefix not in prefixes:
return False return False
if (command.lower() if ignore_case else command) not in commands: if (command.lower() if ignore_case else command) not in commands:
return False return False
return {'command': cls.CommandObj(command=command, prefix=prefix, mention=mention)} return {
"command": cls.CommandObj(command=command, prefix=prefix, mention=mention)
}
@dataclass @dataclass
class CommandObj: class CommandObj:
@ -131,9 +164,9 @@ class Command(Filter):
""" """
"""Command prefix""" """Command prefix"""
prefix: str = '/' prefix: str = "/"
"""Command without prefix and mention""" """Command without prefix and mention"""
command: str = '' command: str = ""
"""Mention (if available)""" """Mention (if available)"""
mention: str = None mention: str = None
"""Command argument""" """Command argument"""
@ -157,9 +190,9 @@ class Command(Filter):
""" """
line = self.prefix + self.command line = self.prefix + self.command
if self.mentioned: if self.mentioned:
line += '@' + self.mention line += "@" + self.mention
if self.args: if self.args:
line += ' ' + self.args line += " " + self.args
return line return line
@ -168,9 +201,12 @@ class CommandStart(Command):
This filter based on :obj:`Command` filter but can handle only ``/start`` command. This filter based on :obj:`Command` filter but can handle only ``/start`` command.
""" """
def __init__(self, def __init__(
deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None, self,
encoded: bool = False): deep_link: typing.Optional[typing.Union[str,
typing.Pattern[str]]] = None,
encoded: bool = False,
):
""" """
Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments. Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments.
@ -183,7 +219,7 @@ class CommandStart(Command):
:param deep_link: string or compiled regular expression (by ``re.compile(...)``). :param deep_link: string or compiled regular expression (by ``re.compile(...)``).
:param encoded: set True if you're waiting for encoded payload (default - False). :param encoded: set True if you're waiting for encoded payload (default - False).
""" """
super().__init__(['start']) super().__init__(["start"])
self.deep_link = deep_link self.deep_link = deep_link
self.encoded = encoded self.encoded = encoded
@ -195,17 +231,22 @@ class CommandStart(Command):
:return: :return:
""" """
from ...utils.deep_linking import decode_payload from ...utils.deep_linking import decode_payload
check = await super().check(message) check = await super().check(message)
if check and self.deep_link is not None: if check and self.deep_link is not None:
payload = decode_payload(message.get_args()) if self.encoded else message.get_args() payload = (
decode_payload(message.get_args())
if self.encoded
else message.get_args()
)
if not isinstance(self.deep_link, typing.Pattern): if not isinstance(self.deep_link, typing.Pattern):
return False if payload != self.deep_link else {'deep_link': payload} return False if payload != self.deep_link else {"deep_link": payload}
match = self.deep_link.match(payload) match = self.deep_link.match(payload)
if match: if match:
return {'deep_link': match} return {"deep_link": match}
return False return False
return check return check
@ -217,7 +258,7 @@ class CommandHelp(Command):
""" """
def __init__(self): def __init__(self):
super().__init__(['help']) super().__init__(["help"])
class CommandSettings(Command): class CommandSettings(Command):
@ -226,7 +267,7 @@ class CommandSettings(Command):
""" """
def __init__(self): def __init__(self):
super().__init__(['settings']) super().__init__(["settings"])
class CommandPrivacy(Command): class CommandPrivacy(Command):
@ -235,7 +276,7 @@ class CommandPrivacy(Command):
""" """
def __init__(self): def __init__(self):
super().__init__(['privacy']) super().__init__(["privacy"])
class Text(Filter): class Text(Filter):
@ -244,18 +285,27 @@ class Text(Filter):
""" """
_default_params = ( _default_params = (
('text', 'equals'), ("text", "equals"),
('text_contains', 'contains'), ("text_contains", "contains"),
('text_startswith', 'startswith'), ("text_startswith", "startswith"),
('text_endswith', 'endswith'), ("text_endswith", "endswith"),
) )
def __init__(self, def __init__(
equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, self,
contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, equals: Optional[Union[str, LazyProxy,
startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, Iterable[Union[str, LazyProxy]]]] = None,
endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, contains: Optional[
ignore_case=False): Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
startswith: Optional[
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
endswith: Optional[
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
ignore_case=False,
):
""" """
Check text for one of pattern. Only one mode can be used in one filter. Check text for one of pattern. Only one mode can be used in one filter.
In every pattern, a single string is treated as a list with 1 element. In every pattern, a single string is treated as a list with 1 element.
@ -267,13 +317,22 @@ class Text(Filter):
:param ignore_case: case insensitive :param ignore_case: case insensitive
""" """
# Only one mode can be used. check it. # Only one mode can be used. check it.
check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith))) check = sum(
map(lambda s: s is not None, (equals, contains, startswith, endswith))
)
if check > 1: if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals), args = "' and '".join(
('contains', contains), [
('startswith', startswith), arg[0]
('endswith', endswith) for arg in [
] if arg[1] is not None]) ("equals", equals),
("contains", contains),
("startswith", startswith),
("endswith", endswith),
]
if arg[1] is not None
]
)
raise ValueError(f"Arguments '{args}' cannot be used together.") raise ValueError(f"Arguments '{args}' cannot be used together.")
if check == 0: if check == 0:
raise ValueError(f"No one mode is specified!") raise ValueError(f"No one mode is specified!")
@ -297,7 +356,7 @@ class Text(Filter):
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]): async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
if isinstance(obj, Message): if isinstance(obj, Message):
text = obj.text or obj.caption or '' text = obj.text or obj.caption or ""
if not text and obj.poll: if not text and obj.poll:
text = obj.poll.question text = obj.poll.question
elif isinstance(obj, CallbackQuery): elif isinstance(obj, CallbackQuery):
@ -311,7 +370,10 @@ class Text(Filter):
if self.ignore_case: if self.ignore_case:
text = text.lower() text = text.lower()
_pre_process_func = lambda s: str(s).lower()
def _pre_process_func(s):
return str(s).lower()
else: else:
_pre_process_func = str _pre_process_func = str
@ -344,7 +406,7 @@ class HashTag(Filter):
def __init__(self, hashtags=None, cashtags=None): def __init__(self, hashtags=None, cashtags=None):
if not hashtags and not cashtags: if not hashtags and not cashtags:
raise ValueError('No one hashtag or cashtag is specified!') raise ValueError("No one hashtag or cashtag is specified!")
if hashtags is None: if hashtags is None:
hashtags = [] hashtags = []
@ -364,10 +426,10 @@ class HashTag(Filter):
@classmethod @classmethod
def validate(cls, full_config: Dict[str, Any]): def validate(cls, full_config: Dict[str, Any]):
config = {} config = {}
if 'hashtags' in full_config: if "hashtags" in full_config:
config['hashtags'] = full_config.pop('hashtags') config["hashtags"] = full_config.pop("hashtags")
if 'cashtags' in full_config: if "cashtags" in full_config:
config['cashtags'] = full_config.pop('cashtags') config["cashtags"] = full_config.pop("cashtags")
return config return config
async def check(self, message: types.Message): async def check(self, message: types.Message):
@ -381,9 +443,13 @@ class HashTag(Filter):
return False return False
hashtags, cashtags = self._get_tags(text, entities) hashtags, cashtags = self._get_tags(text, entities)
if self.hashtags and set(hashtags) & set(self.hashtags) \ if (
or self.cashtags and set(cashtags) & set(self.cashtags): self.hashtags
return {'hashtags': hashtags, 'cashtags': cashtags} and set(hashtags) & set(self.hashtags)
or self.cashtags
and set(cashtags) & set(self.cashtags)
):
return {"hashtags": hashtags, "cashtags": cashtags}
@staticmethod @staticmethod
def _get_tags(text, entities): def _get_tags(text, entities):
@ -392,11 +458,11 @@ class HashTag(Filter):
for entity in entities: for entity in entities:
if entity.type == types.MessageEntityType.HASHTAG: if entity.type == types.MessageEntityType.HASHTAG:
value = entity.get_text(text).lstrip('#') value = entity.get_text(text).lstrip("#")
hashtags.append(value) hashtags.append(value)
elif entity.type == types.MessageEntityType.CASHTAG: elif entity.type == types.MessageEntityType.CASHTAG:
value = entity.get_text(text).lstrip('$') value = entity.get_text(text).lstrip("$")
cashtags.append(value) cashtags.append(value)
return hashtags, cashtags return hashtags, cashtags
@ -414,12 +480,12 @@ class Regexp(Filter):
@classmethod @classmethod
def validate(cls, full_config: Dict[str, Any]): def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config: if "regexp" in full_config:
return {'regexp': full_config.pop('regexp')} return {"regexp": full_config.pop("regexp")}
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]): async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
if isinstance(obj, Message): if isinstance(obj, Message):
content = obj.text or obj.caption or '' content = obj.text or obj.caption or ""
if not content and obj.poll: if not content and obj.poll:
content = obj.poll.question content = obj.poll.question
elif isinstance(obj, CallbackQuery) and obj.data: elif isinstance(obj, CallbackQuery) and obj.data:
@ -434,7 +500,7 @@ class Regexp(Filter):
match = self.regexp.search(content) match = self.regexp.search(content)
if match: if match:
return {'regexp': match} return {"regexp": match}
return False return False
@ -443,17 +509,20 @@ class RegexpCommandsFilter(BoundFilter):
Check commands by regexp in message Check commands by regexp in message
""" """
key = 'regexp_commands' key = "regexp_commands"
def __init__(self, regexp_commands): def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] self.regexp_commands = [
re.compile(command, flags=re.IGNORECASE | re.MULTILINE)
for command in regexp_commands
]
async def check(self, message): async def check(self, message):
if not message.is_command(): if not message.is_command():
return False return False
command = message.text.split()[0][1:] command = message.text.split()[0][1:]
command, _, mention = command.partition('@') command, _, mention = command.partition("@")
if mention and mention != (await message.bot.me).username: if mention and mention != (await message.bot.me).username:
return False return False
@ -461,7 +530,7 @@ class RegexpCommandsFilter(BoundFilter):
for command in self.regexp_commands: for command in self.regexp_commands:
search = command.search(message.text) search = command.search(message.text)
if search: if search:
return {'regexp_command': search} return {"regexp_command": search}
return False return False
@ -470,7 +539,7 @@ class ContentTypeFilter(BoundFilter):
Check message content type Check message content type
""" """
key = 'content_types' key = "content_types"
required = True required = True
default = types.ContentTypes.TEXT default = types.ContentTypes.TEXT
@ -480,8 +549,10 @@ class ContentTypeFilter(BoundFilter):
self.content_types = content_types self.content_types = content_types
async def check(self, message): async def check(self, message):
return types.ContentType.ANY in self.content_types or \ return (
message.content_type in self.content_types types.ContentType.ANY in self.content_types
or message.content_type in self.content_types
)
class IsSenderContact(BoundFilter): class IsSenderContact(BoundFilter):
@ -491,7 +562,8 @@ class IsSenderContact(BoundFilter):
`is_sender_contact=True` - contact matches the sender `is_sender_contact=True` - contact matches the sender
`is_sender_contact=False` - result will be inverted `is_sender_contact=False` - result will be inverted
""" """
key = 'is_sender_contact'
key = "is_sender_contact"
def __init__(self, is_sender_contact: bool): def __init__(self, is_sender_contact: bool):
self.is_sender_contact = is_sender_contact self.is_sender_contact = is_sender_contact
@ -509,10 +581,11 @@ class StateFilter(BoundFilter):
""" """
Check user state Check user state
""" """
key = 'state'
key = "state"
required = True required = True
ctx_state = ContextVar('user_state') ctx_state = ContextVar("user_state")
def __init__(self, dispatcher, state): def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup from aiogram.dispatcher.filters.state import State, StatesGroup
@ -520,7 +593,9 @@ class StateFilter(BoundFilter):
self.dispatcher = dispatcher self.dispatcher = dispatcher
states = [] states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None: if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ] state = [
state,
]
for item in state: for item in state:
if isinstance(item, State): if isinstance(item, State):
states.append(item.state) states.append(item.state)
@ -532,11 +607,13 @@ class StateFilter(BoundFilter):
@staticmethod @staticmethod
def get_target(obj): def get_target(obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) return getattr(getattr(obj, "chat", None), "id", None), getattr(
getattr(obj, "from_user", None), "id", None
)
async def check(self, obj): async def check(self, obj):
if '*' in self.states: if "*" in self.states:
return {'state': self.dispatcher.current_state()} return {"state": self.dispatcher.current_state()}
try: try:
state = self.ctx_state.get() state = self.ctx_state.get()
@ -547,11 +624,14 @@ class StateFilter(BoundFilter):
state = await self.dispatcher.storage.get_state(chat=chat, user=user) state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state) self.ctx_state.set(state)
if state in self.states: if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state} return {
"state": self.dispatcher.current_state(),
"raw_state": state,
}
else: else:
if state in self.states: if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state} return {"state": self.dispatcher.current_state(), "raw_state": state}
return False return False
@ -561,7 +641,7 @@ class ExceptionsFilter(BoundFilter):
Filter for exceptions Filter for exceptions
""" """
key = 'exception' key = "exception"
def __init__(self, exception): def __init__(self, exception):
self.exception = exception self.exception = exception
@ -576,10 +656,11 @@ class ExceptionsFilter(BoundFilter):
class IDFilter(Filter): class IDFilter(Filter):
def __init__(self, def __init__(
user_id: Optional[ChatIDArgumentType] = None, self,
chat_id: Optional[ChatIDArgumentType] = None, user_id: Optional[ChatIDArgumentType] = None,
): chat_id: Optional[ChatIDArgumentType] = None,
):
""" """
:param user_id: :param user_id:
:param chat_id: :param chat_id:
@ -597,13 +678,15 @@ class IDFilter(Filter):
self.chat_id = extract_chat_ids(chat_id) self.chat_id = extract_chat_ids(chat_id)
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
result = {} result = {}
if 'user_id' in full_config: if "user_id" in full_config:
result['user_id'] = full_config.pop('user_id') result["user_id"] = full_config.pop("user_id")
if 'chat_id' in full_config: if "chat_id" in full_config:
result['chat_id'] = full_config.pop('chat_id') result["chat_id"] = full_config.pop("chat_id")
return result return result
@ -658,7 +741,9 @@ class AdminFilter(Filter):
self._chat_ids = extract_chat_ids(is_chat_admin) self._chat_ids = extract_chat_ids(is_chat_admin)
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
result = {} result = {}
if "is_chat_admin" in full_config: if "is_chat_admin" in full_config:
@ -676,13 +761,19 @@ class AdminFilter(Filter):
message = obj.message message = obj.message
else: else:
return False return False
if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats if (
message.chat.type == ChatType.PRIVATE
): # there is no admin in private chats
return False return False
chat_ids = [message.chat.id] chat_ids = [message.chat.id]
else: else:
chat_ids = self._chat_ids chat_ids = self._chat_ids
admins = [member.user.id for chat_id in chat_ids for member in await obj.bot.get_chat_administrators(chat_id)] admins = [
member.user.id
for chat_id in chat_ids
for member in await obj.bot.get_chat_administrators(chat_id)
]
return user_id in admins return user_id in admins
@ -691,20 +782,21 @@ class IsReplyFilter(BoundFilter):
""" """
Check if message is replied and send reply message to handler Check if message is replied and send reply message to handler
""" """
key = 'is_reply'
key = "is_reply"
def __init__(self, is_reply): def __init__(self, is_reply):
self.is_reply = is_reply self.is_reply = is_reply
async def check(self, msg: Message): async def check(self, msg: Message):
if msg.reply_to_message and self.is_reply: if msg.reply_to_message and self.is_reply:
return {'reply': msg.reply_to_message} return {"reply": msg.reply_to_message}
if not msg.reply_to_message and not self.is_reply: if not msg.reply_to_message and not self.is_reply:
return True return True
class ForwardedMessageFilter(BoundFilter): class ForwardedMessageFilter(BoundFilter):
key = 'is_forwarded' key = "is_forwarded"
def __init__(self, is_forwarded: bool): def __init__(self, is_forwarded: bool):
self.is_forwarded = is_forwarded self.is_forwarded = is_forwarded
@ -714,7 +806,7 @@ class ForwardedMessageFilter(BoundFilter):
class ChatTypeFilter(BoundFilter): class ChatTypeFilter(BoundFilter):
key = 'chat_type' key = "chat_type"
def __init__(self, chat_type: typing.Container[ChatType]): def __init__(self, chat_type: typing.Container[ChatType]):
if isinstance(chat_type, str): if isinstance(chat_type, str):
@ -728,7 +820,8 @@ class ChatTypeFilter(BoundFilter):
elif isinstance(obj, CallbackQuery): elif isinstance(obj, CallbackQuery):
obj = obj.message.chat obj = obj.message.chat
else: else:
warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj)) warnings.warn(
"ChatTypeFilter doesn't support %s as input", type(obj))
return False return False
return obj.type in self.chat_type return obj.type in self.chat_type

View file

@ -13,9 +13,11 @@ def wrap_async(func):
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
if inspect.isawaitable(func) \ if (
or inspect.iscoroutinefunction(func) \ inspect.isawaitable(func)
or isinstance(func, AbstractFilter): or inspect.iscoroutinefunction(func)
or isinstance(func, AbstractFilter)
):
return func return func
return async_wrapper return async_wrapper
@ -23,14 +25,16 @@ def wrap_async(func):
def get_filter_spec(dispatcher, filter_: callable): def get_filter_spec(dispatcher, filter_: callable):
kwargs = {} kwargs = {}
if not callable(filter_): if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!') raise TypeError("Filter must be callable and/or awaitable!")
spec = inspect.getfullargspec(filter_) spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec: if "dispatcher" in spec:
kwargs['dispatcher'] = dispatcher kwargs["dispatcher"] = dispatcher
if inspect.isawaitable(filter_) \ if (
or inspect.iscoroutinefunction(filter_) \ inspect.isawaitable(filter_)
or isinstance(filter_, AbstractFilter): or inspect.iscoroutinefunction(filter_)
or isinstance(filter_, AbstractFilter)
):
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True) return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False) return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
@ -80,12 +84,17 @@ class FilterRecord:
Filters record for factory Filters record for factory
""" """
def __init__(self, callback: typing.Union[typing.Callable, 'AbstractFilter'], def __init__(
validator: typing.Optional[typing.Callable] = None, self,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None, callback: typing.Union[typing.Callable, "AbstractFilter"],
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None): validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
):
if event_handlers and exclude_event_handlers: if event_handlers and exclude_event_handlers:
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.") raise ValueError(
"'event_handlers' and 'exclude_event_handlers' arguments cannot be used together."
)
self.callback = callback self.callback = callback
self.event_handlers = event_handlers self.event_handlers = event_handlers
@ -93,22 +102,23 @@ class FilterRecord:
if validator is not None: if validator is not None:
if not callable(validator): if not callable(validator):
raise TypeError(f"validator must be callable, not {type(validator)}") raise TypeError(
f"validator must be callable, not {type(validator)}")
self.resolver = validator self.resolver = validator
elif issubclass(callback, AbstractFilter): elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate self.resolver = callback.validate
else: else:
raise RuntimeError('validator is required!') raise RuntimeError("validator is required!")
def resolve(self, dispatcher, event_handler, full_config): def resolve(self, dispatcher, event_handler, full_config):
if not self._check_event_handler(event_handler): if not self._check_event_handler(event_handler):
return return
config = self.resolver(full_config) config = self.resolver(full_config)
if config: if config:
if 'dispatcher' not in config: if "dispatcher" not in config:
spec = inspect.getfullargspec(self.callback) spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args: if "dispatcher" in spec.args:
config['dispatcher'] = dispatcher config["dispatcher"] = dispatcher
for key in config: for key in config:
if key in full_config: if key in full_config:
@ -131,7 +141,9 @@ class AbstractFilter(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
""" """
Validate and parse config. Validate and parse config.
@ -182,7 +194,9 @@ class Filter(AbstractFilter):
""" """
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
""" """
Here method ``validate`` is optional. Here method ``validate`` is optional.
If you need to use filter from filters factory you need to override this method. If you need to use filter from filters factory you need to override this method.
@ -200,16 +214,18 @@ class BoundFilter(Filter):
You need to implement ``__init__`` method with single argument related with key attribute You need to implement ``__init__`` method with single argument related with key attribute
and ``check`` method where you need to implement filter logic. and ``check`` method where you need to implement filter logic.
""" """
key = None key = None
"""Unique name of the filter argument. You need to override this attribute.""" """Unique name of the filter argument. You need to override this attribute."""
required = False required = False
"""If :obj:`True` this filter will be added to the all of the registered handlers""" """If :obj:`True` this filter will be added to the all of the registered handlers"""
default = None default = None
"""Default value for configure required filters""" """Default value for configure required filters"""
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Dict[str, typing.Any]:
""" """
If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument. If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument.
@ -226,7 +242,7 @@ class BoundFilter(Filter):
class _LogicFilter(Filter): class _LogicFilter(Filter):
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]): def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!') raise ValueError("That filter can't be used in filters factory!")
class NotFilter(_LogicFilter): class NotFilter(_LogicFilter):
@ -238,7 +254,6 @@ class NotFilter(_LogicFilter):
class AndFilter(_LogicFilter): class AndFilter(_LogicFilter):
def __init__(self, *targets): def __init__(self, *targets):
self.targets = [wrap_async(target) for target in targets] self.targets = [wrap_async(target) for target in targets]

View file

@ -26,15 +26,17 @@ class CallbackData:
Callback data factory Callback data factory
""" """
def __init__(self, prefix, *parts, sep=':'): def __init__(self, prefix, *parts, sep=":"):
if not isinstance(prefix, str): if not isinstance(prefix, str):
raise TypeError(f'Prefix must be instance of str not {type(prefix).__name__}') raise TypeError(
f"Prefix must be instance of str not {type(prefix).__name__}"
)
if not prefix: if not prefix:
raise ValueError("Prefix can't be empty") raise ValueError("Prefix can't be empty")
if sep in prefix: if sep in prefix:
raise ValueError(f"Separator {sep!r} can't be used in prefix") raise ValueError(f"Separator {sep!r} can't be used in prefix")
if not parts: if not parts:
raise TypeError('Parts were not passed!') raise TypeError("Parts were not passed!")
self.prefix = prefix self.prefix = prefix
self.sep = sep self.sep = sep
@ -59,7 +61,7 @@ class CallbackData:
if args: if args:
value = args.pop(0) value = args.pop(0)
else: else:
raise ValueError(f'Value for {part!r} was not passed!') raise ValueError(f"Value for {part!r} was not passed!")
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
value = str(value) value = str(value)
@ -67,16 +69,18 @@ class CallbackData:
if not value: if not value:
raise ValueError(f"Value for part {part!r} can't be empty!'") raise ValueError(f"Value for part {part!r} can't be empty!'")
if self.sep in value: if self.sep in value:
raise ValueError(f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values") raise ValueError(
f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values"
)
data.append(value) data.append(value)
if args or kwargs: if args or kwargs:
raise TypeError('Too many arguments were passed!') raise TypeError("Too many arguments were passed!")
callback_data = self.sep.join(data) callback_data = self.sep.join(data)
if len(callback_data.encode()) > 64: if len(callback_data.encode()) > 64:
raise ValueError('Resulted callback data is too long!') raise ValueError("Resulted callback data is too long!")
return callback_data return callback_data
@ -89,11 +93,12 @@ class CallbackData:
""" """
prefix, *parts = callback_data.split(self.sep) prefix, *parts = callback_data.split(self.sep)
if prefix != self.prefix: if prefix != self.prefix:
raise ValueError("Passed callback data can't be parsed with that prefix.") raise ValueError(
"Passed callback data can't be parsed with that prefix.")
if len(parts) != len(self._part_names): if len(parts) != len(self._part_names):
raise ValueError('Invalid parts count!') raise ValueError("Invalid parts count!")
result = {'@': prefix} result = {"@": prefix}
result.update(zip(self._part_names, parts)) result.update(zip(self._part_names, parts))
return result return result
@ -106,12 +111,11 @@ class CallbackData:
""" """
for key in config.keys(): for key in config.keys():
if key not in self._part_names: if key not in self._part_names:
raise ValueError(f'Invalid field name {key!r}') raise ValueError(f"Invalid field name {key!r}")
return CallbackDataFilter(self, config) return CallbackDataFilter(self, config)
class CallbackDataFilter(Filter): class CallbackDataFilter(Filter):
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]): def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
self.config = config self.config = config
self.factory = factory self.factory = factory
@ -133,4 +137,4 @@ class CallbackDataFilter(Filter):
else: else:
if data.get(key) != value: if data.get(key) != value:
return False return False
return {'callback_data': data} return {"callback_data": data}