diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 15cd73dd..df72b9d0 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -206,18 +206,19 @@ class Text(Filter): """ def __init__(self, - equals: Optional[Union[str, LazyProxy]] = None, - contains: Optional[Union[str, LazyProxy]] = None, - startswith: Optional[Union[str, LazyProxy]] = None, - endswith: Optional[Union[str, LazyProxy]] = None, + equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, ignore_case=False): """ Check text for one of pattern. Only one mode can be used in one filter. + In every pattern, a single string is treated as a list with 1 element. - :param equals: - :param contains: - :param startswith: - :param endswith: + :param equals: True if object's text in the list + :param contains: True if object's text contains all strings from the list + :param startswith: True if object's text starts with any of strings from the list + :param endswith: True if object's text ends with any of strings from the list :param ignore_case: case insensitive """ # Only one mode can be used. check it. @@ -232,6 +233,9 @@ class Text(Filter): elif check == 0: raise ValueError(f"No one mode is specified!") + equals, contains, endswith, startswith = map(lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy) + else e, + (equals, contains, endswith, startswith)) self.equals = equals self.contains = contains self.endswith = endswith @@ -267,25 +271,17 @@ class Text(Filter): text = text.lower() if self.equals is not None: - self.equals = str(self.equals) - if self.ignore_case: - self.equals = self.equals.lower() - return text == self.equals + self.equals = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.equals)) + return text in self.equals elif self.contains is not None: - self.contains = str(self.contains) - if self.ignore_case: - self.contains = self.contains.lower() - return self.contains in text + self.contains = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.contains)) + return all(map(text.__contains__, self.contains)) elif self.startswith is not None: - self.startswith = str(self.startswith) - if self.ignore_case: - self.startswith = self.startswith.lower() - return text.startswith(self.startswith) + self.startswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.startswith)) + return any(map(text.startswith, self.startswith)) elif self.endswith is not None: - self.endswith = str(self.endswith) - if self.ignore_case: - self.endswith = self.endswith.lower() - return text.endswith(self.endswith) + self.endswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.endswith)) + return any(map(text.endswith, self.endswith)) return False diff --git a/examples/text_filter_example.py b/examples/text_filter_example.py new file mode 100644 index 00000000..a47a5c92 --- /dev/null +++ b/examples/text_filter_example.py @@ -0,0 +1,50 @@ +""" +This is a bot to show the usage of the builtin Text filter +Instead of a list, a single element can be passed to any filter, it will be treated as list with an element +""" + +import logging + +from aiogram import Bot, Dispatcher, executor, types + +API_TOKEN = 'API_TOKEN_HERE' + +# Configure logging +logging.basicConfig(level=logging.INFO) + +# Initialize bot and dispatcher +bot = Bot(token=API_TOKEN) +dp = Dispatcher(bot) + +# if the text from user in the list +@dp.message_handler(text=['text1', 'text2']) +async def text_in_handler(message: types.Message): + await message.answer("The message text is in the list!") + +# if the text contains any string +@dp.message_handler(text_contains='example1') +@dp.message_handler(text_contains='example2') +async def text_contains_any_handler(message: types.Message): + await message.answer("The message text contains any of strings") + + +# if the text contains all the strings from the list +@dp.message_handler(text_contains=['str1', 'str2']) +async def text_contains_all_handler(message: types.Message): + await message.answer("The message text contains all strings from the list") + + +# if the text starts with any string from the list +@dp.message_handler(text_startswith=['prefix1', 'prefix2']) +async def text_startswith_handler(message: types.Message): + await message.answer("The message text starts with any of prefixes") + + +# if the text ends with any string from the list +@dp.message_handler(text_endswith=['postfix1', 'postfix2']) +async def text_endswith_handler(message: types.Message): + await message.answer("The message text ends with any of postfixes") + + +if __name__ == '__main__': + executor.start_polling(dp, skip_updates=True) diff --git a/tests/test_filters.py b/tests/test_filters.py index da530910..288aa0a4 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -55,6 +55,56 @@ class TestTextFilter: assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_prefix_list, test_text, ignore_case", + [(['not_example', ''], '', True), + (['', 'not_example'], 'exAmple_string', True), + (['not_example', ''], '', False), + (['', 'not_example'], 'exAmple_string', False), + + (['example_string', 'not_example'], 'example_string', True), + (['not_example', 'example_string'], 'exAmple_string', True), + (['exAmple_string', 'not_example'], 'example_string', True), + + (['not_example', 'example_string'], 'example_string', False), + (['example_string', 'not_example'], 'exAmple_string', False), + (['not_example', 'exAmple_string'], 'example_string', False), + + (['example_string', 'not_example'], 'example_string_dsf', True), + (['not_example', 'example_string'], 'example_striNG_dsf', True), + (['example_striNG', 'not_example'], 'example_string_dsf', True), + + (['not_example', 'example_string'], 'example_string_dsf', False), + (['example_string', 'not_example'], 'example_striNG_dsf', False), + (['not_example', 'example_striNG'], 'example_string_dsf', False), + + (['example_string', 'not_example'], 'not_example_string', True), + (['not_example', 'example_string'], 'not_eXample_string', True), + (['EXample_string', 'not_example'], 'not_example_string', True), + + (['not_example', 'example_string'], 'not_example_string', False), + (['example_string', 'not_example'], 'not_eXample_string', False), + (['not_example', 'EXample_string'], 'not_example_string', False), + ]) + async def test_startswith_list(self, test_prefix_list, test_text, ignore_case): + test_filter = Text(startswith=test_prefix_list, ignore_case=ignore_case) + + async def check(obj): + result = await test_filter.check(obj) + if ignore_case: + _test_prefix_list = map(str.lower, test_prefix_list) + _test_text = test_text.lower() + else: + _test_prefix_list = test_prefix_list + _test_text = test_text + + return result is any(map(_test_text.startswith, _test_prefix_list)) + + assert await check(Message(text=test_text)) + assert await check(CallbackQuery(data=test_text)) + assert await check(InlineQuery(query=test_text)) + assert await check(Poll(question=test_text)) + @pytest.mark.asyncio @pytest.mark.parametrize("test_postfix, test_text, ignore_case", [('', '', True), @@ -105,6 +155,55 @@ class TestTextFilter: assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_postfix_list, test_text, ignore_case", + [(['', 'not_example'], '', True), + (['not_example', ''], 'exAmple_string', True), + (['', 'not_example'], '', False), + (['not_example', ''], 'exAmple_string', False), + + (['example_string', 'not_example'], 'example_string', True), + (['not_example', 'example_string'], 'exAmple_string', True), + (['exAmple_string', 'not_example'], 'example_string', True), + + (['example_string', 'not_example'], 'example_string', False), + (['not_example', 'example_string'], 'exAmple_string', False), + (['exAmple_string', 'not_example'], 'example_string', False), + + (['example_string', 'not_example'], 'example_string_dsf', True), + (['not_example', 'example_string'], 'example_striNG_dsf', True), + (['example_striNG', 'not_example'], 'example_string_dsf', True), + + (['not_example', 'example_string'], 'example_string_dsf', False), + (['example_string', 'not_example'], 'example_striNG_dsf', False), + (['not_example', 'example_striNG'], 'example_string_dsf', False), + + (['not_example', 'example_string'], 'not_example_string', True), + (['example_string', 'not_example'], 'not_eXample_string', True), + (['not_example', 'EXample_string'], 'not_eXample_string', True), + + (['not_example', 'example_string'], 'not_example_string', False), + (['example_string', 'not_example'], 'not_eXample_string', False), + (['not_example', 'EXample_string'], 'not_example_string', False), + ]) + async def test_endswith_list(self, test_postfix_list, test_text, ignore_case): + test_filter = Text(endswith=test_postfix_list, ignore_case=ignore_case) + + async def check(obj): + result = await test_filter.check(obj) + if ignore_case: + _test_postfix_list = map(str.lower, test_postfix_list) + _test_text = test_text.lower() + else: + _test_postfix_list = test_postfix_list + _test_text = test_text + + return result is any(map(_test_text.endswith, _test_postfix_list)) + assert await check(Message(text=test_text)) + assert await check(CallbackQuery(data=test_text)) + assert await check(InlineQuery(query=test_text)) + assert await check(Poll(question=test_text)) + @pytest.mark.asyncio @pytest.mark.parametrize("test_string, test_text, ignore_case", [('', '', True), @@ -155,6 +254,37 @@ class TestTextFilter: assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_filter_list, test_text, ignore_case", + [(['a', 'ab', 'abc'], 'A', True), + (['a', 'ab', 'abc'], 'ab', True), + (['a', 'ab', 'abc'], 'aBc', True), + (['a', 'ab', 'abc'], 'd', True), + + (['a', 'ab', 'abc'], 'A', False), + (['a', 'ab', 'abc'], 'ab', False), + (['a', 'ab', 'abc'], 'aBc', False), + (['a', 'ab', 'abc'], 'd', False), + ]) + async def test_contains_list(self, test_filter_list, test_text, ignore_case): + test_filter = Text(contains=test_filter_list, ignore_case=ignore_case) + + async def check(obj): + result = await test_filter.check(obj) + if ignore_case: + _test_filter_list = list(map(str.lower, test_filter_list)) + _test_text = test_text.lower() + else: + _test_filter_list = test_filter_list + _test_text = test_text + + return result is all(map(_test_text.__contains__, _test_filter_list)) + + assert await check(Message(text=test_text)) + assert await check(CallbackQuery(data=test_text)) + assert await check(InlineQuery(query=test_text)) + assert await check(Poll(question=test_text)) + @pytest.mark.asyncio @pytest.mark.parametrize("test_filter_text, test_text, ignore_case", [('', '', True), @@ -195,3 +325,60 @@ class TestTextFilter: assert await check(CallbackQuery(data=test_text)) assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_filter_list, test_text, ignore_case", + [(['', 'new_string'], '', True), + (['new_string', ''], 'exAmple_string', True), + (['new_string', ''], '', False), + (['', 'new_string'], 'exAmple_string', False), + + (['example_string'], 'example_string', True), + (['example_string'], 'exAmple_string', True), + (['exAmple_string'], 'example_string', True), + + (['example_string'], 'example_string', False), + (['example_string'], 'exAmple_string', False), + (['exAmple_string'], 'example_string', False), + + (['example_string'], 'not_example_string', True), + (['example_string'], 'not_eXample_string', True), + (['EXample_string'], 'not_eXample_string', True), + + (['example_string'], 'not_example_string', False), + (['example_string'], 'not_eXample_string', False), + (['EXample_string'], 'not_example_string', False), + + (['example_string', 'new_string'], 'example_string', True), + (['new_string', 'example_string'], 'exAmple_string', True), + (['exAmple_string', 'new_string'], 'example_string', True), + + (['example_string', 'new_string'], 'example_string', False), + (['new_string', 'example_string'], 'exAmple_string', False), + (['exAmple_string', 'new_string'], 'example_string', False), + + (['example_string', 'new_string'], 'not_example_string', True), + (['new_string', 'example_string'], 'not_eXample_string', True), + (['EXample_string', 'new_string'], 'not_eXample_string', True), + + (['example_string', 'new_string'], 'not_example_string', False), + (['new_string', 'example_string'], 'not_eXample_string', False), + (['EXample_string', 'new_string'], 'not_example_string', False), + ]) + async def test_equals_list(self, test_filter_list, test_text, ignore_case): + test_filter = Text(equals=test_filter_list, ignore_case=ignore_case) + + async def check(obj): + result = await test_filter.check(obj) + if ignore_case: + _test_filter_list = list(map(str.lower, test_filter_list)) + _test_text = test_text.lower() + else: + _test_filter_list = test_filter_list + _test_text = test_text + assert result is (_test_text in _test_filter_list) + + await check(Message(text=test_text)) + await check(CallbackQuery(data=test_text)) + await check(InlineQuery(query=test_text)) + await check(Poll(question=test_text))