diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index f35271f0..d87811c7 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -64,9 +64,14 @@ jobs:
run: |
poetry run black --check --diff aiogram tests
+ - name: Start Redis
+ uses: supercharge/redis-github-action@1.2.0
+ with:
+ redis-version: 6
+
- name: Run tests
run: |
- poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml
+ poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml --redis redis://redis:6379/0
- uses: codecov/codecov-action@v1
with:
diff --git a/Makefile b/Makefile
index c1fa9797..19420ac1 100644
--- a/Makefile
+++ b/Makefile
@@ -6,6 +6,8 @@ python := $(py) python
reports_dir := reports
+redis_connection := redis://localhost:6379
+
.PHONY: help
help:
@echo "======================================================================================="
@@ -99,12 +101,12 @@ lint: isort black flake8 mypy
.PHONY: test
test:
- $(py) pytest --cov=aiogram --cov-config .coveragerc tests/
+ $(py) pytest --cov=aiogram --cov-config .coveragerc tests/ --redis $(redis_connection)
.PHONY: test-coverage
test-coverage:
mkdir -p $(reports_dir)/tests/
- $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/
+ $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
.PHONY: test-coverage-report
test-coverage-report:
diff --git a/aiogram/__init__.py b/aiogram/__init__.py
index 31b52552..639e68c9 100644
--- a/aiogram/__init__.py
+++ b/aiogram/__init__.py
@@ -6,6 +6,8 @@ from .dispatcher import filters, handler
from .dispatcher.dispatcher import Dispatcher
from .dispatcher.middlewares.base import BaseMiddleware
from .dispatcher.router import Router
+from .utils.text_decorations import html_decoration as _html_decoration
+from .utils.text_decorations import markdown_decoration as _markdown_decoration
try:
import uvloop as _uvloop
@@ -15,6 +17,8 @@ except ImportError: # pragma: no cover
pass
F = MagicFilter()
+html = _html_decoration
+md = _markdown_decoration
__all__ = (
"__api_version__",
@@ -29,6 +33,8 @@ __all__ = (
"filters",
"handler",
"F",
+ "html",
+ "md",
)
__version__ = "3.0.0-alpha.8"
diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py
index 78ff5aaf..95c721a1 100644
--- a/aiogram/dispatcher/dispatcher.py
+++ b/aiogram/dispatcher/dispatcher.py
@@ -4,7 +4,7 @@ import asyncio
import contextvars
import warnings
from asyncio import CancelledError, Future, Lock
-from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
+from typing import Any, AsyncGenerator, Dict, Optional, Union
from .. import loggers
from ..client.bot import Bot
@@ -13,7 +13,6 @@ from ..types import TelegramObject, Update, User
from ..utils.exceptions.base import TelegramAPIError
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
-from .fsm.context import FSMContext
from .fsm.middleware import FSMContextMiddleware
from .fsm.storage.base import BaseStorage
from .fsm.storage.memory import MemoryStorage
@@ -32,7 +31,7 @@ class Dispatcher(Router):
self,
storage: Optional[BaseStorage] = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
- isolate_events: bool = True,
+ isolate_events: bool = False,
**kwargs: Any,
) -> None:
super(Dispatcher, self).__init__(**kwargs)
@@ -255,7 +254,9 @@ class Dispatcher(Router):
)
return True # because update was processed but unsuccessful
- async def _polling(self, bot: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
+ async def _polling(
+ self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
+ ) -> None:
"""
Internal polling process
@@ -264,7 +265,11 @@ class Dispatcher(Router):
:return:
"""
async for update in self._listen_updates(bot, polling_timeout=polling_timeout):
- await self._process_update(bot=bot, update=update, **kwargs)
+ handle_update = self._process_update(bot=bot, update=update, **kwargs)
+ if handle_as_tasks:
+ asyncio.create_task(handle_update)
+ else:
+ await handle_update
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
@@ -342,11 +347,15 @@ class Dispatcher(Router):
return None
- async def start_polling(self, *bots: Bot, polling_timeout: int = 10, **kwargs: Any) -> None:
+ async def start_polling(
+ self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any
+ ) -> None:
"""
Polling runner
:param bots:
+ :param polling_timeout:
+ :param handle_as_tasks:
:param kwargs:
:return:
"""
@@ -363,7 +372,12 @@ class Dispatcher(Router):
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
coro_list.append(
- self._polling(bot=bot, polling_timeout=polling_timeout, **kwargs)
+ self._polling(
+ bot=bot,
+ handle_as_tasks=handle_as_tasks,
+ polling_timeout=polling_timeout,
+ **kwargs,
+ )
)
await asyncio.gather(*coro_list)
finally:
@@ -372,22 +386,27 @@ class Dispatcher(Router):
loggers.dispatcher.info("Polling stopped")
await self.emit_shutdown(**workflow_data)
- def run_polling(self, *bots: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
+ def run_polling(
+ self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
+ ) -> None:
"""
Run many bots with polling
:param bots: Bot instances
:param polling_timeout: Poling timeout
+ :param handle_as_tasks: Run task for each event and no wait result
:param kwargs: contextual data
:return:
"""
try:
return asyncio.run(
- self.start_polling(*bots, **kwargs, polling_timeout=polling_timeout)
+ self.start_polling(
+ *bots,
+ **kwargs,
+ polling_timeout=polling_timeout,
+ handle_as_tasks=handle_as_tasks,
+ )
)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown
pass
-
- def current_state(self, chat_id: int, user_id: int) -> FSMContext:
- return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id))
diff --git a/aiogram/dispatcher/filters/command.py b/aiogram/dispatcher/filters/command.py
index 899b09be..43ff31c2 100644
--- a/aiogram/dispatcher/filters/command.py
+++ b/aiogram/dispatcher/filters/command.py
@@ -1,18 +1,24 @@
from __future__ import annotations
import re
-from dataclasses import dataclass, field
-from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
+from dataclasses import dataclass, field, replace
+from typing import Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
-from pydantic import validator
+from magic_filter import MagicFilter
+from pydantic import Field, validator
from aiogram import Bot
from aiogram.dispatcher.filters import BaseFilter
from aiogram.types import Message
+from aiogram.utils.deep_linking import decode_payload
CommandPatterType = Union[str, re.Pattern]
+class CommandException(Exception):
+ pass
+
+
class Command(BaseFilter):
"""
This filter can be helpful for handling commands from the text messages.
@@ -29,6 +35,8 @@ class Command(BaseFilter):
"""Ignore case (Does not work with regexp, use flags instead)"""
commands_ignore_mention: bool = False
"""Ignore bot mention. By default bot can not handle commands intended for other bots"""
+ command_magic: Optional[MagicFilter] = None
+ """Validate command object via Magic filter after all checks done"""
@validator("commands", always=True)
def _validate_commands(
@@ -39,22 +47,17 @@ class Command(BaseFilter):
return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
- if not message.text:
+ text = message.text or message.caption
+ if not text:
return False
- return await self.parse_command(text=message.text, bot=bot)
-
- async def parse_command(self, text: str, bot: Bot) -> Union[bool, Dict[str, CommandObject]]:
- """
- Extract command from the text and validate
-
- :param text:
- :param bot:
- :return:
- """
- if not text.strip():
+ try:
+ command = await self.parse_command(text=cast(str, message.text), bot=bot)
+ except CommandException:
return False
+ return {"command": command}
+ def extract_command(self, text: str) -> CommandObject:
# First step: separate command with arguments
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"]
full_command, *args = text.split(maxsplit=1)
@@ -62,46 +65,52 @@ class Command(BaseFilter):
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
+ return CommandObject(
+ prefix=prefix, command=command, mention=mention, args=args[0] if args else None
+ )
- # Validate prefixes
- if prefix not in self.commands_prefix:
- return False
+ def validate_prefix(self, command: CommandObject) -> None:
+ if command.prefix not in self.commands_prefix:
+ raise CommandException("Invalid command prefix")
- # Validate mention
- if mention and not self.commands_ignore_mention:
+ async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
+ if command.mention and not self.commands_ignore_mention:
me = await bot.me()
- if me.username and mention.lower() != me.username.lower():
- return False
+ if me.username and command.mention.lower() != me.username.lower():
+ raise CommandException("Mention did not match")
- # Validate command
+ def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
- result = allowed_command.match(command)
+ result = allowed_command.match(command.command)
if result:
- return {
- "command": CommandObject(
- prefix=prefix,
- command=command,
- mention=mention,
- args=args[0] if args else None,
- match=result,
- )
- }
+ return replace(command, match=result)
+ elif command.command == allowed_command: # String
+ return command
+ raise CommandException("Command did not match pattern")
- elif command == allowed_command: # String
- return {
- "command": CommandObject(
- prefix=prefix,
- command=command,
- mention=mention,
- args=args[0] if args else None,
- match=None,
- )
- }
+ async def parse_command(self, text: str, bot: Bot) -> CommandObject:
+ """
+ Extract command from the text and validate
- return False
+ :param text:
+ :param bot:
+ :return:
+ """
+ command = self.extract_command(text)
+ self.validate_prefix(command=command)
+ await self.validate_mention(bot=bot, command=command)
+ command = self.validate_command(command)
+ self.do_magic(command=command)
+ return command
+
+ def do_magic(self, command: CommandObject) -> None:
+ if not self.command_magic:
+ return
+ if not self.command_magic.resolve(command):
+ raise CommandException("Rejected via magic filter")
class Config:
arbitrary_types_allowed = True
@@ -143,3 +152,40 @@ class CommandObject:
if self.args:
line += " " + self.args
return line
+
+
+class CommandStart(Command):
+ commands: Tuple[str] = Field(("start",), const=True)
+ commands_prefix: str = Field("/", const=True)
+ deep_link: bool = False
+ deep_link_encoded: bool = False
+
+ async def parse_command(self, text: str, bot: Bot) -> CommandObject:
+ """
+ Extract command from the text and validate
+
+ :param text:
+ :param bot:
+ :return:
+ """
+ command = self.extract_command(text)
+ self.validate_prefix(command=command)
+ await self.validate_mention(bot=bot, command=command)
+ command = self.validate_command(command)
+ command = self.validate_deeplink(command=command)
+ self.do_magic(command=command)
+ return command
+
+ def validate_deeplink(self, command: CommandObject) -> CommandObject:
+ if not self.deep_link:
+ return command
+ if not command.args:
+ raise CommandException("Deep-link was missing")
+ args = command.args
+ if self.deep_link_encoded:
+ try:
+ args = decode_payload(args)
+ except UnicodeDecodeError as e:
+ raise CommandException(f"Failed to decode Base64: {e}")
+ return replace(command, args=args)
+ return command
diff --git a/aiogram/dispatcher/fsm/context.py b/aiogram/dispatcher/fsm/context.py
index 78ed480b..dc4e4030 100644
--- a/aiogram/dispatcher/fsm/context.py
+++ b/aiogram/dispatcher/fsm/context.py
@@ -1,25 +1,35 @@
from typing import Any, Dict, Optional
+from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
class FSMContext:
- def __init__(self, storage: BaseStorage, chat_id: int, user_id: int) -> None:
+ def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
+ self.bot = bot
self.storage = storage
self.chat_id = chat_id
self.user_id = user_id
async def set_state(self, state: StateType = None) -> None:
- await self.storage.set_state(chat_id=self.chat_id, user_id=self.user_id, state=state)
+ await self.storage.set_state(
+ bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
+ )
async def get_state(self) -> Optional[str]:
- return await self.storage.get_state(chat_id=self.chat_id, user_id=self.user_id)
+ return await self.storage.get_state(
+ bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
+ )
async def set_data(self, data: Dict[str, Any]) -> None:
- await self.storage.set_data(chat_id=self.chat_id, user_id=self.user_id, data=data)
+ await self.storage.set_data(
+ bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
+ )
async def get_data(self) -> Dict[str, Any]:
- return await self.storage.get_data(chat_id=self.chat_id, user_id=self.user_id)
+ return await self.storage.get_data(
+ bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
+ )
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
@@ -27,7 +37,7 @@ class FSMContext:
if data:
kwargs.update(data)
return await self.storage.update_data(
- chat_id=self.chat_id, user_id=self.user_id, data=kwargs
+ bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
async def clear(self) -> None:
diff --git a/aiogram/dispatcher/fsm/middleware.py b/aiogram/dispatcher/fsm/middleware.py
index 1e3ba91c..734c5825 100644
--- a/aiogram/dispatcher/fsm/middleware.py
+++ b/aiogram/dispatcher/fsm/middleware.py
@@ -1,5 +1,6 @@
-from typing import Any, Awaitable, Callable, Dict, Optional
+from typing import Any, Awaitable, Callable, Dict, Optional, cast
+from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
@@ -24,24 +25,27 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
event: Update,
data: Dict[str, Any],
) -> Any:
- context = self.resolve_event_context(data)
+ bot: Bot = cast(Bot, data["bot"])
+ context = self.resolve_event_context(bot, data)
data["fsm_storage"] = self.storage
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
- async with self.storage.lock(chat_id=context.chat_id, user_id=context.user_id):
+ async with self.storage.lock(
+ bot=bot, chat_id=context.chat_id, user_id=context.user_id
+ ):
return await handler(event, data)
return await handler(event, data)
- def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]:
+ def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
chat_id = chat.id if chat else None
user_id = user.id if user else None
- return self.resolve_context(chat_id=chat_id, user_id=user_id)
+ return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
def resolve_context(
- self, chat_id: Optional[int], user_id: Optional[int]
+ self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
@@ -50,8 +54,8 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
- return self.get_context(chat_id=chat_id, user_id=user_id)
+ return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return None
- def get_context(self, chat_id: int, user_id: int) -> FSMContext:
- return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id)
+ def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
+ return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)
diff --git a/aiogram/dispatcher/fsm/storage/base.py b/aiogram/dispatcher/fsm/storage/base.py
index f394cd61..42826915 100644
--- a/aiogram/dispatcher/fsm/storage/base.py
+++ b/aiogram/dispatcher/fsm/storage/base.py
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Union
+from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]]
@@ -11,34 +12,42 @@ class BaseStorage(ABC):
@abstractmethod
@asynccontextmanager
async def lock(
- self, chat_id: int, user_id: int
+ self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
yield None
@abstractmethod
async def set_state(
- self, chat_id: int, user_id: int, state: StateType = None
+ self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
pass
@abstractmethod
- async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: # pragma: no cover
+ async def get_state(
+ self, bot: Bot, chat_id: int, user_id: int
+ ) -> Optional[str]: # pragma: no cover
pass
@abstractmethod
async def set_data(
- self, chat_id: int, user_id: int, data: Dict[str, Any]
+ self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
pass
@abstractmethod
- async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: # pragma: no cover
+ async def get_data(
+ self, bot: Bot, chat_id: int, user_id: int
+ ) -> Dict[str, Any]: # pragma: no cover
pass
async def update_data(
- self, chat_id: int, user_id: int, data: Dict[str, Any]
+ self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]:
- current_data = await self.get_data(chat_id=chat_id, user_id=user_id)
+ current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
current_data.update(data)
- await self.set_data(chat_id=chat_id, user_id=user_id, data=current_data)
+ await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
return current_data.copy()
+
+ @abstractmethod
+ async def close(self) -> None: # pragma: no cover
+ pass
diff --git a/aiogram/dispatcher/fsm/storage/memory.py b/aiogram/dispatcher/fsm/storage/memory.py
index 933e225c..3e82d306 100644
--- a/aiogram/dispatcher/fsm/storage/memory.py
+++ b/aiogram/dispatcher/fsm/storage/memory.py
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
+from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
@@ -17,23 +18,30 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage):
def __init__(self) -> None:
- self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict(
- lambda: defaultdict(MemoryStorageRecord)
- )
+ self.storage: DefaultDict[
+ Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
+ ] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
+
+ async def close(self) -> None:
+ pass
@asynccontextmanager
- async def lock(self, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
- async with self.storage[chat_id][user_id].lock:
+ async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
+ async with self.storage[bot][chat_id][user_id].lock:
yield None
- async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None:
- self.storage[chat_id][user_id].state = state.state if isinstance(state, State) else state
+ async def set_state(
+ self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
+ ) -> None:
+ self.storage[bot][chat_id][user_id].state = (
+ state.state if isinstance(state, State) else state
+ )
- async def get_state(self, chat_id: int, user_id: int) -> Optional[str]:
- return self.storage[chat_id][user_id].state
+ async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
+ return self.storage[bot][chat_id][user_id].state
- async def set_data(self, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
- self.storage[chat_id][user_id].data = data.copy()
+ async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
+ self.storage[bot][chat_id][user_id].data = data.copy()
- async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]:
- return self.storage[chat_id][user_id].data.copy()
+ async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
+ return self.storage[bot][chat_id][user_id].data.copy()
diff --git a/aiogram/dispatcher/fsm/storage/redis.py b/aiogram/dispatcher/fsm/storage/redis.py
new file mode 100644
index 00000000..64c832f9
--- /dev/null
+++ b/aiogram/dispatcher/fsm/storage/redis.py
@@ -0,0 +1,101 @@
+from contextlib import asynccontextmanager
+from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
+
+from aioredis import ConnectionPool, Redis
+
+from aiogram import Bot
+from aiogram.dispatcher.fsm.state import State
+from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
+
+PrefixFactoryType = Callable[[Bot], str]
+STATE_KEY = "state"
+STATE_DATA_KEY = "data"
+STATE_LOCK_KEY = "lock"
+
+DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
+
+
+class RedisStorage(BaseStorage):
+ def __init__(
+ self,
+ redis: Redis,
+ prefix: str = "fsm",
+ prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
+ state_ttl: Optional[int] = None,
+ data_ttl: Optional[int] = None,
+ lock_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ if lock_kwargs is None:
+ lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
+ self.redis = redis
+ self.prefix = prefix
+ self.prefix_bot = prefix_bot
+ self.state_ttl = state_ttl
+ self.data_ttl = data_ttl
+ self.lock_kwargs = lock_kwargs
+
+ @classmethod
+ def from_url(
+ cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
+ ) -> "RedisStorage":
+ if connection_kwargs is None:
+ connection_kwargs = {}
+ pool = ConnectionPool.from_url(url, **connection_kwargs)
+ redis = Redis(connection_pool=pool)
+ return cls(redis=redis, **kwargs)
+
+ async def close(self) -> None:
+ await self.redis.close()
+
+ def generate_key(self, bot: Bot, *parts: Any) -> str:
+ prefix_parts = [self.prefix]
+ if self.prefix_bot:
+ if isinstance(self.prefix_bot, dict):
+ prefix_parts.append(self.prefix_bot[bot.id])
+ elif callable(self.prefix_bot):
+ prefix_parts.append(self.prefix_bot(bot))
+ else:
+ prefix_parts.append(str(bot.id))
+ prefix_parts.extend(parts)
+ return ":".join(map(str, prefix_parts))
+
+ @asynccontextmanager
+ async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
+ key = self.generate_key(bot, chat_id, user_id, STATE_LOCK_KEY)
+ async with self.redis.lock(name=key, **self.lock_kwargs):
+ yield None
+
+ async def set_state(
+ self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
+ ) -> None:
+ key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
+ if state is None:
+ await self.redis.delete(key)
+ else:
+ await self.redis.set(
+ key, state.state if isinstance(state, State) else state, ex=self.state_ttl
+ )
+
+ async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
+ key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
+ value = await self.redis.get(key)
+ if isinstance(value, bytes):
+ return value.decode("utf-8")
+ return cast(Optional[str], value)
+
+ async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
+ key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
+ if not data:
+ await self.redis.delete(key)
+ return
+ json_data = bot.session.json_dumps(data)
+ await self.redis.set(key, json_data, ex=self.data_ttl)
+
+ async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
+ key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
+ value = await self.redis.get(key)
+ if value is None:
+ return {}
+ if isinstance(value, bytes):
+ value = value.decode("utf-8")
+ return cast(Dict[str, Any], bot.session.json_loads(value))
diff --git a/aiogram/utils/auth_widget.py b/aiogram/utils/auth_widget.py
new file mode 100644
index 00000000..a67afe65
--- /dev/null
+++ b/aiogram/utils/auth_widget.py
@@ -0,0 +1,34 @@
+import hashlib
+import hmac
+from typing import Any, Dict
+
+
+def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
+ """
+ Generate hexadecimal representation
+ of the HMAC-SHA-256 signature of the data-check-string
+ with the SHA256 hash of the bot's token used as a secret key
+
+ :param token:
+ :param hash:
+ :param kwargs: all params received on auth
+ :return:
+ """
+ secret = hashlib.sha256(token.encode("utf-8"))
+ check_string = "\n".join(map(lambda k: f"{k}={kwargs[k]}", sorted(kwargs)))
+ hmac_string = hmac.new(
+ secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
+ ).hexdigest()
+ return hmac_string == hash
+
+
+def check_integrity(token: str, data: Dict[str, Any]) -> bool:
+ """
+ Verify the authentication and the integrity
+ of the data received on user's auth
+
+ :param token: Bot's token
+ :param data: all data that came on auth
+ :return:
+ """
+ return check_signature(token, **data)
diff --git a/aiogram/utils/deep_linking.py b/aiogram/utils/deep_linking.py
new file mode 100644
index 00000000..caac2c26
--- /dev/null
+++ b/aiogram/utils/deep_linking.py
@@ -0,0 +1,131 @@
+"""
+Deep linking
+
+Telegram bots have a deep linking mechanism, that allows for passing
+additional parameters to the bot on startup. It could be a command that
+launches the bot — or an auth token to connect the user's Telegram
+account to their account on some external service.
+
+You can read detailed description in the source:
+https://core.telegram.org/bots#deep-linking
+
+We have add some utils to get deep links more handy.
+
+Basic link example:
+
+ .. code-block:: python
+
+ from aiogram.utils.deep_linking import get_start_link
+ link = await get_start_link('foo')
+
+ # result: 'https://t.me/MyBot?start=foo'
+
+Encoded link example:
+
+ .. code-block:: python
+
+ from aiogram.utils.deep_linking import get_start_link
+
+ link = await get_start_link('foo', encode=True)
+ # result: 'https://t.me/MyBot?start=Zm9v'
+
+Decode it back example:
+ .. code-block:: python
+
+ from aiogram.utils.deep_linking import decode_payload
+ from aiogram.types import Message
+
+ @dp.message_handler(commands=["start"])
+ async def handler(message: Message):
+ args = message.get_args()
+ payload = decode_payload(args)
+ await message.answer(f"Your payload: {payload}")
+
+"""
+import re
+from base64 import urlsafe_b64decode, urlsafe_b64encode
+from typing import Literal, cast
+
+from aiogram import Bot
+from aiogram.utils.link import create_telegram_link
+
+BAD_PATTERN = re.compile(r"[^_A-z0-9-]")
+
+
+async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str:
+ """
+ Create 'start' deep link with your payload.
+
+ If you need to encode payload or pass special characters -
+ set encode as True
+
+ :param bot: bot instance
+ :param payload: args passed with /start
+ :param encode: encode payload with base64url
+ :return: link
+ """
+ username = (await bot.me()).username
+ return create_deep_link(username=username, link_type="start", payload=payload, encode=encode)
+
+
+async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str:
+ """
+ Create 'startgroup' deep link with your payload.
+
+ If you need to encode payload or pass special characters -
+ set encode as True
+
+ :param bot: bot instance
+ :param payload: args passed with /start
+ :param encode: encode payload with base64url
+ :return: link
+ """
+ username = (await bot.me()).username
+ return create_deep_link(
+ username=username, link_type="startgroup", payload=payload, encode=encode
+ )
+
+
+def create_deep_link(
+ username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False
+) -> str:
+ """
+ Create deep link.
+
+ :param username:
+ :param link_type: `start` or `startgroup`
+ :param payload: any string-convertible data
+ :param encode: pass True to encode the payload
+ :return: deeplink
+ """
+ if not isinstance(payload, str):
+ payload = str(payload)
+
+ if encode:
+ payload = encode_payload(payload)
+
+ if re.search(BAD_PATTERN, payload):
+ raise ValueError(
+ "Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
+ "Pass `encode=True` or encode payload manually."
+ )
+
+ if len(payload) > 64:
+ raise ValueError("Payload must be up to 64 characters long.")
+
+ return create_telegram_link(username, **{cast(str, link_type): payload})
+
+
+def encode_payload(payload: str) -> str:
+ """Encode payload with URL-safe base64url."""
+ payload = str(payload)
+ bytes_payload: bytes = urlsafe_b64encode(payload.encode())
+ str_payload = bytes_payload.decode()
+ return str_payload.replace("=", "")
+
+
+def decode_payload(payload: str) -> str:
+ """Decode payload with URL-safe base64url."""
+ payload += "=" * (4 - len(payload) % 4)
+ result: bytes = urlsafe_b64decode(payload)
+ return result.decode()
diff --git a/aiogram/utils/keyboard.py b/aiogram/utils/keyboard.py
index 19409c94..9cb10b02 100644
--- a/aiogram/utils/keyboard.py
+++ b/aiogram/utils/keyboard.py
@@ -2,13 +2,27 @@ from __future__ import annotations
from itertools import chain
from itertools import cycle as repeat_all
-from typing import Any, Generator, Generic, Iterable, List, Optional, Type, TypeVar, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Generator,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ no_type_check,
+)
from aiogram.dispatcher.filters.callback_data import CallbackData
from aiogram.types import (
+ CallbackGame,
InlineKeyboardButton,
InlineKeyboardMarkup,
KeyboardButton,
+ LoginUrl,
ReplyKeyboardMarkup,
)
@@ -239,3 +253,28 @@ def repeat_last(items: Iterable[T]) -> Generator[T, None, None]:
except StopIteration:
finished = True
yield value
+
+
+class InlineKeyboardConstructor(KeyboardConstructor[InlineKeyboardButton]):
+ if TYPE_CHECKING: # pragma: no cover
+
+ @no_type_check
+ def button(
+ self,
+ text: str,
+ url: Optional[str] = None,
+ login_url: Optional[LoginUrl] = None,
+ callback_data: Optional[Union[str, CallbackData]] = None,
+ switch_inline_query: Optional[str] = None,
+ switch_inline_query_current_chat: Optional[str] = None,
+ callback_game: Optional[CallbackGame] = None,
+ pay: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> "KeyboardConstructor[InlineKeyboardButton]":
+ ...
+
+ def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup:
+ ...
+
+ def __init__(self) -> None:
+ super().__init__(InlineKeyboardButton)
diff --git a/aiogram/utils/link.py b/aiogram/utils/link.py
new file mode 100644
index 00000000..87d402e2
--- /dev/null
+++ b/aiogram/utils/link.py
@@ -0,0 +1,18 @@
+from typing import Any
+from urllib.parse import urlencode, urljoin
+
+
+def create_tg_link(link: str, **kwargs: Any) -> str:
+ url = f"tg://{link}"
+ if kwargs:
+ query = urlencode(kwargs)
+ url += f"?{query}"
+ return url
+
+
+def create_telegram_link(uri: str, **kwargs: Any) -> str:
+ url = urljoin("https://t.me", uri)
+ if kwargs:
+ query = urlencode(query=kwargs)
+ url += f"?{query}"
+ return url
diff --git a/aiogram/utils/text_decorations.py b/aiogram/utils/text_decorations.py
index a41e481f..23c9c2a7 100644
--- a/aiogram/utils/text_decorations.py
+++ b/aiogram/utils/text_decorations.py
@@ -183,7 +183,7 @@ class MarkdownDecoration(TextDecoration):
return f"`{value}`"
def pre(self, value: str) -> str:
- return f"```{value}```"
+ return f"```\n{value}\n```"
def pre_language(self, value: str, language: str) -> str:
return f"```{language}\n{value}\n```"
diff --git a/mypy.ini b/mypy.ini
index afe61218..a75c96cb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,6 +1,6 @@
[mypy]
;plugins = pydantic.mypy
-python_version = 3.7
+python_version = 3.8
show_error_codes = True
show_error_context = True
pretty = True
@@ -29,3 +29,6 @@ ignore_missing_imports = True
[mypy-uvloop]
ignore_missing_imports = True
+
+[mypy-aioredis]
+ignore_missing_imports = True
diff --git a/poetry.lock b/poetry.lock
index 7c3acd24..d3b3e036 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -38,6 +38,21 @@ aiohttp = ">=2.3.2"
attrs = ">=19.2.0"
python-socks = {version = ">=1.0.1", extras = ["asyncio"]}
+[[package]]
+name = "aioredis"
+version = "2.0.0a1"
+description = "asyncio (PEP 3156) Redis support"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+async-timeout = "*"
+typing-extensions = "*"
+
+[package.extras]
+hiredis = ["hiredis (>=1.0)"]
+
[[package]]
name = "alabaster"
version = "0.7.12"
@@ -764,6 +779,17 @@ python-versions = ">=3.6"
pytest = ">=5.0,<6.0.0 || >6.0.0"
pytest-metadata = "*"
+[[package]]
+name = "pytest-lazy-fixture"
+version = "0.6.3"
+description = "It helps to use fixtures in pytest.mark.parametrize"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+pytest = ">=3.2.5"
+
[[package]]
name = "pytest-metadata"
version = "1.11.0"
@@ -1171,11 +1197,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt
docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"]
fast = []
proxy = ["aiohttp-socks"]
+redis = ["aioredis"]
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "2fcd44a8937b3ea48196c8eba8ceb0533281af34c884103bcc5b4f5f16b817d5"
+content-hash = "bc8fa6e61728e0463bdfb156aa53f52d9c4323fd3a78a73f587278bf6f908eac"
[metadata.files]
aiofiles = [
@@ -1225,6 +1252,10 @@ aiohttp-socks = [
{file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"},
{file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"},
]
+aioredis = [
+ {file = "aioredis-2.0.0a1-py3-none-any.whl", hash = "sha256:32d7910724282a475c91b8b34403867069a4f07bf0c5ad5fe66cd797322f9a0d"},
+ {file = "aioredis-2.0.0a1.tar.gz", hash = "sha256:5884f384b8ecb143bb73320a96e7c464fd38e117950a7d48340a35db8e35e7d2"},
+]
alabaster = [
{file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
{file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
@@ -1644,6 +1675,10 @@ pytest-html = [
{file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"},
{file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"},
]
+pytest-lazy-fixture = [
+ {file = "pytest-lazy-fixture-0.6.3.tar.gz", hash = "sha256:0e7d0c7f74ba33e6e80905e9bfd81f9d15ef9a790de97993e34213deb5ad10ac"},
+ {file = "pytest_lazy_fixture-0.6.3-py3-none-any.whl", hash = "sha256:e0b379f38299ff27a653f03eaa69b08a6fd4484e46fd1c9907d984b9f9daeda6"},
+]
pytest-metadata = [
{file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"},
{file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"},
diff --git a/pyproject.toml b/pyproject.toml
index aa283b07..c66127ac 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,8 +38,9 @@ Babel = "^2.9.1"
aiofiles = "^0.6.0"
async_lru = "^1.0.2"
aiohttp-socks = { version = "^0.5.5", optional = true }
+aioredis = { version = "^1.3.1", optional = true }
typing-extensions = { version = "^3.7.4", python = "<3.8" }
-magic-filter = {version = "1.0.0a1", allow-prereleases = true}
+magic-filter = { version = "1.0.0a1", allow-prereleases = true }
sphinx = { version = "^3.1.0", optional = true }
sphinx-intl = { version = "^2.0.1", optional = true }
sphinx-autobuild = { version = "^2020.9.1", optional = true }
@@ -50,6 +51,7 @@ Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
[tool.poetry.dev-dependencies]
aiohttp-socks = "^0.5"
+aioredis = {version = "^1.3.1", allow-prereleases = true}
ipython = "^7.22.0"
uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" }
black = "^21.4b2"
@@ -79,9 +81,11 @@ sphinx-copybutton = "^0.3.1"
furo = "^2020.11.15-beta.17"
sphinx-prompt = "^1.3.0"
Sphinx-Substitution-Extensions = "^2020.9.30"
+pytest-lazy-fixture = "^0.6.3"
[tool.poetry.extras]
fast = ["uvloop"]
+redis = ["aioredis"]
proxy = ["aiohttp-socks"]
docs = [
"sphinx",
diff --git a/tests/conftest.py b/tests/conftest.py
index 60d9d0fe..392a07c2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,9 +1,62 @@
import pytest
+from _pytest.config import UsageError
+from aioredis.connection import parse_url as parse_redis_url
from aiogram import Bot
+from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
+from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
+def pytest_addoption(parser):
+ parser.addoption("--redis", default=None, help="run tests which require redis connection")
+
+
+def pytest_configure(config):
+ config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
+
+
+def pytest_collection_modifyitems(config, items):
+ redis_uri = config.getoption("--redis")
+ if redis_uri is None:
+ skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
+ for item in items:
+ if "redis" in item.keywords:
+ item.add_marker(skip_redis)
+ return
+ try:
+ parse_redis_url(redis_uri)
+ except ValueError as e:
+ raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
+
+
+@pytest.fixture(scope="session")
+def redis_server(request):
+ redis_uri = request.config.getoption("--redis")
+ return redis_uri
+
+
+@pytest.fixture()
+@pytest.mark.redis
+async def redis_storage(redis_server):
+ storage = RedisStorage.from_url(redis_server)
+ try:
+ yield storage
+ finally:
+ conn = await storage.redis
+ await conn.flushdb()
+ await storage.close()
+
+
+@pytest.fixture()
+async def memory_storage():
+ storage = MemoryStorage()
+ try:
+ yield storage
+ finally:
+ await storage.close()
+
+
@pytest.fixture()
def bot():
bot = MockedBot()
diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml
new file mode 100644
index 00000000..453f5e5a
--- /dev/null
+++ b/tests/docker-compose.yml
@@ -0,0 +1,7 @@
+version: "3.9"
+
+services:
+ redis:
+ image: redis:6-alpine
+ ports:
+ - "${REDIS_PORT-6379}:6379"
diff --git a/tests/test_api/test_methods/test_get_url.py b/tests/test_api/test_methods/test_get_url.py
index 3c769ca2..76b24200 100644
--- a/tests/test_api/test_methods/test_get_url.py
+++ b/tests/test_api/test_methods/test_get_url.py
@@ -2,8 +2,8 @@ import datetime
from typing import Optional
import pytest
-from aiogram.types import Chat, Message
+from aiogram.types import Chat, Message
from tests.mocked_bot import MockedBot
diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py
index ecf44712..5356a1c1 100644
--- a/tests/test_dispatcher/test_dispatcher.py
+++ b/tests/test_dispatcher/test_dispatcher.py
@@ -423,7 +423,7 @@ class TestDispatcher:
assert User.get_current(False)
return kwargs
- result = await router.update.trigger(update, test="PASS")
+ result = await router.update.trigger(update, test="PASS", bot=None)
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@@ -526,8 +526,9 @@ class TestDispatcher:
assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0]
+ @pytest.mark.parametrize("as_task", [True, False])
@pytest.mark.asyncio
- async def test_polling(self, bot: MockedBot):
+ async def test_polling(self, bot: MockedBot, as_task: bool):
dispatcher = Dispatcher()
async def _mock_updates(*_):
@@ -539,8 +540,11 @@ class TestDispatcher:
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates:
patched_listen_updates.return_value = _mock_updates()
- await dispatcher._polling(bot=bot)
- mocked_process_update.assert_awaited()
+ await dispatcher._polling(bot=bot, handle_as_tasks=as_task)
+ if as_task:
+ pass
+ else:
+ mocked_process_update.assert_awaited()
@pytest.mark.asyncio
async def test_exception_handler_catch_exceptions(self):
@@ -548,9 +552,12 @@ class TestDispatcher:
router = Router()
dp.include_router(router)
+ class CustomException(Exception):
+ pass
+
@router.message()
async def message_handler(message: Message):
- raise Exception("KABOOM")
+ raise CustomException("KABOOM")
update = Update(
update_id=42,
@@ -562,23 +569,23 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
)
- with pytest.raises(Exception, match="KABOOM"):
- await dp.update.trigger(update)
+ with pytest.raises(CustomException, match="KABOOM"):
+ await dp.update.trigger(update, bot=None)
@router.errors()
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
- response = await dp.update.trigger(update)
+ response = await dp.update.trigger(update, bot=None)
assert response == "KABOOM"
@dp.errors()
async def root_error_handler(event: Update, exception: Exception):
return exception
- response = await dp.update.trigger(update)
+ response = await dp.update.trigger(update, bot=None)
- assert isinstance(response, Exception)
+ assert isinstance(response, CustomException)
assert str(response) == "KABOOM"
@pytest.mark.asyncio
@@ -654,20 +661,3 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0]
-
- @pytest.mark.parametrize(
- "strategy,case,expected",
- [
- [FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
- [FSMStrategy.CHAT, (-42, 42), (-42, -42)],
- [FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
- [FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
- [FSMStrategy.CHAT, (42, 42), (42, 42)],
- [FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
- ],
- )
- def test_get_current_state_context(self, strategy, case, expected):
- dp = Dispatcher(fsm_strategy=strategy)
- chat_id, user_id = case
- state = dp.current_state(chat_id=chat_id, user_id=user_id)
- assert (state.chat_id, state.user_id) == expected
diff --git a/tests/test_dispatcher/test_fsm/storage/test_memory.py b/tests/test_dispatcher/test_fsm/storage/test_memory.py
deleted file mode 100644
index 2f587075..00000000
--- a/tests/test_dispatcher/test_fsm/storage/test_memory.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import pytest
-
-from aiogram.dispatcher.fsm.storage.memory import MemoryStorage, MemoryStorageRecord
-
-
-@pytest.fixture()
-def storage():
- return MemoryStorage()
-
-
-class TestMemoryStorage:
- @pytest.mark.asyncio
- async def test_set_state(self, storage: MemoryStorage):
- assert await storage.get_state(chat_id=-42, user_id=42) is None
-
- await storage.set_state(chat_id=-42, user_id=42, state="state")
- assert await storage.get_state(chat_id=-42, user_id=42) == "state"
-
- assert -42 in storage.storage
- assert 42 in storage.storage[-42]
- assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
- assert storage.storage[-42][42].state == "state"
-
- @pytest.mark.asyncio
- async def test_set_data(self, storage: MemoryStorage):
- assert await storage.get_data(chat_id=-42, user_id=42) == {}
-
- await storage.set_data(chat_id=-42, user_id=42, data={"foo": "bar"})
- assert await storage.get_data(chat_id=-42, user_id=42) == {"foo": "bar"}
-
- assert -42 in storage.storage
- assert 42 in storage.storage[-42]
- assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
- assert storage.storage[-42][42].data == {"foo": "bar"}
-
- @pytest.mark.asyncio
- async def test_update_data(self, storage: MemoryStorage):
- assert await storage.get_data(chat_id=-42, user_id=42) == {}
- assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == {
- "foo": "bar"
- }
- assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == {
- "foo": "bar",
- "baz": "spam",
- }
diff --git a/tests/test_dispatcher/test_fsm/storage/test_redis.py b/tests/test_dispatcher/test_fsm/storage/test_redis.py
new file mode 100644
index 00000000..7b914a33
--- /dev/null
+++ b/tests/test_dispatcher/test_fsm/storage/test_redis.py
@@ -0,0 +1,21 @@
+import pytest
+
+from aiogram.dispatcher.fsm.storage.redis import RedisStorage
+from tests.mocked_bot import MockedBot
+
+
+@pytest.mark.redis
+class TestRedisStorage:
+ @pytest.mark.parametrize(
+ "prefix_bot,result",
+ [
+ [False, "fsm:-1:2"],
+ [True, "fsm:42:-1:2"],
+ [{42: "kaboom"}, "fsm:kaboom:-1:2"],
+ [lambda bot: "kaboom", "fsm:kaboom:-1:2"],
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
+ storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
+ assert storage.generate_key(bot, -1, 2) == result
diff --git a/tests/test_dispatcher/test_fsm/storage/test_storages.py b/tests/test_dispatcher/test_fsm/storage/test_storages.py
new file mode 100644
index 00000000..fcb2deae
--- /dev/null
+++ b/tests/test_dispatcher/test_fsm/storage/test_storages.py
@@ -0,0 +1,44 @@
+import pytest
+
+from aiogram.dispatcher.fsm.storage.base import BaseStorage
+from tests.mocked_bot import MockedBot
+
+
+@pytest.mark.parametrize(
+ "storage",
+ [pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
+)
+class TestStorages:
+ @pytest.mark.asyncio
+ async def test_lock(self, bot: MockedBot, storage: BaseStorage):
+ # TODO: ?!?
+ async with storage.lock(bot=bot, chat_id=-42, user_id=42):
+ assert True, "You are kidding me?"
+
+ @pytest.mark.asyncio
+ async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
+ assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
+
+ await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
+ assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
+ await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
+ assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
+
+ @pytest.mark.asyncio
+ async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
+ assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
+
+ await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
+ assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
+ await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
+ assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
+
+ @pytest.mark.asyncio
+ async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
+ assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
+ assert await storage.update_data(
+ bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
+ ) == {"foo": "bar"}
+ assert await storage.update_data(
+ bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
+ ) == {"foo": "bar", "baz": "spam"}
diff --git a/tests/test_dispatcher/test_fsm/test_context.py b/tests/test_dispatcher/test_fsm/test_context.py
index 6c444c44..fb98c423 100644
--- a/tests/test_dispatcher/test_fsm/test_context.py
+++ b/tests/test_dispatcher/test_fsm/test_context.py
@@ -2,27 +2,28 @@ import pytest
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
+from tests.mocked_bot import MockedBot
@pytest.fixture()
-def state():
+def state(bot: MockedBot):
storage = MemoryStorage()
- ctx = storage.storage[-42][42]
+ ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
- return FSMContext(storage=storage, user_id=-42, chat_id=42)
+ return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
class TestFSMContext:
@pytest.mark.asyncio
- async def test_address_mapping(self):
+ async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage()
- ctx = storage.storage[-42][42]
+ ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
- state = FSMContext(storage=storage, chat_id=-42, user_id=42)
- state2 = FSMContext(storage=storage, chat_id=42, user_id=42)
- state3 = FSMContext(storage=storage, chat_id=69, user_id=69)
+ state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
+ state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
+ state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
assert await state.get_state() == "test"
assert await state2.get_state() is None
diff --git a/tests/test_utils/test_auth_widget.py b/tests/test_utils/test_auth_widget.py
new file mode 100644
index 00000000..a6071760
--- /dev/null
+++ b/tests/test_utils/test_auth_widget.py
@@ -0,0 +1,27 @@
+import pytest
+
+from aiogram.utils.auth_widget import check_integrity
+
+TOKEN = "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11"
+
+
+@pytest.fixture
+def data():
+ return {
+ "id": "42",
+ "first_name": "John",
+ "last_name": "Smith",
+ "username": "username",
+ "photo_url": "https://t.me/i/userpic/320/picname.jpg",
+ "auth_date": "1565810688",
+ "hash": "c303db2b5a06fe41d23a9b14f7c545cfc11dcc7473c07c9c5034ae60062461ce",
+ }
+
+
+class TestCheckIntegrity:
+ def test_ok(self, data):
+ assert check_integrity(TOKEN, data) is True
+
+ def test_fail(self, data):
+ data.pop("username")
+ assert check_integrity(TOKEN, data) is False
diff --git a/tests/test_utils/test_deep_linking.py b/tests/test_utils/test_deep_linking.py
new file mode 100644
index 00000000..93ff4dab
--- /dev/null
+++ b/tests/test_utils/test_deep_linking.py
@@ -0,0 +1,97 @@
+import pytest
+from async_lru import alru_cache
+
+from aiogram.utils.deep_linking import (
+ create_start_link,
+ create_startgroup_link,
+ decode_payload,
+ encode_payload,
+)
+
+# enable asyncio mode
+from tests.mocked_bot import MockedBot
+
+pytestmark = pytest.mark.asyncio
+
+PAYLOADS = [
+ "foo",
+ "AAbbCCddEEff1122334455",
+ "aaBBccDDeeFF5544332211",
+ -12345678901234567890,
+ 12345678901234567890,
+]
+WRONG_PAYLOADS = [
+ "@BotFather",
+ "Some:special$characters#=",
+ "spaces spaces spaces",
+ 1234567890123456789.0,
+]
+
+
+@pytest.fixture(params=PAYLOADS, name="payload")
+def payload_fixture(request):
+ return request.param
+
+
+@pytest.fixture(params=WRONG_PAYLOADS, name="wrong_payload")
+def wrong_payload_fixture(request):
+ return request.param
+
+
+@pytest.fixture(autouse=True)
+def get_bot_user_fixture(monkeypatch):
+ """Monkey patching of bot.me calling."""
+
+ @alru_cache()
+ async def get_bot_user_mock(self):
+ from aiogram.types import User
+
+ return User(
+ id=12345678,
+ is_bot=True,
+ first_name="FirstName",
+ last_name="LastName",
+ username="username",
+ language_code="uk-UA",
+ )
+
+ monkeypatch.setattr(MockedBot, "me", get_bot_user_mock)
+
+
+class TestDeepLinking:
+ async def test_get_start_link(self, bot, payload):
+ link = await create_start_link(bot=bot, payload=payload)
+ assert link == f"https://t.me/username?start={payload}"
+
+ async def test_wrong_symbols(self, bot, wrong_payload):
+ with pytest.raises(ValueError):
+ await create_start_link(bot, wrong_payload)
+
+ async def test_get_startgroup_link(self, bot, payload):
+ link = await create_startgroup_link(bot, payload)
+ assert link == f"https://t.me/username?startgroup={payload}"
+
+ async def test_filter_encode_and_decode(self, payload):
+ encoded = encode_payload(payload)
+ decoded = decode_payload(encoded)
+ assert decoded == str(payload)
+
+ async def test_get_start_link_with_encoding(self, bot, wrong_payload):
+ # define link
+ link = await create_start_link(bot, wrong_payload, encode=True)
+
+ # define reference link
+ encoded_payload = encode_payload(wrong_payload)
+
+ assert link == f"https://t.me/username?start={encoded_payload}"
+
+ async def test_64_len_payload(self, bot):
+ payload = "p" * 64
+ link = await create_start_link(bot, payload)
+ assert link
+
+ async def test_too_long_payload(self, bot):
+ payload = "p" * 65
+ print(payload, len(payload))
+ with pytest.raises(ValueError):
+ await create_start_link(bot, payload)
diff --git a/tests/test_utils/test_markdown.py b/tests/test_utils/test_markdown.py
index 12e44ccf..815b1c5d 100644
--- a/tests/test_utils/test_markdown.py
+++ b/tests/test_utils/test_markdown.py
@@ -35,7 +35,7 @@ class TestMarkdown:
[hitalic, ("test", "test"), " ", "test test"],
[code, ("test", "test"), " ", "`test test`"],
[hcode, ("test", "test"), " ", "test test"],
- [pre, ("test", "test"), " ", "```test test```"],
+ [pre, ("test", "test"), " ", "```\ntest test\n```"],
[hpre, ("test", "test"), " ", "
test test"], [underline, ("test", "test"), " ", "__\rtest test__\r"], [hunderline, ("test", "test"), " ", "test test"], diff --git a/tests/test_utils/test_text_decorations.py b/tests/test_utils/test_text_decorations.py index 6cb5105d..da171575 100644 --- a/tests/test_utils/test_text_decorations.py +++ b/tests/test_utils/test_text_decorations.py @@ -55,7 +55,7 @@ class TestTextDecoration: [markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"], [markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"], [markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"], - [markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```test```"], + [markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```\ntest\n```"], [ markdown_decoration, MessageEntity(type="pre", offset=0, length=5, language="python"),