diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 55ed63e5..f6aeaa14 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -140,7 +140,9 @@ class CommandStart(Command): This filter based on :obj:`Command` filter but can handle only ``/start`` command. """ - def __init__(self, deep_link: typing.Optional[typing.Union[str, re.Pattern]] = None): + def __init__(self, + deep_link: typing.Optional[typing.Union[str, typing.Pattern[str]]] = None, + encoded: bool = False): """ Also this filter can handle `deep-linking `_ arguments. @@ -151,9 +153,11 @@ class CommandStart(Command): @dp.message_handler(CommandStart(re.compile(r'ref-([\\d]+)'))) :param deep_link: string or compiled regular expression (by ``re.compile(...)``). + :param encoded: set True if you're waiting for encoded payload (default - False). """ super().__init__(['start']) self.deep_link = deep_link + self.encoded = encoded async def check(self, message: types.Message): """ @@ -162,18 +166,21 @@ class CommandStart(Command): :param message: :return: """ + from ...utils.deep_linking import decode_payload check = await super().check(message) if check and self.deep_link is not None: - if not isinstance(self.deep_link, re.Pattern): - return message.get_args() == self.deep_link + payload = decode_payload(message.get_args()) if self.encoded else message.get_args() - match = self.deep_link.match(message.get_args()) + if not isinstance(self.deep_link, typing.Pattern): + return False if payload != self.deep_link else {'deep_link': payload} + + match = self.deep_link.match(payload) if match: return {'deep_link': match} return False - return check + return {'deep_link': None} class CommandHelp(Command): @@ -244,7 +251,7 @@ class Text(Filter): raise ValueError(f"No one mode is specified!") equals, contains, endswith, startswith = map(lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy) - else e, + else e, (equals, contains, endswith, startswith)) self.equals = equals self.contains = contains @@ -370,7 +377,7 @@ class Regexp(Filter): """ def __init__(self, regexp): - if not isinstance(regexp, re.Pattern): + if not isinstance(regexp, typing.Pattern): regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) self.regexp = regexp diff --git a/aiogram/utils/deep_linking.py b/aiogram/utils/deep_linking.py new file mode 100644 index 00000000..acb105da --- /dev/null +++ b/aiogram/utils/deep_linking.py @@ -0,0 +1,101 @@ +""" +Deep linking + +Telegram bots have a deep linking mechanism, that allows for passing additional +parameters to the bot on startup. It could be a command that launches the bot — or +an auth token to connect the user's Telegram account to their account on some +external service. + +You can read detailed description in the source: +https://core.telegram.org/bots#deep-linking + +We have add some utils to get deep links more handy. + +Basic link example: + + .. code-block:: python + + from aiogram.utils.deep_linking import get_start_link + link = await get_start_link('foo') # result: 'https://t.me/MyBot?start=foo' + +Encoded link example: + + .. code-block:: python + + from aiogram.utils.deep_linking import get_start_link, decode_payload + link = await get_start_link('foo', encode=True) # result: 'https://t.me/MyBot?start=Zm9v' + # and decode it back: + payload = decode_payload('Zm9v') # result: 'foo' + +""" + + +async def get_start_link(payload: str, encode=False) -> str: + """ + Use this method to handy get 'start' deep link with your payload. + If you need to encode payload or pass special characters - set encode as True + + :param payload: args passed with /start + :param encode: encode payload with base64url + :return: link + """ + return await _create_link('start', payload, encode) + + +async def get_startgroup_link(payload: str, encode=False) -> str: + """ + Use this method to handy get 'startgroup' deep link with your payload. + If you need to encode payload or pass special characters - set encode as True + + :param payload: args passed with /start + :param encode: encode payload with base64url + :return: link + """ + return await _create_link('startgroup', payload, encode) + + +async def _create_link(link_type, payload: str, encode=False): + bot = await _get_bot_user() + payload = filter_payload(payload) + if encode: + payload = encode_payload(payload) + return f'https://t.me/{bot.username}?{link_type}={payload}' + + +def encode_payload(payload: str) -> str: + """ Encode payload with URL-safe base64url. """ + from base64 import urlsafe_b64encode + result: bytes = urlsafe_b64encode(payload.encode()) + return result.decode() + + +def decode_payload(payload: str) -> str: + """ Decode payload with URL-safe base64url. """ + from base64 import urlsafe_b64decode + result: bytes = urlsafe_b64decode(payload + '=' * (4 - len(payload) % 4)) + return result.decode() + + +def filter_payload(payload: str) -> str: + """ Convert payload to text and search for not allowed symbols. """ + import re + + # convert to string + if not isinstance(payload, str): + payload = str(payload) + + # search for not allowed characters + if re.search(r'[^_A-z0-9-]', payload): + message = ('Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. ' + 'We recommend to encode parameters with binary and other ' + 'types of content.') + raise ValueError(message) + + return payload + + +async def _get_bot_user(): + """ Get current user of bot. """ + from ..bot import Bot + bot = Bot.get_current() + return await bot.me diff --git a/docs/source/utils/deep_linking.rst b/docs/source/utils/deep_linking.rst new file mode 100644 index 00000000..e00e0d20 --- /dev/null +++ b/docs/source/utils/deep_linking.rst @@ -0,0 +1,6 @@ +============ +Deep linking +============ + +.. automodule:: aiogram.utils.deep_linking + :members: diff --git a/tests/test_filters.py b/tests/test_filters.py index 609db736..0592f31b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,8 +1,14 @@ +import re +from typing import Match + import pytest -from aiogram.dispatcher.filters import Text +from aiogram.dispatcher.filters import Text, CommandStart from aiogram.types import Message, CallbackQuery, InlineQuery, Poll +# enable asyncio mode +pytestmark = pytest.mark.asyncio + def data_sample_1(): return [ @@ -22,15 +28,16 @@ def data_sample_1(): ('EXample_string', 'not_example_string'), ] + class TestTextFilter: - async def _run_check(self, check, test_text): + @staticmethod + async def _run_check(check, test_text): assert await check(Message(text=test_text)) assert await check(CallbackQuery(data=test_text)) assert await check(InlineQuery(query=test_text)) assert await check(Poll(question=test_text)) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_prefix, test_text", data_sample_1()) async def test_startswith(self, test_prefix, test_text, ignore_case): @@ -49,7 +56,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_prefix_list, test_text", [ (['not_example', ''], ''), @@ -83,7 +89,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_postfix, test_text", data_sample_1()) async def test_endswith(self, test_postfix, test_text, ignore_case): @@ -102,7 +107,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_postfix_list, test_text", [ (['', 'not_example'], ''), @@ -133,9 +137,9 @@ class TestTextFilter: _test_text = test_text return result is any(map(_test_text.endswith, _test_postfix_list)) + await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_string, test_text", [ ('', ''), @@ -169,7 +173,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_filter_list, test_text", [ (['a', 'ab', 'abc'], 'A'), @@ -193,7 +196,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_filter_text, test_text", [ ('', ''), @@ -222,7 +224,6 @@ class TestTextFilter: await self._run_check(check, test_text) - @pytest.mark.asyncio @pytest.mark.parametrize('ignore_case', (True, False)) @pytest.mark.parametrize("test_filter_list, test_text", [ (['new_string', ''], ''), @@ -261,3 +262,50 @@ class TestTextFilter: await check(CallbackQuery(data=test_text)) await check(InlineQuery(query=test_text)) await check(Poll(question=test_text)) + + +class TestCommandStart: + START = '/start' + GOOD = 'foo' + BAD = 'bar' + GOOD_PATTERN = re.compile(r'^f..$') + BAD_PATTERN = re.compile(r'^b..$') + ENCODED = 'Zm9v' + + async def test_start_command_without_payload(self): + test_filter = CommandStart() # empty filter + message = Message(text=self.START) + result = await test_filter.check(message) + assert result == {'deep_link': None} + + async def test_start_command_payload_is_matched(self): + test_filter = CommandStart(deep_link=self.GOOD) + message = Message(text=f'{self.START} {self.GOOD}') + result = await test_filter.check(message) + assert result == {'deep_link': self.GOOD} + + async def test_start_command_payload_is_not_matched(self): + test_filter = CommandStart(deep_link=self.GOOD) + message = Message(text=f'{self.START} {self.BAD}') + result = await test_filter.check(message) + assert result is False + + async def test_start_command_payload_pattern_is_matched(self): + test_filter = CommandStart(deep_link=self.GOOD_PATTERN) + message = Message(text=f'{self.START} {self.GOOD}') + result = await test_filter.check(message) + assert isinstance(result, dict) + match = result.get('deep_link') + assert isinstance(match, Match) + + async def test_start_command_payload_pattern_is_not_matched(self): + test_filter = CommandStart(deep_link=self.BAD_PATTERN) + message = Message(text=f'{self.START} {self.GOOD}') + result = await test_filter.check(message) + assert result is False + + async def test_start_command_payload_is_encoded(self): + test_filter = CommandStart(deep_link=self.GOOD, encoded=True) + message = Message(text=f'{self.START} {self.ENCODED}') + result = await test_filter.check(message) + assert result == {'deep_link': self.GOOD} diff --git a/tests/test_utils/test_deep_linking.py b/tests/test_utils/test_deep_linking.py new file mode 100644 index 00000000..a1d01e4e --- /dev/null +++ b/tests/test_utils/test_deep_linking.py @@ -0,0 +1,74 @@ +import pytest + +from aiogram.utils.deep_linking import decode_payload, encode_payload, filter_payload +from aiogram.utils.deep_linking import get_start_link, get_startgroup_link +from tests.types import dataset + +# enable asyncio mode +pytestmark = pytest.mark.asyncio + +PAYLOADS = [ + 'foo', + 'AAbbCCddEEff1122334455', + 'aaBBccDDeeFF5544332211', + -12345678901234567890, + 12345678901234567890, +] + +WRONG_PAYLOADS = [ + '@BotFather', + 'spaces spaces spaces', + 1234567890123456789.0, +] + + +@pytest.fixture(params=PAYLOADS, name='payload') +def payload_fixture(request): + return request.param + + +@pytest.fixture(params=WRONG_PAYLOADS, name='wrong_payload') +def wrong_payload_fixture(request): + return request.param + + +@pytest.fixture(autouse=True) +def get_bot_user_fixture(monkeypatch): + """ Monkey patching of bot.me calling. """ + from aiogram.utils import deep_linking + + async def get_bot_user_mock(): + from aiogram.types import User + return User(**dataset.USER) + + monkeypatch.setattr(deep_linking, '_get_bot_user', get_bot_user_mock) + + +class TestDeepLinking: + async def test_get_start_link(self, payload): + link = await get_start_link(payload) + assert link == f'https://t.me/{dataset.USER["username"]}?start={payload}' + + async def test_wrong_symbols(self, wrong_payload): + with pytest.raises(ValueError): + await get_start_link(wrong_payload) + + async def test_get_startgroup_link(self, payload): + link = await get_startgroup_link(payload) + assert link == f'https://t.me/{dataset.USER["username"]}?startgroup={payload}' + + async def test_filter_encode_and_decode(self, payload): + _payload = filter_payload(payload) + encoded = encode_payload(_payload) + decoded = decode_payload(encoded) + assert decoded == str(payload) + + async def test_get_start_link_with_encoding(self, payload): + # define link + link = await get_start_link(payload, encode=True) + + # define reference link + payload = filter_payload(payload) + encoded_payload = encode_payload(payload) + + assert link == f'https://t.me/{dataset.USER["username"]}?start={encoded_payload}'