Backport RedisStorage, deep-linking

This commit is contained in:
Alex Root Junior 2021-06-13 02:45:16 +03:00
parent bc96bdd3b6
commit 988d55ff65
30 changed files with 852 additions and 183 deletions

View file

@ -64,9 +64,14 @@ jobs:
run: |
poetry run black --check --diff aiogram tests
- name: Start Redis
uses: supercharge/redis-github-action@1.2.0
with:
redis-version: 6
- name: Run tests
run: |
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml --redis redis://redis:6379/0
- uses: codecov/codecov-action@v1
with:

View file

@ -6,6 +6,8 @@ python := $(py) python
reports_dir := reports
redis_connection := redis://localhost:6379
.PHONY: help
help:
@echo "======================================================================================="
@ -99,12 +101,12 @@ lint: isort black flake8 mypy
.PHONY: test
test:
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/ --redis $(redis_connection)
.PHONY: test-coverage
test-coverage:
mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
.PHONY: test-coverage-report
test-coverage-report:

View file

@ -6,6 +6,8 @@ from .dispatcher import filters, handler
from .dispatcher.dispatcher import Dispatcher
from .dispatcher.middlewares.base import BaseMiddleware
from .dispatcher.router import Router
from .utils.text_decorations import html_decoration as _html_decoration
from .utils.text_decorations import markdown_decoration as _markdown_decoration
try:
import uvloop as _uvloop
@ -15,6 +17,8 @@ except ImportError: # pragma: no cover
pass
F = MagicFilter()
html = _html_decoration
md = _markdown_decoration
__all__ = (
"__api_version__",
@ -29,6 +33,8 @@ __all__ = (
"filters",
"handler",
"F",
"html",
"md",
)
__version__ = "3.0.0-alpha.8"

View file

@ -4,7 +4,7 @@ import asyncio
import contextvars
import warnings
from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
from typing import Any, AsyncGenerator, Dict, Optional, Union
from .. import loggers
from ..client.bot import Bot
@ -13,7 +13,6 @@ from ..types import TelegramObject, Update, User
from ..utils.exceptions.base import TelegramAPIError
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
from .fsm.context import FSMContext
from .fsm.middleware import FSMContextMiddleware
from .fsm.storage.base import BaseStorage
from .fsm.storage.memory import MemoryStorage
@ -32,7 +31,7 @@ class Dispatcher(Router):
self,
storage: Optional[BaseStorage] = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
isolate_events: bool = True,
isolate_events: bool = False,
**kwargs: Any,
) -> None:
super(Dispatcher, self).__init__(**kwargs)
@ -255,7 +254,9 @@ class Dispatcher(Router):
)
return True # because update was processed but unsuccessful
async def _polling(self, bot: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
async def _polling(
self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Internal polling process
@ -264,7 +265,11 @@ class Dispatcher(Router):
:return:
"""
async for update in self._listen_updates(bot, polling_timeout=polling_timeout):
await self._process_update(bot=bot, update=update, **kwargs)
handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks:
asyncio.create_task(handle_update)
else:
await handle_update
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
@ -342,11 +347,15 @@ class Dispatcher(Router):
return None
async def start_polling(self, *bots: Bot, polling_timeout: int = 10, **kwargs: Any) -> None:
async def start_polling(
self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Polling runner
:param bots:
:param polling_timeout:
:param handle_as_tasks:
:param kwargs:
:return:
"""
@ -363,7 +372,12 @@ class Dispatcher(Router):
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
coro_list.append(
self._polling(bot=bot, polling_timeout=polling_timeout, **kwargs)
self._polling(
bot=bot,
handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout,
**kwargs,
)
)
await asyncio.gather(*coro_list)
finally:
@ -372,22 +386,27 @@ class Dispatcher(Router):
loggers.dispatcher.info("Polling stopped")
await self.emit_shutdown(**workflow_data)
def run_polling(self, *bots: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
def run_polling(
self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Run many bots with polling
:param bots: Bot instances
:param polling_timeout: Poling timeout
:param handle_as_tasks: Run task for each event and no wait result
:param kwargs: contextual data
:return:
"""
try:
return asyncio.run(
self.start_polling(*bots, **kwargs, polling_timeout=polling_timeout)
self.start_polling(
*bots,
**kwargs,
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
)
)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown
pass
def current_state(self, chat_id: int, user_id: int) -> FSMContext:
return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id))

View file

@ -1,18 +1,24 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
from pydantic import validator
from magic_filter import MagicFilter
from pydantic import Field, validator
from aiogram import Bot
from aiogram.dispatcher.filters import BaseFilter
from aiogram.types import Message
from aiogram.utils.deep_linking import decode_payload
CommandPatterType = Union[str, re.Pattern]
class CommandException(Exception):
pass
class Command(BaseFilter):
"""
This filter can be helpful for handling commands from the text messages.
@ -29,6 +35,8 @@ class Command(BaseFilter):
"""Ignore case (Does not work with regexp, use flags instead)"""
commands_ignore_mention: bool = False
"""Ignore bot mention. By default bot can not handle commands intended for other bots"""
command_magic: Optional[MagicFilter] = None
"""Validate command object via Magic filter after all checks done"""
@validator("commands", always=True)
def _validate_commands(
@ -39,22 +47,17 @@ class Command(BaseFilter):
return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
if not message.text:
text = message.text or message.caption
if not text:
return False
return await self.parse_command(text=message.text, bot=bot)
async def parse_command(self, text: str, bot: Bot) -> Union[bool, Dict[str, CommandObject]]:
"""
Extract command from the text and validate
:param text:
:param bot:
:return:
"""
if not text.strip():
try:
command = await self.parse_command(text=cast(str, message.text), bot=bot)
except CommandException:
return False
return {"command": command}
def extract_command(self, text: str) -> CommandObject:
# First step: separate command with arguments
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"]
full_command, *args = text.split(maxsplit=1)
@ -62,46 +65,52 @@ class Command(BaseFilter):
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
return CommandObject(
prefix=prefix, command=command, mention=mention, args=args[0] if args else None
)
# Validate prefixes
if prefix not in self.commands_prefix:
return False
def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.commands_prefix:
raise CommandException("Invalid command prefix")
# Validate mention
if mention and not self.commands_ignore_mention:
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.commands_ignore_mention:
me = await bot.me()
if me.username and mention.lower() != me.username.lower():
return False
if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match")
# Validate command
def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command)
result = allowed_command.match(command.command)
if result:
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=result,
)
}
return replace(command, match=result)
elif command.command == allowed_command: # String
return command
raise CommandException("Command did not match pattern")
elif command == allowed_command: # String
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=None,
)
}
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
Extract command from the text and validate
return False
:param text:
:param bot:
:return:
"""
command = self.extract_command(text)
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
self.do_magic(command=command)
return command
def do_magic(self, command: CommandObject) -> None:
if not self.command_magic:
return
if not self.command_magic.resolve(command):
raise CommandException("Rejected via magic filter")
class Config:
arbitrary_types_allowed = True
@ -143,3 +152,40 @@ class CommandObject:
if self.args:
line += " " + self.args
return line
class CommandStart(Command):
commands: Tuple[str] = Field(("start",), const=True)
commands_prefix: str = Field("/", const=True)
deep_link: bool = False
deep_link_encoded: bool = False
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
Extract command from the text and validate
:param text:
:param bot:
:return:
"""
command = self.extract_command(text)
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
command = self.validate_deeplink(command=command)
self.do_magic(command=command)
return command
def validate_deeplink(self, command: CommandObject) -> CommandObject:
if not self.deep_link:
return command
if not command.args:
raise CommandException("Deep-link was missing")
args = command.args
if self.deep_link_encoded:
try:
args = decode_payload(args)
except UnicodeDecodeError as e:
raise CommandException(f"Failed to decode Base64: {e}")
return replace(command, args=args)
return command

View file

@ -1,25 +1,35 @@
from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
class FSMContext:
def __init__(self, storage: BaseStorage, chat_id: int, user_id: int) -> None:
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
self.bot = bot
self.storage = storage
self.chat_id = chat_id
self.user_id = user_id
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(chat_id=self.chat_id, user_id=self.user_id, state=state)
await self.storage.set_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
async def get_state(self) -> Optional[str]:
return await self.storage.get_state(chat_id=self.chat_id, user_id=self.user_id)
return await self.storage.get_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(chat_id=self.chat_id, user_id=self.user_id, data=data)
await self.storage.set_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(chat_id=self.chat_id, user_id=self.user_id)
return await self.storage.get_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
@ -27,7 +37,7 @@ class FSMContext:
if data:
kwargs.update(data)
return await self.storage.update_data(
chat_id=self.chat_id, user_id=self.user_id, data=kwargs
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
async def clear(self) -> None:

View file

@ -1,5 +1,6 @@
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
@ -24,24 +25,27 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
event: Update,
data: Dict[str, Any],
) -> Any:
context = self.resolve_event_context(data)
bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data)
data["fsm_storage"] = self.storage
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
async with self.storage.lock(chat_id=context.chat_id, user_id=context.user_id):
async with self.storage.lock(
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
return await handler(event, data)
return await handler(event, data)
def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]:
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(chat_id=chat_id, user_id=user_id)
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
def resolve_context(
self, chat_id: Optional[int], user_id: Optional[int]
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
@ -50,8 +54,8 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
return self.get_context(chat_id=chat_id, user_id=user_id)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return None
def get_context(self, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id)
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)

View file

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]]
@ -11,34 +12,42 @@ class BaseStorage(ABC):
@abstractmethod
@asynccontextmanager
async def lock(
self, chat_id: int, user_id: int
self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
yield None
@abstractmethod
async def set_state(
self, chat_id: int, user_id: int, state: StateType = None
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
pass
@abstractmethod
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: # pragma: no cover
async def get_state(
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
pass
@abstractmethod
async def set_data(
self, chat_id: int, user_id: int, data: Dict[str, Any]
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
pass
@abstractmethod
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: # pragma: no cover
async def get_data(
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
pass
async def update_data(
self, chat_id: int, user_id: int, data: Dict[str, Any]
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]:
current_data = await self.get_data(chat_id=chat_id, user_id=user_id)
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
current_data.update(data)
await self.set_data(chat_id=chat_id, user_id=user_id, data=current_data)
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
return current_data.copy()
@abstractmethod
async def close(self) -> None: # pragma: no cover
pass

View file

@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
@ -17,23 +18,30 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage):
def __init__(self) -> None:
self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict(
lambda: defaultdict(MemoryStorageRecord)
)
self.storage: DefaultDict[
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
async def close(self) -> None:
pass
@asynccontextmanager
async def lock(self, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[chat_id][user_id].lock:
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[bot][chat_id][user_id].lock:
yield None
async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None:
self.storage[chat_id][user_id].state = state.state if isinstance(state, State) else state
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[chat_id][user_id].state
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[bot][chat_id][user_id].state
async def set_data(self, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[chat_id][user_id].data = data.copy()
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[bot][chat_id][user_id].data = data.copy()
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[chat_id][user_id].data.copy()
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[bot][chat_id][user_id].data.copy()

View file

@ -0,0 +1,101 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
from aioredis import ConnectionPool, Redis
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
prefix: str = "fsm",
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.prefix = prefix
self.prefix_bot = prefix_bot
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "RedisStorage":
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
async def close(self) -> None:
await self.redis.close()
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, STATE_LOCK_KEY)
async with self.redis.lock(name=key, **self.lock_kwargs):
yield None
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
if state is None:
await self.redis.delete(key)
else:
await self.redis.set(
key, state.state if isinstance(state, State) else state, ex=self.state_ttl
)
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
value = await self.redis.get(key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
if not data:
await self.redis.delete(key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl)
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
value = await self.redis.get(key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], bot.session.json_loads(value))

View file

@ -0,0 +1,34 @@
import hashlib
import hmac
from typing import Any, Dict
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
"""
Generate hexadecimal representation
of the HMAC-SHA-256 signature of the data-check-string
with the SHA256 hash of the bot's token used as a secret key
:param token:
:param hash:
:param kwargs: all params received on auth
:return:
"""
secret = hashlib.sha256(token.encode("utf-8"))
check_string = "\n".join(map(lambda k: f"{k}={kwargs[k]}", sorted(kwargs)))
hmac_string = hmac.new(
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
).hexdigest()
return hmac_string == hash
def check_integrity(token: str, data: Dict[str, Any]) -> bool:
"""
Verify the authentication and the integrity
of the data received on user's auth
:param token: Bot's token
:param data: all data that came on auth
:return:
"""
return check_signature(token, **data)

View file

@ -0,0 +1,131 @@
"""
Deep linking
Telegram bots have a deep linking mechanism, that allows for passing
additional parameters to the bot on startup. It could be a command that
launches the bot or an auth token to connect the user's Telegram
account to their account on some external service.
You can read detailed description in the source:
https://core.telegram.org/bots#deep-linking
We have add some utils to get deep links more handy.
Basic link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo')
# result: 'https://t.me/MyBot?start=foo'
Encoded link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo', encode=True)
# result: 'https://t.me/MyBot?start=Zm9v'
Decode it back example:
.. code-block:: python
from aiogram.utils.deep_linking import decode_payload
from aiogram.types import Message
@dp.message_handler(commands=["start"])
async def handler(message: Message):
args = message.get_args()
payload = decode_payload(args)
await message.answer(f"Your payload: {payload}")
"""
import re
from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Literal, cast
from aiogram import Bot
from aiogram.utils.link import create_telegram_link
BAD_PATTERN = re.compile(r"[^_A-z0-9-]")
async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'start' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(username=username, link_type="start", payload=payload, encode=encode)
async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'startgroup' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(
username=username, link_type="startgroup", payload=payload, encode=encode
)
def create_deep_link(
username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False
) -> str:
"""
Create deep link.
:param username:
:param link_type: `start` or `startgroup`
:param payload: any string-convertible data
:param encode: pass True to encode the payload
:return: deeplink
"""
if not isinstance(payload, str):
payload = str(payload)
if encode:
payload = encode_payload(payload)
if re.search(BAD_PATTERN, payload):
raise ValueError(
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
"Pass `encode=True` or encode payload manually."
)
if len(payload) > 64:
raise ValueError("Payload must be up to 64 characters long.")
return create_telegram_link(username, **{cast(str, link_type): payload})
def encode_payload(payload: str) -> str:
"""Encode payload with URL-safe base64url."""
payload = str(payload)
bytes_payload: bytes = urlsafe_b64encode(payload.encode())
str_payload = bytes_payload.decode()
return str_payload.replace("=", "")
def decode_payload(payload: str) -> str:
"""Decode payload with URL-safe base64url."""
payload += "=" * (4 - len(payload) % 4)
result: bytes = urlsafe_b64decode(payload)
return result.decode()

View file

@ -2,13 +2,27 @@ from __future__ import annotations
from itertools import chain
from itertools import cycle as repeat_all
from typing import Any, Generator, Generic, Iterable, List, Optional, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Generator,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)
from aiogram.dispatcher.filters.callback_data import CallbackData
from aiogram.types import (
CallbackGame,
InlineKeyboardButton,
InlineKeyboardMarkup,
KeyboardButton,
LoginUrl,
ReplyKeyboardMarkup,
)
@ -239,3 +253,28 @@ def repeat_last(items: Iterable[T]) -> Generator[T, None, None]:
except StopIteration:
finished = True
yield value
class InlineKeyboardConstructor(KeyboardConstructor[InlineKeyboardButton]):
if TYPE_CHECKING: # pragma: no cover
@no_type_check
def button(
self,
text: str,
url: Optional[str] = None,
login_url: Optional[LoginUrl] = None,
callback_data: Optional[Union[str, CallbackData]] = None,
switch_inline_query: Optional[str] = None,
switch_inline_query_current_chat: Optional[str] = None,
callback_game: Optional[CallbackGame] = None,
pay: Optional[bool] = None,
**kwargs: Any,
) -> "KeyboardConstructor[InlineKeyboardButton]":
...
def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup:
...
def __init__(self) -> None:
super().__init__(InlineKeyboardButton)

18
aiogram/utils/link.py Normal file
View file

@ -0,0 +1,18 @@
from typing import Any
from urllib.parse import urlencode, urljoin
def create_tg_link(link: str, **kwargs: Any) -> str:
url = f"tg://{link}"
if kwargs:
query = urlencode(kwargs)
url += f"?{query}"
return url
def create_telegram_link(uri: str, **kwargs: Any) -> str:
url = urljoin("https://t.me", uri)
if kwargs:
query = urlencode(query=kwargs)
url += f"?{query}"
return url

View file

@ -183,7 +183,7 @@ class MarkdownDecoration(TextDecoration):
return f"`{value}`"
def pre(self, value: str) -> str:
return f"```{value}```"
return f"```\n{value}\n```"
def pre_language(self, value: str, language: str) -> str:
return f"```{language}\n{value}\n```"

View file

@ -1,6 +1,6 @@
[mypy]
;plugins = pydantic.mypy
python_version = 3.7
python_version = 3.8
show_error_codes = True
show_error_context = True
pretty = True
@ -29,3 +29,6 @@ ignore_missing_imports = True
[mypy-uvloop]
ignore_missing_imports = True
[mypy-aioredis]
ignore_missing_imports = True

37
poetry.lock generated
View file

@ -38,6 +38,21 @@ aiohttp = ">=2.3.2"
attrs = ">=19.2.0"
python-socks = {version = ">=1.0.1", extras = ["asyncio"]}
[[package]]
name = "aioredis"
version = "2.0.0a1"
description = "asyncio (PEP 3156) Redis support"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
async-timeout = "*"
typing-extensions = "*"
[package.extras]
hiredis = ["hiredis (>=1.0)"]
[[package]]
name = "alabaster"
version = "0.7.12"
@ -764,6 +779,17 @@ python-versions = ">=3.6"
pytest = ">=5.0,<6.0.0 || >6.0.0"
pytest-metadata = "*"
[[package]]
name = "pytest-lazy-fixture"
version = "0.6.3"
description = "It helps to use fixtures in pytest.mark.parametrize"
category = "dev"
optional = false
python-versions = "*"
[package.dependencies]
pytest = ">=3.2.5"
[[package]]
name = "pytest-metadata"
version = "1.11.0"
@ -1171,11 +1197,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt
docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"]
fast = []
proxy = ["aiohttp-socks"]
redis = ["aioredis"]
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "2fcd44a8937b3ea48196c8eba8ceb0533281af34c884103bcc5b4f5f16b817d5"
content-hash = "bc8fa6e61728e0463bdfb156aa53f52d9c4323fd3a78a73f587278bf6f908eac"
[metadata.files]
aiofiles = [
@ -1225,6 +1252,10 @@ aiohttp-socks = [
{file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"},
{file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"},
]
aioredis = [
{file = "aioredis-2.0.0a1-py3-none-any.whl", hash = "sha256:32d7910724282a475c91b8b34403867069a4f07bf0c5ad5fe66cd797322f9a0d"},
{file = "aioredis-2.0.0a1.tar.gz", hash = "sha256:5884f384b8ecb143bb73320a96e7c464fd38e117950a7d48340a35db8e35e7d2"},
]
alabaster = [
{file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
{file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
@ -1644,6 +1675,10 @@ pytest-html = [
{file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"},
{file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"},
]
pytest-lazy-fixture = [
{file = "pytest-lazy-fixture-0.6.3.tar.gz", hash = "sha256:0e7d0c7f74ba33e6e80905e9bfd81f9d15ef9a790de97993e34213deb5ad10ac"},
{file = "pytest_lazy_fixture-0.6.3-py3-none-any.whl", hash = "sha256:e0b379f38299ff27a653f03eaa69b08a6fd4484e46fd1c9907d984b9f9daeda6"},
]
pytest-metadata = [
{file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"},
{file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"},

View file

@ -38,8 +38,9 @@ Babel = "^2.9.1"
aiofiles = "^0.6.0"
async_lru = "^1.0.2"
aiohttp-socks = { version = "^0.5.5", optional = true }
aioredis = { version = "^1.3.1", optional = true }
typing-extensions = { version = "^3.7.4", python = "<3.8" }
magic-filter = {version = "1.0.0a1", allow-prereleases = true}
magic-filter = { version = "1.0.0a1", allow-prereleases = true }
sphinx = { version = "^3.1.0", optional = true }
sphinx-intl = { version = "^2.0.1", optional = true }
sphinx-autobuild = { version = "^2020.9.1", optional = true }
@ -50,6 +51,7 @@ Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
[tool.poetry.dev-dependencies]
aiohttp-socks = "^0.5"
aioredis = {version = "^1.3.1", allow-prereleases = true}
ipython = "^7.22.0"
uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" }
black = "^21.4b2"
@ -79,9 +81,11 @@ sphinx-copybutton = "^0.3.1"
furo = "^2020.11.15-beta.17"
sphinx-prompt = "^1.3.0"
Sphinx-Substitution-Extensions = "^2020.9.30"
pytest-lazy-fixture = "^0.6.3"
[tool.poetry.extras]
fast = ["uvloop"]
redis = ["aioredis"]
proxy = ["aiohttp-socks"]
docs = [
"sphinx",

View file

@ -1,9 +1,62 @@
import pytest
from _pytest.config import UsageError
from aioredis.connection import parse_url as parse_redis_url
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
def pytest_addoption(parser):
parser.addoption("--redis", default=None, help="run tests which require redis connection")
def pytest_configure(config):
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
def pytest_collection_modifyitems(config, items):
redis_uri = config.getoption("--redis")
if redis_uri is None:
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
for item in items:
if "redis" in item.keywords:
item.add_marker(skip_redis)
return
try:
parse_redis_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
@pytest.fixture(scope="session")
def redis_server(request):
redis_uri = request.config.getoption("--redis")
return redis_uri
@pytest.fixture()
@pytest.mark.redis
async def redis_storage(redis_server):
storage = RedisStorage.from_url(redis_server)
try:
yield storage
finally:
conn = await storage.redis
await conn.flushdb()
await storage.close()
@pytest.fixture()
async def memory_storage():
storage = MemoryStorage()
try:
yield storage
finally:
await storage.close()
@pytest.fixture()
def bot():
bot = MockedBot()

7
tests/docker-compose.yml Normal file
View file

@ -0,0 +1,7 @@
version: "3.9"
services:
redis:
image: redis:6-alpine
ports:
- "${REDIS_PORT-6379}:6379"

View file

@ -2,8 +2,8 @@ import datetime
from typing import Optional
import pytest
from aiogram.types import Chat, Message
from aiogram.types import Chat, Message
from tests.mocked_bot import MockedBot

View file

@ -423,7 +423,7 @@ class TestDispatcher:
assert User.get_current(False)
return kwargs
result = await router.update.trigger(update, test="PASS")
result = await router.update.trigger(update, test="PASS", bot=None)
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@ -526,8 +526,9 @@ class TestDispatcher:
assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize("as_task", [True, False])
@pytest.mark.asyncio
async def test_polling(self, bot: MockedBot):
async def test_polling(self, bot: MockedBot, as_task: bool):
dispatcher = Dispatcher()
async def _mock_updates(*_):
@ -539,8 +540,11 @@ class TestDispatcher:
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates:
patched_listen_updates.return_value = _mock_updates()
await dispatcher._polling(bot=bot)
mocked_process_update.assert_awaited()
await dispatcher._polling(bot=bot, handle_as_tasks=as_task)
if as_task:
pass
else:
mocked_process_update.assert_awaited()
@pytest.mark.asyncio
async def test_exception_handler_catch_exceptions(self):
@ -548,9 +552,12 @@ class TestDispatcher:
router = Router()
dp.include_router(router)
class CustomException(Exception):
pass
@router.message()
async def message_handler(message: Message):
raise Exception("KABOOM")
raise CustomException("KABOOM")
update = Update(
update_id=42,
@ -562,23 +569,23 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
)
with pytest.raises(Exception, match="KABOOM"):
await dp.update.trigger(update)
with pytest.raises(CustomException, match="KABOOM"):
await dp.update.trigger(update, bot=None)
@router.errors()
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
response = await dp.update.trigger(update)
response = await dp.update.trigger(update, bot=None)
assert response == "KABOOM"
@dp.errors()
async def root_error_handler(event: Update, exception: Exception):
return exception
response = await dp.update.trigger(update)
response = await dp.update.trigger(update, bot=None)
assert isinstance(response, Exception)
assert isinstance(response, CustomException)
assert str(response) == "KABOOM"
@pytest.mark.asyncio
@ -654,20 +661,3 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
],
)
def test_get_current_state_context(self, strategy, case, expected):
dp = Dispatcher(fsm_strategy=strategy)
chat_id, user_id = case
state = dp.current_state(chat_id=chat_id, user_id=user_id)
assert (state.chat_id, state.user_id) == expected

View file

@ -1,45 +0,0 @@
import pytest
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage, MemoryStorageRecord
@pytest.fixture()
def storage():
return MemoryStorage()
class TestMemoryStorage:
@pytest.mark.asyncio
async def test_set_state(self, storage: MemoryStorage):
assert await storage.get_state(chat_id=-42, user_id=42) is None
await storage.set_state(chat_id=-42, user_id=42, state="state")
assert await storage.get_state(chat_id=-42, user_id=42) == "state"
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].state == "state"
@pytest.mark.asyncio
async def test_set_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
await storage.set_data(chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(chat_id=-42, user_id=42) == {"foo": "bar"}
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].data == {"foo": "bar"}
@pytest.mark.asyncio
async def test_update_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == {
"foo": "bar"
}
assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == {
"foo": "bar",
"baz": "spam",
}

View file

@ -0,0 +1,21 @@
import pytest
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
@pytest.mark.redis
class TestRedisStorage:
@pytest.mark.parametrize(
"prefix_bot,result",
[
[False, "fsm:-1:2"],
[True, "fsm:42:-1:2"],
[{42: "kaboom"}, "fsm:kaboom:-1:2"],
[lambda bot: "kaboom", "fsm:kaboom:-1:2"],
],
)
@pytest.mark.asyncio
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
assert storage.generate_key(bot, -1, 2) == result

View file

@ -0,0 +1,44 @@
import pytest
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from tests.mocked_bot import MockedBot
@pytest.mark.parametrize(
"storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
)
class TestStorages:
@pytest.mark.asyncio
async def test_lock(self, bot: MockedBot, storage: BaseStorage):
# TODO: ?!?
async with storage.lock(bot=bot, chat_id=-42, user_id=42):
assert True, "You are kidding me?"
@pytest.mark.asyncio
async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
@pytest.mark.asyncio
async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
@pytest.mark.asyncio
async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
) == {"foo": "bar"}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
) == {"foo": "bar", "baz": "spam"}

View file

@ -2,27 +2,28 @@ import pytest
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from tests.mocked_bot import MockedBot
@pytest.fixture()
def state():
def state(bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
return FSMContext(storage=storage, user_id=-42, chat_id=42)
return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
class TestFSMContext:
@pytest.mark.asyncio
async def test_address_mapping(self):
async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
state = FSMContext(storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(storage=storage, chat_id=69, user_id=69)
state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
assert await state.get_state() == "test"
assert await state2.get_state() is None

View file

@ -0,0 +1,27 @@
import pytest
from aiogram.utils.auth_widget import check_integrity
TOKEN = "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11"
@pytest.fixture
def data():
return {
"id": "42",
"first_name": "John",
"last_name": "Smith",
"username": "username",
"photo_url": "https://t.me/i/userpic/320/picname.jpg",
"auth_date": "1565810688",
"hash": "c303db2b5a06fe41d23a9b14f7c545cfc11dcc7473c07c9c5034ae60062461ce",
}
class TestCheckIntegrity:
def test_ok(self, data):
assert check_integrity(TOKEN, data) is True
def test_fail(self, data):
data.pop("username")
assert check_integrity(TOKEN, data) is False

View file

@ -0,0 +1,97 @@
import pytest
from async_lru import alru_cache
from aiogram.utils.deep_linking import (
create_start_link,
create_startgroup_link,
decode_payload,
encode_payload,
)
# enable asyncio mode
from tests.mocked_bot import MockedBot
pytestmark = pytest.mark.asyncio
PAYLOADS = [
"foo",
"AAbbCCddEEff1122334455",
"aaBBccDDeeFF5544332211",
-12345678901234567890,
12345678901234567890,
]
WRONG_PAYLOADS = [
"@BotFather",
"Some:special$characters#=",
"spaces spaces spaces",
1234567890123456789.0,
]
@pytest.fixture(params=PAYLOADS, name="payload")
def payload_fixture(request):
return request.param
@pytest.fixture(params=WRONG_PAYLOADS, name="wrong_payload")
def wrong_payload_fixture(request):
return request.param
@pytest.fixture(autouse=True)
def get_bot_user_fixture(monkeypatch):
"""Monkey patching of bot.me calling."""
@alru_cache()
async def get_bot_user_mock(self):
from aiogram.types import User
return User(
id=12345678,
is_bot=True,
first_name="FirstName",
last_name="LastName",
username="username",
language_code="uk-UA",
)
monkeypatch.setattr(MockedBot, "me", get_bot_user_mock)
class TestDeepLinking:
async def test_get_start_link(self, bot, payload):
link = await create_start_link(bot=bot, payload=payload)
assert link == f"https://t.me/username?start={payload}"
async def test_wrong_symbols(self, bot, wrong_payload):
with pytest.raises(ValueError):
await create_start_link(bot, wrong_payload)
async def test_get_startgroup_link(self, bot, payload):
link = await create_startgroup_link(bot, payload)
assert link == f"https://t.me/username?startgroup={payload}"
async def test_filter_encode_and_decode(self, payload):
encoded = encode_payload(payload)
decoded = decode_payload(encoded)
assert decoded == str(payload)
async def test_get_start_link_with_encoding(self, bot, wrong_payload):
# define link
link = await create_start_link(bot, wrong_payload, encode=True)
# define reference link
encoded_payload = encode_payload(wrong_payload)
assert link == f"https://t.me/username?start={encoded_payload}"
async def test_64_len_payload(self, bot):
payload = "p" * 64
link = await create_start_link(bot, payload)
assert link
async def test_too_long_payload(self, bot):
payload = "p" * 65
print(payload, len(payload))
with pytest.raises(ValueError):
await create_start_link(bot, payload)

View file

@ -35,7 +35,7 @@ class TestMarkdown:
[hitalic, ("test", "test"), " ", "<i>test test</i>"],
[code, ("test", "test"), " ", "`test test`"],
[hcode, ("test", "test"), " ", "<code>test test</code>"],
[pre, ("test", "test"), " ", "```test test```"],
[pre, ("test", "test"), " ", "```\ntest test\n```"],
[hpre, ("test", "test"), " ", "<pre>test test</pre>"],
[underline, ("test", "test"), " ", "__\rtest test__\r"],
[hunderline, ("test", "test"), " ", "<u>test test</u>"],

View file

@ -55,7 +55,7 @@ class TestTextDecoration:
[markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"],
[markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"],
[markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"],
[markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```test```"],
[markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```\ntest\n```"],
[
markdown_decoration,
MessageEntity(type="pre", offset=0, length=5, language="python"),