mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Deep linking util fix (#569)
* fix: deep linking util fixed and refactored * fix: wrong payload split * feat: check payload length
This commit is contained in:
parent
ea28e2a77a
commit
08f0635afe
2 changed files with 110 additions and 58 deletions
|
|
@ -1,10 +1,10 @@
|
|||
"""
|
||||
Deep linking
|
||||
|
||||
Telegram bots have a deep linking mechanism, that allows for passing additional
|
||||
parameters to the bot on startup. It could be a command that launches the bot — or
|
||||
an auth token to connect the user's Telegram account to their account on some
|
||||
external service.
|
||||
Telegram bots have a deep linking mechanism, that allows for passing
|
||||
additional parameters to the bot on startup. It could be a command that
|
||||
launches the bot — or an auth token to connect the user's Telegram
|
||||
account to their account on some external service.
|
||||
|
||||
You can read detailed description in the source:
|
||||
https://core.telegram.org/bots#deep-linking
|
||||
|
|
@ -16,86 +16,123 @@ Basic link example:
|
|||
.. code-block:: python
|
||||
|
||||
from aiogram.utils.deep_linking import get_start_link
|
||||
link = await get_start_link('foo') # result: 'https://t.me/MyBot?start=foo'
|
||||
link = await get_start_link('foo')
|
||||
|
||||
# result: 'https://t.me/MyBot?start=foo'
|
||||
|
||||
Encoded link example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from aiogram.utils.deep_linking import get_start_link, decode_payload
|
||||
link = await get_start_link('foo', encode=True) # result: 'https://t.me/MyBot?start=Zm9v'
|
||||
# and decode it back:
|
||||
payload = decode_payload('Zm9v') # result: 'foo'
|
||||
from aiogram.utils.deep_linking import get_start_link
|
||||
|
||||
link = await get_start_link('foo', encode=True)
|
||||
# result: 'https://t.me/MyBot?start=Zm9v'
|
||||
|
||||
Decode it back example:
|
||||
.. code-block:: python
|
||||
|
||||
from aiogram.utils.deep_linking import decode_payload
|
||||
from aiogram.types import Message
|
||||
|
||||
@dp.message_handler(commands=["start"])
|
||||
async def handler(message: Message):
|
||||
args = message.get_args()
|
||||
payload = decode_payload(args)
|
||||
await message.answer(f"Your payload: {payload}")
|
||||
|
||||
"""
|
||||
import re
|
||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
||||
|
||||
from ..bot import Bot
|
||||
|
||||
BAD_PATTERN = re.compile(r"[^_A-z0-9-]")
|
||||
|
||||
|
||||
async def get_start_link(payload: str, encode=False) -> str:
|
||||
"""
|
||||
Use this method to handy get 'start' deep link with your payload.
|
||||
If you need to encode payload or pass special characters - set encode as True
|
||||
Get 'start' deep link with your payload.
|
||||
|
||||
If you need to encode payload or pass special characters -
|
||||
set encode as True
|
||||
|
||||
:param payload: args passed with /start
|
||||
:param encode: encode payload with base64url
|
||||
:return: link
|
||||
"""
|
||||
return await _create_link('start', payload, encode)
|
||||
return await _create_link(
|
||||
link_type="start",
|
||||
payload=payload,
|
||||
encode=encode,
|
||||
)
|
||||
|
||||
|
||||
async def get_startgroup_link(payload: str, encode=False) -> str:
|
||||
"""
|
||||
Use this method to handy get 'startgroup' deep link with your payload.
|
||||
If you need to encode payload or pass special characters - set encode as True
|
||||
Get 'startgroup' deep link with your payload.
|
||||
|
||||
If you need to encode payload or pass special characters -
|
||||
set encode as True
|
||||
|
||||
:param payload: args passed with /start
|
||||
:param encode: encode payload with base64url
|
||||
:return: link
|
||||
"""
|
||||
return await _create_link('startgroup', payload, encode)
|
||||
return await _create_link(
|
||||
link_type="startgroup",
|
||||
payload=payload,
|
||||
encode=encode,
|
||||
)
|
||||
|
||||
|
||||
async def _create_link(link_type, payload: str, encode=False):
|
||||
"""
|
||||
Create deep link.
|
||||
|
||||
:param link_type: `start` or `startgroup`
|
||||
:param payload: any string-convertible data
|
||||
:param encode: pass True to encode the payload
|
||||
:return: deeplink
|
||||
"""
|
||||
bot = await _get_bot_user()
|
||||
payload = filter_payload(payload)
|
||||
if encode:
|
||||
payload = encode_payload(payload)
|
||||
return f'https://t.me/{bot.username}?{link_type}={payload}'
|
||||
|
||||
|
||||
def encode_payload(payload: str) -> str:
|
||||
""" Encode payload with URL-safe base64url. """
|
||||
from base64 import urlsafe_b64encode
|
||||
result: bytes = urlsafe_b64encode(payload.encode())
|
||||
return result.decode()
|
||||
|
||||
|
||||
def decode_payload(payload: str) -> str:
|
||||
""" Decode payload with URL-safe base64url. """
|
||||
from base64 import urlsafe_b64decode
|
||||
result: bytes = urlsafe_b64decode(payload + '=' * (4 - len(payload) % 4))
|
||||
return result.decode()
|
||||
|
||||
|
||||
def filter_payload(payload: str) -> str:
|
||||
""" Convert payload to text and search for not allowed symbols. """
|
||||
import re
|
||||
|
||||
# convert to string
|
||||
if not isinstance(payload, str):
|
||||
payload = str(payload)
|
||||
|
||||
# search for not allowed characters
|
||||
if re.search(r'[^_A-z0-9-]', payload):
|
||||
message = ('Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. '
|
||||
'We recommend to encode parameters with binary and other '
|
||||
'types of content.')
|
||||
if encode:
|
||||
payload = encode_payload(payload)
|
||||
|
||||
if re.search(BAD_PATTERN, payload):
|
||||
message = (
|
||||
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
|
||||
"Pass `encode=True` or encode payload manually."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return payload
|
||||
if len(payload) > 64:
|
||||
message = "Payload must be up to 64 characters long."
|
||||
raise ValueError(message)
|
||||
|
||||
return f"https://t.me/{bot.username}?{link_type}={payload}"
|
||||
|
||||
|
||||
def encode_payload(payload: str) -> str:
|
||||
"""Encode payload with URL-safe base64url."""
|
||||
payload = str(payload)
|
||||
bytes_payload: bytes = urlsafe_b64encode(payload.encode())
|
||||
str_payload = bytes_payload.decode()
|
||||
return str_payload.replace("=", "")
|
||||
|
||||
|
||||
def decode_payload(payload: str) -> str:
|
||||
"""Decode payload with URL-safe base64url."""
|
||||
payload += "=" * (4 - len(payload) % 4)
|
||||
result: bytes = urlsafe_b64decode(payload)
|
||||
return result.decode()
|
||||
|
||||
|
||||
async def _get_bot_user():
|
||||
""" Get current user of bot. """
|
||||
from ..bot import Bot
|
||||
"""Get current user of bot."""
|
||||
bot = Bot.get_current()
|
||||
return await bot.me
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.utils.deep_linking import decode_payload, encode_payload, filter_payload
|
||||
from aiogram.utils.deep_linking import get_start_link, get_startgroup_link
|
||||
from aiogram.utils.deep_linking import (
|
||||
decode_payload,
|
||||
encode_payload,
|
||||
get_start_link,
|
||||
get_startgroup_link,
|
||||
)
|
||||
from tests.types import dataset
|
||||
|
||||
# enable asyncio mode
|
||||
|
|
@ -17,9 +21,11 @@ PAYLOADS = [
|
|||
|
||||
WRONG_PAYLOADS = [
|
||||
'@BotFather',
|
||||
"Some:special$characters#=",
|
||||
'spaces spaces spaces',
|
||||
1234567890123456789.0,
|
||||
]
|
||||
USERNAME = dataset.USER["username"]
|
||||
|
||||
|
||||
@pytest.fixture(params=PAYLOADS, name='payload')
|
||||
|
|
@ -47,7 +53,7 @@ def get_bot_user_fixture(monkeypatch):
|
|||
class TestDeepLinking:
|
||||
async def test_get_start_link(self, payload):
|
||||
link = await get_start_link(payload)
|
||||
assert link == f'https://t.me/{dataset.USER["username"]}?start={payload}'
|
||||
assert link == f'https://t.me/{USERNAME}?start={payload}'
|
||||
|
||||
async def test_wrong_symbols(self, wrong_payload):
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -55,20 +61,29 @@ class TestDeepLinking:
|
|||
|
||||
async def test_get_startgroup_link(self, payload):
|
||||
link = await get_startgroup_link(payload)
|
||||
assert link == f'https://t.me/{dataset.USER["username"]}?startgroup={payload}'
|
||||
assert link == f'https://t.me/{USERNAME}?startgroup={payload}'
|
||||
|
||||
async def test_filter_encode_and_decode(self, payload):
|
||||
_payload = filter_payload(payload)
|
||||
encoded = encode_payload(_payload)
|
||||
encoded = encode_payload(payload)
|
||||
decoded = decode_payload(encoded)
|
||||
assert decoded == str(payload)
|
||||
|
||||
async def test_get_start_link_with_encoding(self, payload):
|
||||
async def test_get_start_link_with_encoding(self, wrong_payload):
|
||||
# define link
|
||||
link = await get_start_link(payload, encode=True)
|
||||
link = await get_start_link(wrong_payload, encode=True)
|
||||
|
||||
# define reference link
|
||||
payload = filter_payload(payload)
|
||||
encoded_payload = encode_payload(payload)
|
||||
encoded_payload = encode_payload(wrong_payload)
|
||||
|
||||
assert link == f'https://t.me/{dataset.USER["username"]}?start={encoded_payload}'
|
||||
assert link == f'https://t.me/{USERNAME}?start={encoded_payload}'
|
||||
|
||||
async def test_64_len_payload(self):
|
||||
payload = "p" * 64
|
||||
link = await get_start_link(payload)
|
||||
assert link
|
||||
|
||||
async def test_too_long_payload(self):
|
||||
payload = "p" * 65
|
||||
print(payload, len(payload))
|
||||
with pytest.raises(ValueError):
|
||||
await get_start_link(payload)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue