From 995a0d7e9b3774c869a1a0f015dc63ae3e4e358a Mon Sep 17 00:00:00 2001 From: Oleg A Date: Sun, 3 Sep 2023 00:26:57 +0300 Subject: [PATCH] Custom encoding support (#1278) * Custom encoding support in deep-linking --- CHANGES/1262.feature | 1 + aiogram/utils/deep_linking.py | 77 +++++++++++------- aiogram/utils/payload.py | 108 ++++++++++++++++++++++++++ pyproject.toml | 3 +- tests/test_utils/test_deep_linking.py | 30 ++++++- 5 files changed, 187 insertions(+), 32 deletions(-) create mode 100644 CHANGES/1262.feature create mode 100644 aiogram/utils/payload.py diff --git a/CHANGES/1262.feature b/CHANGES/1262.feature new file mode 100644 index 00000000..822c82ae --- /dev/null +++ b/CHANGES/1262.feature @@ -0,0 +1 @@ +Added support for custom encoders/decoders for payload (and also for deep-linking). diff --git a/aiogram/utils/deep_linking.py b/aiogram/utils/deep_linking.py index 3d27633f..7f9b3583 100644 --- a/aiogram/utils/deep_linking.py +++ b/aiogram/utils/deep_linking.py @@ -16,7 +16,7 @@ Basic link example: .. code-block:: python from aiogram.utils.deep_linking import create_start_link - + link = await create_start_link(bot, 'foo') # result: 'https://t.me/MyBot?start=foo' @@ -46,19 +46,33 @@ Decode it back example: """ from __future__ import annotations +__all__ = [ + "create_start_link", + "create_startgroup_link", + "create_deep_link", + "create_telegram_link", + "encode_payload", + "decode_payload", +] + import re -from base64 import urlsafe_b64decode, urlsafe_b64encode -from typing import TYPE_CHECKING, Literal, cast +from typing import Callable, Literal, Optional, TYPE_CHECKING, cast from aiogram.utils.link import create_telegram_link +from aiogram.utils.payload import encode_payload, decode_payload if TYPE_CHECKING: from aiogram import Bot -BAD_PATTERN = re.compile(r"[^_A-z0-9-]") +BAD_PATTERN = re.compile(r"[^A-z0-9-]") -async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str: +async def create_start_link( + bot: Bot, + payload: str, + encode: bool = False, + encoder: Optional[Callable[[bytes], bytes]] = None, +) -> str: """ Create 'start' deep link with your payload. @@ -67,16 +81,26 @@ async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str :param bot: bot instance :param payload: args passed with /start - :param encode: encode payload with base64url + :param encode: encode payload with base64url or custom encoder + :param encoder: custom encoder callable :return: link """ username = (await bot.me()).username return create_deep_link( - username=cast(str, username), link_type="start", payload=payload, encode=encode + username=cast(str, username), + link_type="start", + payload=payload, + encode=encode, + encoder=encoder, ) -async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str: +async def create_startgroup_link( + bot: Bot, + payload: str, + encode: bool = False, + encoder: Optional[Callable[[bytes], bytes]] = None, +) -> str: """ Create 'startgroup' deep link with your payload. @@ -85,17 +109,26 @@ async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) - :param bot: bot instance :param payload: args passed with /start - :param encode: encode payload with base64url + :param encode: encode payload with base64url or custom encoder + :param encoder: custom encoder callable :return: link """ username = (await bot.me()).username return create_deep_link( - username=cast(str, username), link_type="startgroup", payload=payload, encode=encode + username=cast(str, username), + link_type="startgroup", + payload=payload, + encode=encode, + encoder=encoder, ) def create_deep_link( - username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False + username: str, + link_type: Literal["start", "startgroup"], + payload: str, + encode: bool = False, + encoder: Optional[Callable[[bytes], bytes]] = None, ) -> str: """ Create deep link. @@ -103,14 +136,15 @@ def create_deep_link( :param username: :param link_type: `start` or `startgroup` :param payload: any string-convertible data - :param encode: pass True to encode the payload + :param encode: encode payload with base64url or custom encoder + :param encoder: custom encoder callable :return: deeplink """ if not isinstance(payload, str): payload = str(payload) - if encode: - payload = encode_payload(payload) + if encode or encoder: + payload = encode_payload(payload, encoder=encoder) if re.search(BAD_PATTERN, payload): raise ValueError( @@ -122,18 +156,3 @@ def create_deep_link( raise ValueError("Payload must be up to 64 characters long.") return create_telegram_link(username, **{cast(str, link_type): payload}) - - -def encode_payload(payload: str) -> str: - """Encode payload with URL-safe base64url.""" - payload = str(payload) - bytes_payload: bytes = urlsafe_b64encode(payload.encode()) - str_payload = bytes_payload.decode() - return str_payload.replace("=", "") - - -def decode_payload(payload: str) -> str: - """Decode payload with URL-safe base64url.""" - payload += "=" * (4 - len(payload) % 4) - result: bytes = urlsafe_b64decode(payload) - return result.decode() diff --git a/aiogram/utils/payload.py b/aiogram/utils/payload.py new file mode 100644 index 00000000..f070dd4e --- /dev/null +++ b/aiogram/utils/payload.py @@ -0,0 +1,108 @@ +""" +Payload preparing + +We have added some utils to make work with payload easier. + +Basic encode example: + + .. code-block:: python + + from aiogram.utils.payload import encode_payload + + encoded = encode_payload("foo") + + # result: "Zm9v" + +Basic decode it back example: + + .. code-block:: python + + from aiogram.utils.payload import decode_payload + + encoded = "Zm9v" + decoded = decode_payload(encoded) + # result: "foo" + +Encoding and decoding with your own methods: + + 1. Create your own cryptor + + .. code-block:: python + + from Cryptodome.Cipher import AES + from Cryptodome.Util.Padding import pad, unpad + + class Cryptor: + def __init__(self, key: str): + self.key = key.encode("utf-8") + self.mode = AES.MODE_ECB # never use ECB in strong systems obviously + self.size = 32 + + @property + def cipher(self): + return AES.new(self.key, self.mode) + + def encrypt(self, data: bytes) -> bytes: + return self.cipher.encrypt(pad(data, self.size)) + + def decrypt(self, data: bytes) -> bytes: + decrypted_data = self.cipher.decrypt(data) + return unpad(decrypted_data, self.size) + + 2. Pass cryptor callable methods to aiogram payload tools + + .. code-block:: python + + cryptor = Cryptor("abcdefghijklmnop") + encoded = encode_payload("foo", encoder=cryptor.encrypt) + decoded = decode_payload(encoded_payload, decoder=cryptor.decrypt) + + # result: decoded == "foo" + +""" +from base64 import urlsafe_b64decode, urlsafe_b64encode +from typing import Callable, Optional + + +def encode_payload( + payload: str, + encoder: Optional[Callable[[bytes], bytes]] = None, +) -> str: + """Encode payload with encoder. + + Result also will be encoded with URL-safe base64url. + """ + if not isinstance(payload, str): + payload = str(payload) + + payload_bytes = payload.encode("utf-8") + if encoder is not None: + payload_bytes = encoder(payload_bytes) + + return _encode_b64(payload_bytes) + + +def decode_payload( + payload: str, + decoder: Optional[Callable[[bytes], bytes]] = None, +) -> str: + """Decode URL-safe base64url payload with decoder.""" + original_payload = _decode_b64(payload) + + if decoder is None: + return original_payload.decode() + + return decoder(original_payload).decode() + + +def _encode_b64(payload: bytes) -> str: + """Encode with URL-safe base64url.""" + bytes_payload: bytes = urlsafe_b64encode(payload) + str_payload = bytes_payload.decode() + return str_payload.replace("=", "") + + +def _decode_b64(payload: str) -> bytes: + """Decode with URL-safe base64url.""" + payload += "=" * (4 - len(payload) % 4) + return urlsafe_b64decode(payload.encode()) diff --git a/pyproject.toml b/pyproject.toml index a287076f..aa31ab34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,8 @@ test = [ "pytest-cov~=4.0.0", "pytest-aiohttp~=1.0.4", "aresponses~=2.1.6", - "pytz~=2022.7.1" + "pytz~=2022.7.1", + "pycryptodomex~=3.18", ] docs = [ "Sphinx~=7.1.1", diff --git a/tests/test_utils/test_deep_linking.py b/tests/test_utils/test_deep_linking.py index 3c1dbec2..df06c1e5 100644 --- a/tests/test_utils/test_deep_linking.py +++ b/tests/test_utils/test_deep_linking.py @@ -3,9 +3,8 @@ import pytest from aiogram.utils.deep_linking import ( create_start_link, create_startgroup_link, - decode_payload, - encode_payload, ) +from aiogram.utils.payload import decode_payload, encode_payload from tests.mocked_bot import MockedBot PAYLOADS = [ @@ -51,6 +50,33 @@ class TestDeepLinking: decoded = decode_payload(encoded) assert decoded == str(payload) + async def test_custom_encode_decode(self, payload: str): + from Cryptodome.Cipher import AES + from Cryptodome.Util.Padding import pad, unpad + + class Cryptor: + def __init__(self, key: str): + self.key = key.encode("utf-8") + self.mode = AES.MODE_ECB # never use ECB in strong systems obviously + self.size = 32 + + @property + def cipher(self): + return AES.new(self.key, self.mode) + + def encrypt(self, data: bytes) -> bytes: + return self.cipher.encrypt(pad(data, self.size)) + + def decrypt(self, data: bytes) -> bytes: + decrypted_data = self.cipher.decrypt(data) + return unpad(decrypted_data, self.size) + + cryptor = Cryptor("abcdefghijklmnop") + encoded_payload = encode_payload(payload, encoder=cryptor.encrypt) + decoded_payload = decode_payload(encoded_payload, decoder=cryptor.decrypt) + + assert decoded_payload == str(payload) + async def test_get_start_link_with_encoding(self, bot: MockedBot, wrong_payload: str): # define link link = await create_start_link(bot, wrong_payload, encode=True)