diff --git a/tests/test_filters.py b/tests/test_filters.py index da530910..e0002381 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -195,3 +195,60 @@ class TestTextFilter: assert await check(CallbackQuery(data=test_text)) assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_filter_list, test_text, ignore_case", + [(['', 'new_string'], '', True), + (['new_string', ''], 'exAmple_string', True), + (['new_string', ''], '', False), + (['', 'new_string'], 'exAmple_string', False), + + (['example_string'], 'example_string', True), + (['example_string'], 'exAmple_string', True), + (['exAmple_string'], 'example_string', True), + + (['example_string'], 'example_string', False), + (['example_string'], 'exAmple_string', False), + (['exAmple_string'], 'example_string', False), + + (['example_string'], 'not_example_string', True), + (['example_string'], 'not_eXample_string', True), + (['EXample_string'], 'not_eXample_string', True), + + (['example_string'], 'not_example_string', False), + (['example_string'], 'not_eXample_string', False), + (['EXample_string'], 'not_example_string', False), + + (['example_string', 'new_string'], 'example_string', True), + (['new_string', 'example_string'], 'exAmple_string', True), + (['exAmple_string', 'new_string'], 'example_string', True), + + (['example_string', 'new_string'], 'example_string', False), + (['new_string', 'example_string'], 'exAmple_string', False), + (['exAmple_string', 'new_string'], 'example_string', False), + + (['example_string', 'new_string'], 'not_example_string', True), + (['new_string', 'example_string'], 'not_eXample_string', True), + (['EXample_string', 'new_string'], 'not_eXample_string', True), + + (['example_string', 'new_string'], 'not_example_string', False), + (['new_string', 'example_string'], 'not_eXample_string', False), + (['EXample_string', 'new_string'], 'not_example_string', False), + ]) + async def test_equals_list(self, test_filter_list, test_text, ignore_case): + test_filter = Text(equals=test_filter_list, ignore_case=ignore_case) + + async def check(obj): + result = await test_filter.check(obj) + if ignore_case: + _test_filter_list = list(map(str.lower, test_filter_list)) + _test_text = test_text.lower() + else: + _test_filter_list = test_filter_list + _test_text = test_text + assert result is (_test_text in _test_filter_list) + + await check(Message(text=test_text)) + await check(CallbackQuery(data=test_text)) + await check(InlineQuery(query=test_text)) + await check(Poll(question=test_text))