diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 0a81998a..5fe01dde 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -12,6 +12,19 @@ from aiogram.dispatcher.filters.filters import BoundFilter, Filter from aiogram.types import CallbackQuery, Message, InlineQuery, Poll, ChatType +ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str, int] + + +def extract_chat_ids(chat_id: ChatIDArgumentType) -> typing.Set[int]: + # since "str" is also an "Iterable", we have to check for it first + if isinstance(chat_id, str): + return {int(chat_id), } + if isinstance(chat_id, Iterable): + return {int(item) for (item) in chat_id} + # the last possible type is a single "int" + return {chat_id, } + + class Command(Filter): """ You can handle commands by using this filter. @@ -545,10 +558,9 @@ class ExceptionsFilter(BoundFilter): class IDFilter(Filter): - def __init__(self, - user_id: Optional[Union[Iterable[Union[int, str]], str, int]] = None, - chat_id: Optional[Union[Iterable[Union[int, str]], str, int]] = None, + user_id: Optional[ChatIDArgumentType] = None, + chat_id: Optional[ChatIDArgumentType] = None, ): """ :param user_id: @@ -557,18 +569,14 @@ class IDFilter(Filter): if user_id is None and chat_id is None: raise ValueError("Both user_id and chat_id can't be None") - self.user_id = None - self.chat_id = None + self.user_id: Optional[typing.Set[int]] = None + self.chat_id: Optional[typing.Set[int]] = None + if user_id: - if isinstance(user_id, Iterable): - self.user_id = list(map(int, user_id)) - else: - self.user_id = [int(user_id), ] + self.user_id = extract_chat_ids(user_id) + if chat_id: - if isinstance(chat_id, Iterable): - self.chat_id = list(map(int, chat_id)) - else: - self.chat_id = [int(chat_id), ] + self.chat_id = extract_chat_ids(chat_id) @classmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: @@ -614,22 +622,20 @@ class AdminFilter(Filter): is_chat_admin is required for InlineQuery. """ - def __init__(self, is_chat_admin: Optional[Union[Iterable[Union[int, str]], str, int, bool]] = None): + def __init__(self, is_chat_admin: Optional[Union[ChatIDArgumentType, bool]] = None): self._check_current = False self._chat_ids = None if is_chat_admin is False: raise ValueError("is_chat_admin cannot be False") - if is_chat_admin: - if isinstance(is_chat_admin, bool): - self._check_current = is_chat_admin - if isinstance(is_chat_admin, Iterable): - self._chat_ids = list(is_chat_admin) - else: - self._chat_ids = [is_chat_admin] - else: + if not is_chat_admin: self._check_current = True + return + + if isinstance(is_chat_admin, bool): + self._check_current = is_chat_admin + self._chat_ids = extract_chat_ids(is_chat_admin) @classmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: diff --git a/tests/test_dispatcher/test_filters/test_builtin.py b/tests/test_dispatcher/test_filters/test_builtin.py index 86344cec..a26fc139 100644 --- a/tests/test_dispatcher/test_filters/test_builtin.py +++ b/tests/test_dispatcher/test_filters/test_builtin.py @@ -1,6 +1,12 @@ +from typing import Set + import pytest -from aiogram.dispatcher.filters.builtin import Text +from aiogram.dispatcher.filters.builtin import ( + Text, + extract_chat_ids, + ChatIDArgumentType, +) class TestText: @@ -16,3 +22,50 @@ class TestText: config = {param: value} res = Text.validate(config) assert res == {key: value} + + +@pytest.mark.parametrize( + ('chat_id', 'expected'), + ( + pytest.param('-64856280', {-64856280,}, id='single negative int as string'), + pytest.param('64856280', {64856280,}, id='single positive int as string'), + pytest.param(-64856280, {-64856280,}, id='single negative int'), + pytest.param(64856280, {64856280,}, id='single positive negative int'), + pytest.param( + ['-64856280'], {-64856280,}, id='list of single negative int as string' + ), + pytest.param([-64856280], {-64856280,}, id='list of single negative int'), + pytest.param( + ['-64856280', '-64856280'], + {-64856280,}, + id='list of two duplicated negative ints as strings', + ), + pytest.param( + ['-64856280', -64856280], + {-64856280,}, + id='list of one negative int as string and one negative int', + ), + pytest.param( + [-64856280, -64856280], + {-64856280,}, + id='list of two duplicated negative ints', + ), + pytest.param( + iter(['-64856280']), + {-64856280,}, + id='iterator from a list of single negative int as string', + ), + pytest.param( + [10000000, 20000000, 30000000], + {10000000, 20000000, 30000000}, + id='list of several positive ints', + ), + pytest.param( + [10000000, '20000000', -30000000], + {10000000, 20000000, -30000000}, + id='list of positive int, positive int as string, negative int', + ), + ), +) +def test_extract_chat_ids(chat_id: ChatIDArgumentType, expected: Set[int]): + assert extract_chat_ids(chat_id) == expected