Added tests

This commit is contained in:
Alex Root Junior 2023-04-22 19:30:54 +03:00
parent e23839aa6c
commit 6d4e55ae3e
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
6 changed files with 39 additions and 74 deletions

View file

@ -19,30 +19,13 @@ class UserContextMiddleware(BaseMiddleware):
if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user, thread_id = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user):
if user is not None:
data[EVENT_FROM_USER_KEY] = user
if chat is not None:
data[EVENT_CHAT_KEY] = chat
if thread_id is not None:
data[EVENT_THREAD_ID_KEY] = thread_id
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
if user is not None:
data[EVENT_FROM_USER_KEY] = user
if chat is not None:
data[EVENT_CHAT_KEY] = chat
if thread_id is not None:
data[EVENT_THREAD_ID_KEY] = thread_id
return await handler(event, data)
@classmethod
def resolve_event_context(

View file

@ -70,9 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:

View file

@ -26,49 +26,3 @@ def create_tg_link(link: str, **kwargs: Any) -> str:
def create_telegram_link(*path: str, **kwargs: Any) -> str:
return _format_url("https://t.me", *path, **kwargs)
def create_channel_bot_link(
username: str,
parameter: Optional[str] = None,
change_info: bool = False,
post_messages: bool = False,
edit_messages: bool = False,
delete_messages: bool = False,
restrict_members: bool = False,
invite_users: bool = False,
pin_messages: bool = False,
promote_members: bool = False,
manage_video_chats: bool = False,
anonymous: bool = False,
manage_chat: bool = False,
) -> str:
params = {}
if parameter is not None:
params["startgroup"] = parameter
permissions = []
if change_info:
permissions.append("change_info")
if post_messages:
permissions.append("post_messages")
if edit_messages:
permissions.append("edit_messages")
if delete_messages:
permissions.append("delete_messages")
if restrict_members:
permissions.append("restrict_members")
if invite_users:
permissions.append("invite_users")
if pin_messages:
permissions.append("pin_messages")
if promote_members:
permissions.append("promote_members")
if manage_video_chats:
permissions.append("manage_video_chats")
if anonymous:
permissions.append("anonymous")
if manage_chat:
permissions.append("manage_chat")
if permissions:
params["admin"] = "+".join(permissions)
return create_telegram_link(username, **params)

View file

@ -14,7 +14,7 @@ from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod
from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod
from aiogram.types import (
CallbackQuery,
Chat,
@ -462,9 +462,9 @@ class TestDispatcher:
async def my_handler(event: Any, **kwargs: Any):
assert event == getattr(update, event_type)
if has_chat:
assert Chat.get_current(False)
assert kwargs["event_chat"]
if has_user:
assert User.get_current(False)
assert kwargs["event_from_user"]
return kwargs
result = await router.feed_update(bot, update, test="PASS")

View file

@ -1,6 +1,9 @@
from unittest.mock import patch
import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.types import Update
async def next_handler(*args, **kwargs):
@ -11,3 +14,13 @@ class TestUserContextMiddleware:
async def test_unexpected_event_type(self):
with pytest.raises(RuntimeError):
await UserContextMiddleware()(next_handler, object(), {})
async def test_call(self):
middleware = UserContextMiddleware()
data = {}
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
await middleware(next_handler, Update(update_id=42), data)
assert data["event_chat"] == 1
assert data["event_from_user"] == 2
assert data["event_thread_id"] == 3

View file

@ -11,6 +11,7 @@ PREFIX = "test"
BOT_ID = 42
CHAT_ID = -1
USER_ID = 2
THREAD_ID = 3
FIELD = "data"
@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder:
with pytest.raises(ValueError):
key_builder.build(key, FIELD)
def test_thread_id(self):
key_builder = DefaultKeyBuilder(
prefix=PREFIX,
)
key = StorageKey(
chat_id=CHAT_ID,
user_id=USER_ID,
bot_id=BOT_ID,
thread_id=THREAD_ID,
destiny=DEFAULT_DESTINY,
)
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
def test_create_isolation(self):
fake_redis = object()
storage = RedisStorage(redis=fake_redis)