diff --git a/aiogram/bot/api.py b/aiogram/bot/api.py index d0639aa2..ee3bd21b 100644 --- a/aiogram/bot/api.py +++ b/aiogram/bot/api.py @@ -10,7 +10,7 @@ from ..utils import exceptions, json from ..utils.helper import Helper, HelperMode, Item # Main aiogram logger -log = logging.getLogger('aiogram') +log = logging.getLogger("aiogram") @dataclass(frozen=True) @@ -43,7 +43,7 @@ class TelegramAPIServer: return self.file.format(token=token, path=path) @classmethod - def from_base(cls, base: str) -> 'TelegramAPIServer': + def from_base(cls, base: str) -> "TelegramAPIServer": base = base.rstrip("/") return cls( base=f"{base}/bot{{token}}/{{method}}", @@ -62,17 +62,19 @@ def check_token(token: str) -> bool: :return: """ if not isinstance(token, str): - message = (f"Token is invalid! " - f"It must be 'str' type instead of {type(token)} type.") + message = ( + f"Token is invalid! " + f"It must be 'str' type instead of {type(token)} type." + ) raise exceptions.ValidationError(message) if any(x.isspace() for x in token): message = "Token is invalid! It can't contains spaces." raise exceptions.ValidationError(message) - left, sep, right = token.partition(':') + left, sep, right = token.partition(":") if (not sep) or (not left.isdigit()) or (not right): - raise exceptions.ValidationError('Token is invalid!') + raise exceptions.ValidationError("Token is invalid!") return True @@ -94,24 +96,27 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st """ log.debug('Response for %s: [%d] "%r"', method_name, status_code, body) - if content_type != 'application/json': - raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"") + if content_type != "application/json": + raise exceptions.NetworkError( + f'Invalid response with content type {content_type}: "{body}"' + ) try: result_json = json.loads(body) except ValueError: result_json = {} - description = result_json.get('description') or body - parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {}) + description = result_json.get("description") or body + parameters = types.ResponseParameters( + **result_json.get("parameters", {}) or {}) if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED: - return result_json.get('result') + return result_json.get("result") if parameters.retry_after: raise exceptions.RetryAfter(parameters.retry_after) - elif parameters.migrate_to_chat_id: + if parameters.migrate_to_chat_id: raise exceptions.MigrateToChat(parameters.migrate_to_chat_id) - elif status_code == HTTPStatus.BAD_REQUEST: + if status_code == HTTPStatus.BAD_REQUEST: exceptions.BadRequest.detect(description) elif status_code == HTTPStatus.NOT_FOUND: exceptions.NotFound.detect(description) @@ -120,26 +125,33 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st elif status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): exceptions.Unauthorized.detect(description) elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: - raise exceptions.NetworkError('File too large for uploading. ' - 'Check telegram api limits https://core.telegram.org/bots/api#senddocument') + raise exceptions.NetworkError( + "File too large for uploading. " + "Check telegram api limits https://core.telegram.org/bots/api#senddocument" + ) elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: - if 'restart' in description: + if "restart" in description: raise exceptions.RestartingTelegram() raise exceptions.TelegramAPIError(description) raise exceptions.TelegramAPIError(f"{description} [{status_code}]") async def make_request(session, server, token, method, data=None, files=None, **kwargs): - log.debug('Make request: "%s" with data: "%r" and files "%r"', method, data, files) + log.debug('Make request: "%s" with data: "%r" and files "%r"', + method, data, files) url = server.api_url(token=token, method=method) req = compose_data(data, files) try: async with session.post(url, data=req, **kwargs) as response: - return check_result(method, response.content_type, response.status, await response.text()) + return check_result( + method, response.content_type, response.status, await response.text() + ) except aiohttp.ClientError as e: - raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}") + raise exceptions.NetworkError( + f"aiohttp client throws an error: {e.__class__.__name__}: {e}" + ) def guess_filename(obj): @@ -149,8 +161,8 @@ def guess_filename(obj): :param obj: :return: """ - name = getattr(obj, 'name', None) - if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>': + name = getattr(obj, "name", None) + if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": return os.path.basename(name) @@ -174,7 +186,9 @@ def compose_data(params=None, files=None): if len(f) == 2: filename, fileobj = f else: - raise ValueError('Tuple must have exactly 2 elements: filename, fileobj') + raise ValueError( + "Tuple must have exactly 2 elements: filename, fileobj" + ) elif isinstance(f, types.InputFile): filename, fileobj = f.filename, f.file else: @@ -191,6 +205,7 @@ class Methods(Helper): List is updated to Bot API 5.0 """ + mode = HelperMode.lowerCamelCase # Getting Updates diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 2a23e8fc..e0c456b7 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -18,11 +18,15 @@ ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str, def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]: # since "str" is also an "Iterable", we have to check for it first if isinstance(chat_id, str): - return {int(chat_id), } + return { + int(chat_id), + } if isinstance(chat_id, Iterable): return {int(item) for (item) in chat_id} # the last possible type is a single "int" - return {chat_id, } + return { + chat_id, + } class Command(Filter): @@ -34,11 +38,14 @@ class Command(Filter): By default this filter is registered for messages and edited messages handlers. """ - def __init__(self, commands: Union[Iterable, str], - prefixes: Union[Iterable, str] = '/', - ignore_case: bool = True, - ignore_mention: bool = False, - ignore_caption: bool = True): + def __init__( + self, + commands: Union[Iterable, str], + prefixes: Union[Iterable, str] = "/", + ignore_case: bool = True, + ignore_mention: bool = False, + ignore_caption: bool = True, + ): """ Filter can be initialized from filters factory or by simply creating instance of this class. @@ -69,7 +76,8 @@ class Command(Filter): if isinstance(commands, str): commands = (commands,) - self.commands = list(map(str.lower, commands)) if ignore_case else commands + self.commands = list(map(str.lower, commands) + ) if ignore_case else commands self.prefixes = prefixes self.ignore_case = ignore_case self.ignore_mention = ignore_mention @@ -91,36 +99,61 @@ class Command(Filter): :return: config or empty dict """ config = {} - if 'commands' in full_config: - config['commands'] = full_config.pop('commands') - if config and 'commands_prefix' in full_config: - config['prefixes'] = full_config.pop('commands_prefix') - if config and 'commands_ignore_mention' in full_config: - config['ignore_mention'] = full_config.pop('commands_ignore_mention') - if config and 'commands_ignore_caption' in full_config: - config['ignore_caption'] = full_config.pop('commands_ignore_caption') + if "commands" in full_config: + config["commands"] = full_config.pop("commands") + if config and "commands_prefix" in full_config: + config["prefixes"] = full_config.pop("commands_prefix") + if config and "commands_ignore_mention" in full_config: + config["ignore_mention"] = full_config.pop( + "commands_ignore_mention") + if config and "commands_ignore_caption" in full_config: + config["ignore_caption"] = full_config.pop( + "commands_ignore_caption") return config async def check(self, message: types.Message): - return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention, self.ignore_caption) + return await self.check_command( + message, + self.commands, + self.prefixes, + self.ignore_case, + self.ignore_mention, + self.ignore_caption, + ) @classmethod - async def check_command(cls, message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False, ignore_caption=True): - text = message.text or (message.caption if not ignore_caption else None) + async def check_command( + cls, + message: types.Message, + commands, + prefixes, + ignore_case=True, + ignore_mention=False, + ignore_caption=True, + ): + text = message.text or ( + message.caption if not ignore_caption else None) if not text: return False full_command = text.split()[0] - prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@') + prefix, (command, _, + mention) = full_command[0], full_command[1:].partition("@") - if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower(): + if ( + not ignore_mention + and mention + and (await message.bot.me).username.lower() != mention.lower() + ): return False if prefix not in prefixes: return False if (command.lower() if ignore_case else command) not in commands: return False - return {'command': cls.CommandObj(command=command, prefix=prefix, mention=mention)} + return { + "command": cls.CommandObj(command=command, prefix=prefix, mention=mention) + } @dataclass class CommandObj: @@ -131,9 +164,9 @@ class Command(Filter): """ """Command prefix""" - prefix: str = '/' + prefix: str = "/" """Command without prefix and mention""" - command: str = '' + command: str = "" """Mention (if available)""" mention: str = None """Command argument""" @@ -157,9 +190,9 @@ class Command(Filter): """ line = self.prefix + self.command if self.mentioned: - line += '@' + self.mention + line += "@" + self.mention if self.args: - line += ' ' + self.args + line += " " + self.args return line @@ -168,9 +201,12 @@ class CommandStart(Command): This filter based on :obj:`Command` filter but can handle only ``/start`` command. """ - def __init__(self, - deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None, - encoded: bool = False): + def __init__( + self, + deep_link: typing.Optional[typing.Union[str, + typing.Pattern[str]]] = None, + encoded: bool = False, + ): """ Also this filter can handle `deep-linking `_ arguments. @@ -183,7 +219,7 @@ class CommandStart(Command): :param deep_link: string or compiled regular expression (by ``re.compile(...)``). :param encoded: set True if you're waiting for encoded payload (default - False). """ - super().__init__(['start']) + super().__init__(["start"]) self.deep_link = deep_link self.encoded = encoded @@ -195,17 +231,22 @@ class CommandStart(Command): :return: """ from ...utils.deep_linking import decode_payload + check = await super().check(message) if check and self.deep_link is not None: - payload = decode_payload(message.get_args()) if self.encoded else message.get_args() + payload = ( + decode_payload(message.get_args()) + if self.encoded + else message.get_args() + ) if not isinstance(self.deep_link, typing.Pattern): - return False if payload != self.deep_link else {'deep_link': payload} + return False if payload != self.deep_link else {"deep_link": payload} match = self.deep_link.match(payload) if match: - return {'deep_link': match} + return {"deep_link": match} return False return check @@ -217,7 +258,7 @@ class CommandHelp(Command): """ def __init__(self): - super().__init__(['help']) + super().__init__(["help"]) class CommandSettings(Command): @@ -226,7 +267,7 @@ class CommandSettings(Command): """ def __init__(self): - super().__init__(['settings']) + super().__init__(["settings"]) class CommandPrivacy(Command): @@ -235,7 +276,7 @@ class CommandPrivacy(Command): """ def __init__(self): - super().__init__(['privacy']) + super().__init__(["privacy"]) class Text(Filter): @@ -244,18 +285,27 @@ class Text(Filter): """ _default_params = ( - ('text', 'equals'), - ('text_contains', 'contains'), - ('text_startswith', 'startswith'), - ('text_endswith', 'endswith'), + ("text", "equals"), + ("text_contains", "contains"), + ("text_startswith", "startswith"), + ("text_endswith", "endswith"), ) - def __init__(self, - equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, - contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, - startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, - endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, - ignore_case=False): + def __init__( + self, + equals: Optional[Union[str, LazyProxy, + Iterable[Union[str, LazyProxy]]]] = None, + contains: Optional[ + Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]] + ] = None, + startswith: Optional[ + Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]] + ] = None, + endswith: Optional[ + Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]] + ] = None, + ignore_case=False, + ): """ Check text for one of pattern. Only one mode can be used in one filter. In every pattern, a single string is treated as a list with 1 element. @@ -267,15 +317,24 @@ class Text(Filter): :param ignore_case: case insensitive """ # Only one mode can be used. check it. - check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith))) + check = sum( + map(lambda s: s is not None, (equals, contains, startswith, endswith)) + ) if check > 1: - args = "' and '".join([arg[0] for arg in [('equals', equals), - ('contains', contains), - ('startswith', startswith), - ('endswith', endswith) - ] if arg[1] is not None]) + args = "' and '".join( + [ + arg[0] + for arg in [ + ("equals", equals), + ("contains", contains), + ("startswith", startswith), + ("endswith", endswith), + ] + if arg[1] is not None + ] + ) raise ValueError(f"Arguments '{args}' cannot be used together.") - elif check == 0: + if check == 0: raise ValueError(f"No one mode is specified!") equals, contains, endswith, startswith = map( @@ -297,7 +356,7 @@ class Text(Filter): async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]): if isinstance(obj, Message): - text = obj.text or obj.caption or '' + text = obj.text or obj.caption or "" if not text and obj.poll: text = obj.poll.question elif isinstance(obj, CallbackQuery): @@ -311,7 +370,10 @@ class Text(Filter): if self.ignore_case: text = text.lower() - _pre_process_func = lambda s: str(s).lower() + + def _pre_process_func(s): + return str(s).lower() + else: _pre_process_func = str @@ -344,7 +406,7 @@ class HashTag(Filter): def __init__(self, hashtags=None, cashtags=None): if not hashtags and not cashtags: - raise ValueError('No one hashtag or cashtag is specified!') + raise ValueError("No one hashtag or cashtag is specified!") if hashtags is None: hashtags = [] @@ -364,10 +426,10 @@ class HashTag(Filter): @classmethod def validate(cls, full_config: Dict[str, Any]): config = {} - if 'hashtags' in full_config: - config['hashtags'] = full_config.pop('hashtags') - if 'cashtags' in full_config: - config['cashtags'] = full_config.pop('cashtags') + if "hashtags" in full_config: + config["hashtags"] = full_config.pop("hashtags") + if "cashtags" in full_config: + config["cashtags"] = full_config.pop("cashtags") return config async def check(self, message: types.Message): @@ -381,9 +443,13 @@ class HashTag(Filter): return False hashtags, cashtags = self._get_tags(text, entities) - if self.hashtags and set(hashtags) & set(self.hashtags) \ - or self.cashtags and set(cashtags) & set(self.cashtags): - return {'hashtags': hashtags, 'cashtags': cashtags} + if ( + self.hashtags + and set(hashtags) & set(self.hashtags) + or self.cashtags + and set(cashtags) & set(self.cashtags) + ): + return {"hashtags": hashtags, "cashtags": cashtags} @staticmethod def _get_tags(text, entities): @@ -392,11 +458,11 @@ class HashTag(Filter): for entity in entities: if entity.type == types.MessageEntityType.HASHTAG: - value = entity.get_text(text).lstrip('#') + value = entity.get_text(text).lstrip("#") hashtags.append(value) elif entity.type == types.MessageEntityType.CASHTAG: - value = entity.get_text(text).lstrip('$') + value = entity.get_text(text).lstrip("$") cashtags.append(value) return hashtags, cashtags @@ -414,12 +480,12 @@ class Regexp(Filter): @classmethod def validate(cls, full_config: Dict[str, Any]): - if 'regexp' in full_config: - return {'regexp': full_config.pop('regexp')} + if "regexp" in full_config: + return {"regexp": full_config.pop("regexp")} async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]): if isinstance(obj, Message): - content = obj.text or obj.caption or '' + content = obj.text or obj.caption or "" if not content and obj.poll: content = obj.poll.question elif isinstance(obj, CallbackQuery) and obj.data: @@ -434,7 +500,7 @@ class Regexp(Filter): match = self.regexp.search(content) if match: - return {'regexp': match} + return {"regexp": match} return False @@ -443,17 +509,20 @@ class RegexpCommandsFilter(BoundFilter): Check commands by regexp in message """ - key = 'regexp_commands' + key = "regexp_commands" def __init__(self, regexp_commands): - self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] + self.regexp_commands = [ + re.compile(command, flags=re.IGNORECASE | re.MULTILINE) + for command in regexp_commands + ] async def check(self, message): if not message.is_command(): return False command = message.text.split()[0][1:] - command, _, mention = command.partition('@') + command, _, mention = command.partition("@") if mention and mention != (await message.bot.me).username: return False @@ -461,7 +530,7 @@ class RegexpCommandsFilter(BoundFilter): for command in self.regexp_commands: search = command.search(message.text) if search: - return {'regexp_command': search} + return {"regexp_command": search} return False @@ -470,7 +539,7 @@ class ContentTypeFilter(BoundFilter): Check message content type """ - key = 'content_types' + key = "content_types" required = True default = types.ContentTypes.TEXT @@ -480,8 +549,10 @@ class ContentTypeFilter(BoundFilter): self.content_types = content_types async def check(self, message): - return types.ContentType.ANY in self.content_types or \ - message.content_type in self.content_types + return ( + types.ContentType.ANY in self.content_types + or message.content_type in self.content_types + ) class IsSenderContact(BoundFilter): @@ -491,7 +562,8 @@ class IsSenderContact(BoundFilter): `is_sender_contact=True` - contact matches the sender `is_sender_contact=False` - result will be inverted """ - key = 'is_sender_contact' + + key = "is_sender_contact" def __init__(self, is_sender_contact: bool): self.is_sender_contact = is_sender_contact @@ -509,10 +581,11 @@ class StateFilter(BoundFilter): """ Check user state """ - key = 'state' + + key = "state" required = True - ctx_state = ContextVar('user_state') + ctx_state = ContextVar("user_state") def __init__(self, dispatcher, state): from aiogram.dispatcher.filters.state import State, StatesGroup @@ -520,7 +593,9 @@ class StateFilter(BoundFilter): self.dispatcher = dispatcher states = [] if not isinstance(state, (list, set, tuple, frozenset)) or state is None: - state = [state, ] + state = [ + state, + ] for item in state: if isinstance(item, State): states.append(item.state) @@ -532,11 +607,13 @@ class StateFilter(BoundFilter): @staticmethod def get_target(obj): - return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + return getattr(getattr(obj, "chat", None), "id", None), getattr( + getattr(obj, "from_user", None), "id", None + ) async def check(self, obj): - if '*' in self.states: - return {'state': self.dispatcher.current_state()} + if "*" in self.states: + return {"state": self.dispatcher.current_state()} try: state = self.ctx_state.get() @@ -547,11 +624,14 @@ class StateFilter(BoundFilter): state = await self.dispatcher.storage.get_state(chat=chat, user=user) self.ctx_state.set(state) if state in self.states: - return {'state': self.dispatcher.current_state(), 'raw_state': state} + return { + "state": self.dispatcher.current_state(), + "raw_state": state, + } else: if state in self.states: - return {'state': self.dispatcher.current_state(), 'raw_state': state} + return {"state": self.dispatcher.current_state(), "raw_state": state} return False @@ -561,7 +641,7 @@ class ExceptionsFilter(BoundFilter): Filter for exceptions """ - key = 'exception' + key = "exception" def __init__(self, exception): self.exception = exception @@ -576,10 +656,11 @@ class ExceptionsFilter(BoundFilter): class IDFilter(Filter): - def __init__(self, - user_id: Optional[ChatIDArgumentType] = None, - chat_id: Optional[ChatIDArgumentType] = None, - ): + def __init__( + self, + user_id: Optional[ChatIDArgumentType] = None, + chat_id: Optional[ChatIDArgumentType] = None, + ): """ :param user_id: :param chat_id: @@ -597,13 +678,15 @@ class IDFilter(Filter): self.chat_id = extract_chat_ids(chat_id) @classmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + def validate( + cls, full_config: typing.Dict[str, typing.Any] + ) -> typing.Optional[typing.Dict[str, typing.Any]]: result = {} - if 'user_id' in full_config: - result['user_id'] = full_config.pop('user_id') + if "user_id" in full_config: + result["user_id"] = full_config.pop("user_id") - if 'chat_id' in full_config: - result['chat_id'] = full_config.pop('chat_id') + if "chat_id" in full_config: + result["chat_id"] = full_config.pop("chat_id") return result @@ -658,7 +741,9 @@ class AdminFilter(Filter): self._chat_ids = extract_chat_ids(is_chat_admin) @classmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + def validate( + cls, full_config: typing.Dict[str, typing.Any] + ) -> typing.Optional[typing.Dict[str, typing.Any]]: result = {} if "is_chat_admin" in full_config: @@ -676,13 +761,19 @@ class AdminFilter(Filter): message = obj.message else: return False - if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats + if ( + message.chat.type == ChatType.PRIVATE + ): # there is no admin in private chats return False chat_ids = [message.chat.id] else: chat_ids = self._chat_ids - admins = [member.user.id for chat_id in chat_ids for member in await obj.bot.get_chat_administrators(chat_id)] + admins = [ + member.user.id + for chat_id in chat_ids + for member in await obj.bot.get_chat_administrators(chat_id) + ] return user_id in admins @@ -691,20 +782,21 @@ class IsReplyFilter(BoundFilter): """ Check if message is replied and send reply message to handler """ - key = 'is_reply' + + key = "is_reply" def __init__(self, is_reply): self.is_reply = is_reply async def check(self, msg: Message): if msg.reply_to_message and self.is_reply: - return {'reply': msg.reply_to_message} + return {"reply": msg.reply_to_message} if not msg.reply_to_message and not self.is_reply: return True class ForwardedMessageFilter(BoundFilter): - key = 'is_forwarded' + key = "is_forwarded" def __init__(self, is_forwarded: bool): self.is_forwarded = is_forwarded @@ -714,7 +806,7 @@ class ForwardedMessageFilter(BoundFilter): class ChatTypeFilter(BoundFilter): - key = 'chat_type' + key = "chat_type" def __init__(self, chat_type: typing.Container[ChatType]): if isinstance(chat_type, str): @@ -728,7 +820,8 @@ class ChatTypeFilter(BoundFilter): elif isinstance(obj, CallbackQuery): obj = obj.message.chat else: - warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj)) + warnings.warn( + "ChatTypeFilter doesn't support %s as input", type(obj)) return False return obj.type in self.chat_type diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index 09ab2477..210b9a85 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -13,9 +13,11 @@ def wrap_async(func): async def async_wrapper(*args, **kwargs): return func(*args, **kwargs) - if inspect.isawaitable(func) \ - or inspect.iscoroutinefunction(func) \ - or isinstance(func, AbstractFilter): + if ( + inspect.isawaitable(func) + or inspect.iscoroutinefunction(func) + or isinstance(func, AbstractFilter) + ): return func return async_wrapper @@ -23,14 +25,16 @@ def wrap_async(func): def get_filter_spec(dispatcher, filter_: callable): kwargs = {} if not callable(filter_): - raise TypeError('Filter must be callable and/or awaitable!') + raise TypeError("Filter must be callable and/or awaitable!") spec = inspect.getfullargspec(filter_) - if 'dispatcher' in spec: - kwargs['dispatcher'] = dispatcher - if inspect.isawaitable(filter_) \ - or inspect.iscoroutinefunction(filter_) \ - or isinstance(filter_, AbstractFilter): + if "dispatcher" in spec: + kwargs["dispatcher"] = dispatcher + if ( + inspect.isawaitable(filter_) + or inspect.iscoroutinefunction(filter_) + or isinstance(filter_, AbstractFilter) + ): return FilterObj(filter=filter_, kwargs=kwargs, is_async=True) return FilterObj(filter=filter_, kwargs=kwargs, is_async=False) @@ -70,7 +74,7 @@ async def check_filters(filters: typing.Iterable[FilterObj], args): f = await execute_filter(filter_, args) if not f: raise FilterNotPassed() - elif isinstance(f, dict): + if isinstance(f, dict): data.update(f) return data @@ -80,12 +84,17 @@ class FilterRecord: Filters record for factory """ - def __init__(self, callback: typing.Union[typing.Callable, 'AbstractFilter'], - validator: typing.Optional[typing.Callable] = None, - event_handlers: typing.Optional[typing.Iterable[Handler]] = None, - exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None): + def __init__( + self, + callback: typing.Union[typing.Callable, "AbstractFilter"], + validator: typing.Optional[typing.Callable] = None, + event_handlers: typing.Optional[typing.Iterable[Handler]] = None, + exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None, + ): if event_handlers and exclude_event_handlers: - raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.") + raise ValueError( + "'event_handlers' and 'exclude_event_handlers' arguments cannot be used together." + ) self.callback = callback self.event_handlers = event_handlers @@ -93,22 +102,23 @@ class FilterRecord: if validator is not None: if not callable(validator): - raise TypeError(f"validator must be callable, not {type(validator)}") + raise TypeError( + f"validator must be callable, not {type(validator)}") self.resolver = validator elif issubclass(callback, AbstractFilter): self.resolver = callback.validate else: - raise RuntimeError('validator is required!') + raise RuntimeError("validator is required!") def resolve(self, dispatcher, event_handler, full_config): if not self._check_event_handler(event_handler): return config = self.resolver(full_config) if config: - if 'dispatcher' not in config: + if "dispatcher" not in config: spec = inspect.getfullargspec(self.callback) - if 'dispatcher' in spec.args: - config['dispatcher'] = dispatcher + if "dispatcher" in spec.args: + config["dispatcher"] = dispatcher for key in config: if key in full_config: @@ -131,7 +141,9 @@ class AbstractFilter(abc.ABC): @classmethod @abc.abstractmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + def validate( + cls, full_config: typing.Dict[str, typing.Any] + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """ Validate and parse config. @@ -182,7 +194,9 @@ class Filter(AbstractFilter): """ @classmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + def validate( + cls, full_config: typing.Dict[str, typing.Any] + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """ Here method ``validate`` is optional. If you need to use filter from filters factory you need to override this method. @@ -200,16 +214,18 @@ class BoundFilter(Filter): You need to implement ``__init__`` method with single argument related with key attribute and ``check`` method where you need to implement filter logic. """ - + key = None """Unique name of the filter argument. You need to override this attribute.""" required = False """If :obj:`True` this filter will be added to the all of the registered handlers""" default = None """Default value for configure required filters""" - + @classmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: + def validate( + cls, full_config: typing.Dict[str, typing.Any] + ) -> typing.Dict[str, typing.Any]: """ If ``cls.key`` is not :obj:`None` and that is in config returns config with that argument. @@ -226,7 +242,7 @@ class BoundFilter(Filter): class _LogicFilter(Filter): @classmethod def validate(cls, full_config: typing.Dict[str, typing.Any]): - raise ValueError('That filter can\'t be used in filters factory!') + raise ValueError("That filter can't be used in filters factory!") class NotFilter(_LogicFilter): @@ -238,7 +254,6 @@ class NotFilter(_LogicFilter): class AndFilter(_LogicFilter): - def __init__(self, *targets): self.targets = [wrap_async(target) for target in targets] diff --git a/aiogram/utils/callback_data.py b/aiogram/utils/callback_data.py index e24ad7b1..09df040b 100644 --- a/aiogram/utils/callback_data.py +++ b/aiogram/utils/callback_data.py @@ -26,15 +26,17 @@ class CallbackData: Callback data factory """ - def __init__(self, prefix, *parts, sep=':'): + def __init__(self, prefix, *parts, sep=":"): if not isinstance(prefix, str): - raise TypeError(f'Prefix must be instance of str not {type(prefix).__name__}') + raise TypeError( + f"Prefix must be instance of str not {type(prefix).__name__}" + ) if not prefix: raise ValueError("Prefix can't be empty") if sep in prefix: raise ValueError(f"Separator {sep!r} can't be used in prefix") if not parts: - raise TypeError('Parts were not passed!') + raise TypeError("Parts were not passed!") self.prefix = prefix self.sep = sep @@ -59,7 +61,7 @@ class CallbackData: if args: value = args.pop(0) else: - raise ValueError(f'Value for {part!r} was not passed!') + raise ValueError(f"Value for {part!r} was not passed!") if value is not None and not isinstance(value, str): value = str(value) @@ -67,16 +69,18 @@ class CallbackData: if not value: raise ValueError(f"Value for part {part!r} can't be empty!'") if self.sep in value: - raise ValueError(f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values") + raise ValueError( + f"Symbol {self.sep!r} is defined as the separator and can't be used in parts' values" + ) data.append(value) if args or kwargs: - raise TypeError('Too many arguments were passed!') + raise TypeError("Too many arguments were passed!") callback_data = self.sep.join(data) if len(callback_data.encode()) > 64: - raise ValueError('Resulted callback data is too long!') + raise ValueError("Resulted callback data is too long!") return callback_data @@ -89,11 +93,12 @@ class CallbackData: """ prefix, *parts = callback_data.split(self.sep) if prefix != self.prefix: - raise ValueError("Passed callback data can't be parsed with that prefix.") - elif len(parts) != len(self._part_names): - raise ValueError('Invalid parts count!') + raise ValueError( + "Passed callback data can't be parsed with that prefix.") + if len(parts) != len(self._part_names): + raise ValueError("Invalid parts count!") - result = {'@': prefix} + result = {"@": prefix} result.update(zip(self._part_names, parts)) return result @@ -106,12 +111,11 @@ class CallbackData: """ for key in config.keys(): if key not in self._part_names: - raise ValueError(f'Invalid field name {key!r}') + raise ValueError(f"Invalid field name {key!r}") return CallbackDataFilter(self, config) class CallbackDataFilter(Filter): - def __init__(self, factory: CallbackData, config: typing.Dict[str, str]): self.config = config self.factory = factory @@ -133,4 +137,4 @@ class CallbackDataFilter(Filter): else: if data.get(key) != value: return False - return {'callback_data': data} + return {"callback_data": data}