Merge pull request #13 from muhammedfurkan/deepsource-fix-f820c4b0

Refactor unnecessary `else` / `elif` when `if` block has a `raise` statement
This commit is contained in:
M.Furkan 2020-11-09 01:02:44 +03:00 committed by GitHub
commit bd51d299df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 293 additions and 166 deletions

View file

@ -10,7 +10,7 @@ from ..utils import exceptions, json
from ..utils.helper import Helper, HelperMode, Item
# Main aiogram logger
log = logging.getLogger('aiogram')
log = logging.getLogger("aiogram")
@dataclass(frozen=True)
@ -43,7 +43,7 @@ class TelegramAPIServer:
return self.file.format(token=token, path=path)
@classmethod
def from_base(cls, base: str) -> 'TelegramAPIServer':
def from_base(cls, base: str) -> "TelegramAPIServer":
base = base.rstrip("/")
return cls(
base=f"{base}/bot{{token}}/{{method}}",
@ -62,17 +62,19 @@ def check_token(token: str) -> bool:
:return:
"""
if not isinstance(token, str):
message = (f"Token is invalid! "
f"It must be 'str' type instead of {type(token)} type.")
message = (
f"Token is invalid! "
f"It must be 'str' type instead of {type(token)} type."
)
raise exceptions.ValidationError(message)
if any(x.isspace() for x in token):
message = "Token is invalid! It can't contains spaces."
raise exceptions.ValidationError(message)
left, sep, right = token.partition(':')
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (not right):
raise exceptions.ValidationError('Token is invalid!')
raise exceptions.ValidationError("Token is invalid!")
return True
@ -94,24 +96,27 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
"""
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
if content_type != 'application/json':
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
if content_type != "application/json":
raise exceptions.NetworkError(
f'Invalid response with content type {content_type}: "{body}"'
)
try:
result_json = json.loads(body)
except ValueError:
result_json = {}
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
description = result_json.get("description") or body
parameters = types.ResponseParameters(
**result_json.get("parameters", {}) or {})
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
return result_json.get('result')
return result_json.get("result")
if parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
if parameters.migrate_to_chat_id:
raise exceptions.MigrateToChat(parameters.migrate_to_chat_id)
elif status_code == HTTPStatus.BAD_REQUEST:
if status_code == HTTPStatus.BAD_REQUEST:
exceptions.BadRequest.detect(description)
elif status_code == HTTPStatus.NOT_FOUND:
exceptions.NotFound.detect(description)
@ -120,26 +125,33 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
exceptions.Unauthorized.detect(description)
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
raise exceptions.NetworkError(
"File too large for uploading. "
"Check telegram api limits https://core.telegram.org/bots/api#senddocument"
)
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
if "restart" in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
async def make_request(session, server, token, method, data=None, files=None, **kwargs):
log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files)
log.debug('Make request: "%s" with data: "%r" and files "%r"',
method, data, files)
url = server.api_url(token=token, method=method)
req = compose_data(data, files)
try:
async with session.post(url, data=req, **kwargs) as response:
return check_result(method, response.content_type, response.status, await response.text())
return check_result(
method, response.content_type, response.status, await response.text()
)
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
raise exceptions.NetworkError(
f"aiohttp client throws an error: {e.__class__.__name__}: {e}"
)
def guess_filename(obj):
@ -149,8 +161,8 @@ def guess_filename(obj):
:param obj:
:return:
"""
name = getattr(obj, 'name', None)
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return os.path.basename(name)
@ -174,7 +186,9 @@ def compose_data(params=None, files=None):
if len(f) == 2:
filename, fileobj = f
else:
raise ValueError('Tuple must have exactly 2 elements: filename, fileobj')
raise ValueError(
"Tuple must have exactly 2 elements: filename, fileobj"
)
elif isinstance(f, types.InputFile):
filename, fileobj = f.filename, f.file
else:
@ -191,6 +205,7 @@ class Methods(Helper):
List is updated to Bot API 5.0
"""
mode = HelperMode.lowerCamelCase
# Getting Updates

View file

@ -18,11 +18,15 @@ ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str,
def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]:
# since "str" is also an "Iterable", we have to check for it first
if isinstance(chat_id, str):
return {int(chat_id), }
return {
int(chat_id),
}
if isinstance(chat_id, Iterable):
return {int(item) for (item) in chat_id}
# the last possible type is a single "int"
return {chat_id, }
return {
chat_id,
}
class Command(Filter):
@ -34,11 +38,14 @@ class Command(Filter):
By default this filter is registered for messages and edited messages handlers.
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False,
ignore_caption: bool = True):
def __init__(
self,
commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = "/",
ignore_case: bool = True,
ignore_mention: bool = False,
ignore_caption: bool = True,
):
"""
Filter can be initialized from filters factory or by simply creating instance of this class.
@ -69,7 +76,8 @@ class Command(Filter):
if isinstance(commands, str):
commands = (commands,)
self.commands = list(map(str.lower, commands)) if ignore_case else commands
self.commands = list(map(str.lower, commands)
) if ignore_case else commands
self.prefixes = prefixes
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
@ -91,36 +99,61 @@ class Command(Filter):
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if config and 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if config and 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
if config and 'commands_ignore_caption' in full_config:
config['ignore_caption'] = full_config.pop('commands_ignore_caption')
if "commands" in full_config:
config["commands"] = full_config.pop("commands")
if config and "commands_prefix" in full_config:
config["prefixes"] = full_config.pop("commands_prefix")
if config and "commands_ignore_mention" in full_config:
config["ignore_mention"] = full_config.pop(
"commands_ignore_mention")
if config and "commands_ignore_caption" in full_config:
config["ignore_caption"] = full_config.pop(
"commands_ignore_caption")
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention, self.ignore_caption)
return await self.check_command(
message,
self.commands,
self.prefixes,
self.ignore_case,
self.ignore_mention,
self.ignore_caption,
)
@classmethod
async def check_command(cls, message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False, ignore_caption=True):
text = message.text or (message.caption if not ignore_caption else None)
async def check_command(
cls,
message: types.Message,
commands,
prefixes,
ignore_case=True,
ignore_mention=False,
ignore_caption=True,
):
text = message.text or (
message.caption if not ignore_caption else None)
if not text:
return False
full_command = text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
prefix, (command, _,
mention) = full_command[0], full_command[1:].partition("@")
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
if (
not ignore_mention
and mention
and (await message.bot.me).username.lower() != mention.lower()
):
return False
if prefix not in prefixes:
return False
if (command.lower() if ignore_case else command) not in commands:
return False
return {'command': cls.CommandObj(command=command, prefix=prefix, mention=mention)}
return {
"command": cls.CommandObj(command=command, prefix=prefix, mention=mention)
}
@dataclass
class CommandObj:
@ -131,9 +164,9 @@ class Command(Filter):
"""
"""Command prefix"""
prefix: str = '/'
prefix: str = "/"
"""Command without prefix and mention"""
command: str = ''
command: str = ""
"""Mention (if available)"""
mention: str = None
"""Command argument"""
@ -157,9 +190,9 @@ class Command(Filter):
"""
line = self.prefix + self.command
if self.mentioned:
line += '@' + self.mention
line += "@" + self.mention
if self.args:
line += ' ' + self.args
line += " " + self.args
return line
@ -168,9 +201,12 @@ class CommandStart(Command):
This filter based on :obj:`Command` filter but can handle only ``/start`` command.
"""
def __init__(self,
deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None,
encoded: bool = False):
def __init__(
self,
deep_link: typing.Optional[typing.Union[str,
typing.Pattern[str]]] = None,
encoded: bool = False,
):
"""
Also this filter can handle `deep-linking <https://core.telegram.org/bots#deep-linking>`_ arguments.
@ -183,7 +219,7 @@ class CommandStart(Command):
:param deep_link: string or compiled regular expression (by ``re.compile(...)``).
:param encoded: set True if you're waiting for encoded payload (default - False).
"""
super().__init__(['start'])
super().__init__(["start"])
self.deep_link = deep_link
self.encoded = encoded
@ -195,17 +231,22 @@ class CommandStart(Command):
:return:
"""
from ...utils.deep_linking import decode_payload
check = await super().check(message)
if check and self.deep_link is not None:
payload = decode_payload(message.get_args()) if self.encoded else message.get_args()
payload = (
decode_payload(message.get_args())
if self.encoded
else message.get_args()
)
if not isinstance(self.deep_link, typing.Pattern):
return False if payload != self.deep_link else {'deep_link': payload}
return False if payload != self.deep_link else {"deep_link": payload}
match = self.deep_link.match(payload)
if match:
return {'deep_link': match}
return {"deep_link": match}
return False
return check
@ -217,7 +258,7 @@ class CommandHelp(Command):
"""
def __init__(self):
super().__init__(['help'])
super().__init__(["help"])
class CommandSettings(Command):
@ -226,7 +267,7 @@ class CommandSettings(Command):
"""
def __init__(self):
super().__init__(['settings'])
super().__init__(["settings"])
class CommandPrivacy(Command):
@ -235,7 +276,7 @@ class CommandPrivacy(Command):
"""
def __init__(self):
super().__init__(['privacy'])
super().__init__(["privacy"])
class Text(Filter):
@ -244,18 +285,27 @@ class Text(Filter):
"""
_default_params = (
('text', 'equals'),
('text_contains', 'contains'),
('text_startswith', 'startswith'),
('text_endswith', 'endswith'),
("text", "equals"),
("text_contains", "contains"),
("text_startswith", "startswith"),
("text_endswith", "endswith"),
)
def __init__(self,
equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
ignore_case=False):
def __init__(
self,
equals: Optional[Union[str, LazyProxy,
Iterable[Union[str, LazyProxy]]]] = None,
contains: Optional[
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
startswith: Optional[
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
endswith: Optional[
Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]
] = None,
ignore_case=False,
):
"""
Check text for one of pattern. Only one mode can be used in one filter.
In every pattern, a single string is treated as a list with 1 element.
@ -267,15 +317,24 @@ class Text(Filter):
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it.
check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith)))
check = sum(
map(lambda s: s is not None, (equals, contains, startswith, endswith))
)
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('startswith', startswith),
('endswith', endswith)
] if arg[1] is not None])
args = "' and '".join(
[
arg[0]
for arg in [
("equals", equals),
("contains", contains),
("startswith", startswith),
("endswith", endswith),
]
if arg[1] is not None
]
)
raise ValueError(f"Arguments '{args}' cannot be used together.")
elif check == 0:
if check == 0:
raise ValueError(f"No one mode is specified!")
equals, contains, endswith, startswith = map(
@ -297,7 +356,7 @@ class Text(Filter):
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
text = obj.text or obj.caption or ""
if not text and obj.poll:
text = obj.poll.question
elif isinstance(obj, CallbackQuery):
@ -311,7 +370,10 @@ class Text(Filter):
if self.ignore_case:
text = text.lower()
_pre_process_func = lambda s: str(s).lower()
def _pre_process_func(s):
return str(s).lower()
else:
_pre_process_func = str
@ -344,7 +406,7 @@ class HashTag(Filter):
def __init__(self, hashtags=None, cashtags=None):
if not hashtags and not cashtags:
raise ValueError('No one hashtag or cashtag is specified!')
raise ValueError("No one hashtag or cashtag is specified!")
if hashtags is None:
hashtags = []
@ -364,10 +426,10 @@ class HashTag(Filter):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
config = {}
if 'hashtags' in full_config:
config['hashtags'] = full_config.pop('hashtags')
if 'cashtags' in full_config:
config['cashtags'] = full_config.pop('cashtags')
if "hashtags" in full_config:
config["hashtags"] = full_config.pop("hashtags")
if "cashtags" in full_config:
config["cashtags"] = full_config.pop("cashtags")
return config
async def check(self, message: types.Message):
@ -381,9 +443,13 @@ class HashTag(Filter):
return False
hashtags, cashtags = self._get_tags(text, entities)
if self.hashtags and set(hashtags) & set(self.hashtags) \
or self.cashtags and set(cashtags) & set(self.cashtags):
return {'hashtags': hashtags, 'cashtags': cashtags}
if (
self.hashtags
and set(hashtags) & set(self.hashtags)
or self.cashtags
and set(cashtags) & set(self.cashtags)
):
return {"hashtags": hashtags, "cashtags": cashtags}
@staticmethod
def _get_tags(text, entities):
@ -392,11 +458,11 @@ class HashTag(Filter):
for entity in entities:
if entity.type == types.MessageEntityType.HASHTAG:
value = entity.get_text(text).lstrip('#')
value = entity.get_text(text).lstrip("#")
hashtags.append(value)
elif entity.type == types.MessageEntityType.CASHTAG:
value = entity.get_text(text).lstrip('$')
value = entity.get_text(text).lstrip("$")
cashtags.append(value)
return hashtags, cashtags
@ -414,12 +480,12 @@ class Regexp(Filter):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
if "regexp" in full_config:
return {"regexp": full_config.pop("regexp")}
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
if isinstance(obj, Message):
content = obj.text or obj.caption or ''
content = obj.text or obj.caption or ""
if not content and obj.poll:
content = obj.poll.question
elif isinstance(obj, CallbackQuery) and obj.data:
@ -434,7 +500,7 @@ class Regexp(Filter):
match = self.regexp.search(content)
if match:
return {'regexp': match}
return {"regexp": match}
return False
@ -443,17 +509,20 @@ class RegexpCommandsFilter(BoundFilter):
Check commands by regexp in message
"""
key = 'regexp_commands'
key = "regexp_commands"
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
self.regexp_commands = [
re.compile(command, flags=re.IGNORECASE | re.MULTILINE)
for command in regexp_commands
]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
command, _, mention = command.partition("@")
if mention and mention != (await message.bot.me).username:
return False
@ -461,7 +530,7 @@ class RegexpCommandsFilter(BoundFilter):
for command in self.regexp_commands:
search = command.search(message.text)
if search:
return {'regexp_command': search}
return {"regexp_command": search}
return False
@ -470,7 +539,7 @@ class ContentTypeFilter(BoundFilter):
Check message content type
"""
key = 'content_types'
key = "content_types"
required = True
default = types.ContentTypes.TEXT
@ -480,8 +549,10 @@ class ContentTypeFilter(BoundFilter):
self.content_types = content_types
async def check(self, message):
return types.ContentType.ANY in self.content_types or \
message.content_type in self.content_types
return (
types.ContentType.ANY in self.content_types
or message.content_type in self.content_types
)
class IsSenderContact(BoundFilter):
@ -491,7 +562,8 @@ class IsSenderContact(BoundFilter):
`is_sender_contact=True` - contact matches the sender
`is_sender_contact=False` - result will be inverted
"""
key = 'is_sender_contact'
key = "is_sender_contact"
def __init__(self, is_sender_contact: bool):
self.is_sender_contact = is_sender_contact
@ -509,10 +581,11 @@ class StateFilter(BoundFilter):
"""
Check user state
"""
key = 'state'
key = "state"
required = True
ctx_state = ContextVar('user_state')
ctx_state = ContextVar("user_state")
def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup
@ -520,7 +593,9 @@ class StateFilter(BoundFilter):
self.dispatcher = dispatcher
states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ]
state = [
state,
]
for item in state:
if isinstance(item, State):
states.append(item.state)
@ -532,11 +607,13 @@ class StateFilter(BoundFilter):
@staticmethod
def get_target(obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
return getattr(getattr(obj, "chat", None), "id", None), getattr(
getattr(obj, "from_user", None), "id", None
)
async def check(self, obj):
if '*' in self.states:
return {'state': self.dispatcher.current_state()}
if "*" in self.states:
return {"state": self.dispatcher.current_state()}
try:
state = self.ctx_state.get()
@ -547,11 +624,14 @@ class StateFilter(BoundFilter):
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state)
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return {
"state": self.dispatcher.current_state(),
"raw_state": state,
}
else:
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return {"state": self.dispatcher.current_state(), "raw_state": state}
return False
@ -561,7 +641,7 @@ class ExceptionsFilter(BoundFilter):
Filter for exceptions
"""
key = 'exception'
key = "exception"
def __init__(self, exception):
self.exception = exception
@ -576,10 +656,11 @@ class ExceptionsFilter(BoundFilter):
class IDFilter(Filter):
def __init__(self,
user_id: Optional[ChatIDArgumentType] = None,
chat_id: Optional[ChatIDArgumentType] = None,
):
def __init__(
self,
user_id: Optional[ChatIDArgumentType] = None,
chat_id: Optional[ChatIDArgumentType] = None,
):
"""
:param user_id:
:param chat_id:
@ -597,13 +678,15 @@ class IDFilter(Filter):
self.chat_id = extract_chat_ids(chat_id)
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
result = {}
if 'user_id' in full_config:
result['user_id'] = full_config.pop('user_id')
if "user_id" in full_config:
result["user_id"] = full_config.pop("user_id")
if 'chat_id' in full_config:
result['chat_id'] = full_config.pop('chat_id')
if "chat_id" in full_config:
result["chat_id"] = full_config.pop("chat_id")
return result
@ -658,7 +741,9 @@ class AdminFilter(Filter):
self._chat_ids = extract_chat_ids(is_chat_admin)
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
result = {}
if "is_chat_admin" in full_config:
@ -676,13 +761,19 @@ class AdminFilter(Filter):
message = obj.message
else:
return False
if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats
if (
message.chat.type == ChatType.PRIVATE
): # there is no admin in private chats
return False
chat_ids = [message.chat.id]
else:
chat_ids = self._chat_ids
admins = [member.user.id for chat_id in chat_ids for member in await obj.bot.get_chat_administrators(chat_id)]
admins = [
member.user.id
for chat_id in chat_ids
for member in await obj.bot.get_chat_administrators(chat_id)
]
return user_id in admins
@ -691,20 +782,21 @@ class IsReplyFilter(BoundFilter):
"""
Check if message is replied and send reply message to handler
"""
key = 'is_reply'
key = "is_reply"
def __init__(self, is_reply):
self.is_reply = is_reply
async def check(self, msg: Message):
if msg.reply_to_message and self.is_reply:
return {'reply': msg.reply_to_message}
return {"reply": msg.reply_to_message}
if not msg.reply_to_message and not self.is_reply:
return True
class ForwardedMessageFilter(BoundFilter):
key = 'is_forwarded'
key = "is_forwarded"
def __init__(self, is_forwarded: bool):
self.is_forwarded = is_forwarded
@ -714,7 +806,7 @@ class ForwardedMessageFilter(BoundFilter):
class ChatTypeFilter(BoundFilter):
key = 'chat_type'
key = "chat_type"
def __init__(self, chat_type: typing.Container[ChatType]):
if isinstance(chat_type, str):
@ -728,7 +820,8 @@ class ChatTypeFilter(BoundFilter):
elif isinstance(obj, CallbackQuery):
obj = obj.message.chat
else:
warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj))
warnings.warn(
"ChatTypeFilter doesn't support %s as input", type(obj))
return False
return obj.type in self.chat_type

View file

@ -13,9 +13,11 @@ def wrap_async(func):
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
if inspect.isawaitable(func) \
or inspect.iscoroutinefunction(func) \
or isinstance(func, AbstractFilter):
if (
inspect.isawaitable(func)
or inspect.iscoroutinefunction(func)
or isinstance(func, AbstractFilter)
):
return func
return async_wrapper
@ -23,14 +25,16 @@ def wrap_async(func):
def get_filter_spec(dispatcher, filter_: callable):
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
raise TypeError("Filter must be callable and/or awaitable!")
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
if "dispatcher" in spec:
kwargs["dispatcher"] = dispatcher
if (
inspect.isawaitable(filter_)
or inspect.iscoroutinefunction(filter_)
or isinstance(filter_, AbstractFilter)
):
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
@ -70,7 +74,7 @@ async def check_filters(filters: typing.Iterable[FilterObj], args):
f = await execute_filter(filter_, args)
if not f:
raise FilterNotPassed()
elif isinstance(f, dict):
if isinstance(f, dict):
data.update(f)
return data
@ -80,12 +84,17 @@ class FilterRecord:
Filters record for factory
"""
def __init__(self, callback: typing.Union[typing.Callable, 'AbstractFilter'],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
def __init__(
self,
callback: typing.Union[typing.Callable, "AbstractFilter"],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
):
if event_handlers and exclude_event_handlers:
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
raise ValueError(
"'event_handlers' and 'exclude_event_handlers' arguments cannot be used together."
)
self.callback = callback
self.event_handlers = event_handlers
@ -93,22 +102,23 @@ class FilterRecord:
if validator is not None:
if not callable(validator):
raise TypeError(f"validator must be callable, not {type(validator)}")
raise TypeError(
f"validator must be callable, not {type(validator)}")
self.resolver = validator
elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate
else:
raise RuntimeError('validator is required!')
raise RuntimeError("validator is required!")
def resolve(self, dispatcher, event_handler, full_config):
if not self._check_event_handler(event_handler):
return
config = self.resolver(full_config)
if config:
if 'dispatcher' not in config:
if "dispatcher" not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
if "dispatcher" in spec.args:
config["dispatcher"] = dispatcher
for key in config:
if key in full_config:
@ -131,7 +141,9 @@ class AbstractFilter(abc.ABC):
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Validate and parse config.
@ -182,7 +194,9 @@ class Filter(AbstractFilter):
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Here method ``validate`` is optional.
If you need to use filter from filters factory you need to override this method.
@ -200,16 +214,18 @@ class BoundFilter(Filter):
You need to implement ``__init__`` method with single argument related with key attribute
and ``check`` method where you need to implement filter logic.
"""
key = None
"""Unique name of the filter argument. You need to override this attribute."""
required = False
"""If :obj:`True` this filter will be added to the all of the registered handlers"""
default = None
"""Default value for configure required filters"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Dict[str, typing.Any]:
"""
If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument.
@ -226,7 +242,7 @@ class BoundFilter(Filter):
class _LogicFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
raise ValueError("That filter can't be used in filters factory!")
class NotFilter(_LogicFilter):
@ -238,7 +254,6 @@ class NotFilter(_LogicFilter):
class AndFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = [wrap_async(target) for target in targets]

View file

@ -26,15 +26,17 @@ class CallbackData:
Callback data factory
"""
def __init__(self, prefix, *parts, sep=':'):
def __init__(self, prefix, *parts, sep=":"):
if not isinstance(prefix, str):
raise TypeError(f'Prefix must be instance of str not {type(prefix).__name__}')
raise TypeError(
f"Prefix must be instance of str not {type(prefix).__name__}"
)
if not prefix:
raise ValueError("Prefix can't be empty")
if sep in prefix:
raise ValueError(f"Separator {sep!r} can't be used in prefix")
if not parts:
raise TypeError('Parts were not passed!')
raise TypeError("Parts were not passed!")
self.prefix = prefix
self.sep = sep
@ -59,7 +61,7 @@ class CallbackData:
if args:
value = args.pop(0)
else:
raise ValueError(f'Value for {part!r} was not passed!')
raise ValueError(f"Value for {part!r} was not passed!")
if value is not None and not isinstance(value, str):
value = str(value)
@ -67,16 +69,18 @@ class CallbackData:
if not value:
raise ValueError(f"Value for part {part!r} can't be empty!'")
if self.sep in value:
raise ValueError(f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values")
raise ValueError(
f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values"
)
data.append(value)
if args or kwargs:
raise TypeError('Too many arguments were passed!')
raise TypeError("Too many arguments were passed!")
callback_data = self.sep.join(data)
if len(callback_data.encode()) > 64:
raise ValueError('Resulted callback data is too long!')
raise ValueError("Resulted callback data is too long!")
return callback_data
@ -89,11 +93,12 @@ class CallbackData:
"""
prefix, *parts = callback_data.split(self.sep)
if prefix != self.prefix:
raise ValueError("Passed callback data can't be parsed with that prefix.")
elif len(parts) != len(self._part_names):
raise ValueError('Invalid parts count!')
raise ValueError(
"Passed callback data can't be parsed with that prefix.")
if len(parts) != len(self._part_names):
raise ValueError("Invalid parts count!")
result = {'@': prefix}
result = {"@": prefix}
result.update(zip(self._part_names, parts))
return result
@ -106,12 +111,11 @@ class CallbackData:
"""
for key in config.keys():
if key not in self._part_names:
raise ValueError(f'Invalid field name {key!r}')
raise ValueError(f"Invalid field name {key!r}")
return CallbackDataFilter(self, config)
class CallbackDataFilter(Filter):
def __init__(self, factory: CallbackData, config: typing.Dict[str, str]):
self.config = config
self.factory = factory
@ -133,4 +137,4 @@ class CallbackDataFilter(Filter):
else:
if data.get(key) != value:
return False
return {'callback_data': data}
return {"callback_data": data}