mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Added tests
This commit is contained in:
parent
e23839aa6c
commit
6d4e55ae3e
6 changed files with 39 additions and 74 deletions
|
|
@ -19,30 +19,13 @@ class UserContextMiddleware(BaseMiddleware):
|
|||
if not isinstance(event, Update):
|
||||
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
|
||||
chat, user, thread_id = self.resolve_event_context(event=event)
|
||||
with self.context(chat=chat, user=user):
|
||||
if user is not None:
|
||||
data[EVENT_FROM_USER_KEY] = user
|
||||
if chat is not None:
|
||||
data[EVENT_CHAT_KEY] = chat
|
||||
if thread_id is not None:
|
||||
data[EVENT_THREAD_ID_KEY] = thread_id
|
||||
return await handler(event, data)
|
||||
|
||||
@contextmanager
|
||||
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
|
||||
chat_token = None
|
||||
user_token = None
|
||||
if chat:
|
||||
chat_token = chat.set_current(chat)
|
||||
if user:
|
||||
user_token = user.set_current(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if chat and chat_token:
|
||||
chat.reset_current(chat_token)
|
||||
if user and user_token:
|
||||
user.reset_current(user_token)
|
||||
if user is not None:
|
||||
data[EVENT_FROM_USER_KEY] = user
|
||||
if chat is not None:
|
||||
data[EVENT_CHAT_KEY] = chat
|
||||
if thread_id is not None:
|
||||
data[EVENT_THREAD_ID_KEY] = thread_id
|
||||
return await handler(event, data)
|
||||
|
||||
@classmethod
|
||||
def resolve_event_context(
|
||||
|
|
|
|||
|
|
@ -70,9 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
|
|||
parts = [self.prefix]
|
||||
if self.with_bot_id:
|
||||
parts.append(str(key.bot_id))
|
||||
parts.extend([str(key.chat_id), str(key.user_id)])
|
||||
parts.append(str(key.chat_id))
|
||||
if key.thread_id:
|
||||
parts.append(str(key.thread_id))
|
||||
parts.append(str(key.user_id))
|
||||
if self.with_destiny:
|
||||
parts.append(key.destiny)
|
||||
elif key.destiny != DEFAULT_DESTINY:
|
||||
|
|
|
|||
|
|
@ -26,49 +26,3 @@ def create_tg_link(link: str, **kwargs: Any) -> str:
|
|||
|
||||
def create_telegram_link(*path: str, **kwargs: Any) -> str:
|
||||
return _format_url("https://t.me", *path, **kwargs)
|
||||
|
||||
|
||||
def create_channel_bot_link(
|
||||
username: str,
|
||||
parameter: Optional[str] = None,
|
||||
change_info: bool = False,
|
||||
post_messages: bool = False,
|
||||
edit_messages: bool = False,
|
||||
delete_messages: bool = False,
|
||||
restrict_members: bool = False,
|
||||
invite_users: bool = False,
|
||||
pin_messages: bool = False,
|
||||
promote_members: bool = False,
|
||||
manage_video_chats: bool = False,
|
||||
anonymous: bool = False,
|
||||
manage_chat: bool = False,
|
||||
) -> str:
|
||||
params = {}
|
||||
if parameter is not None:
|
||||
params["startgroup"] = parameter
|
||||
permissions = []
|
||||
if change_info:
|
||||
permissions.append("change_info")
|
||||
if post_messages:
|
||||
permissions.append("post_messages")
|
||||
if edit_messages:
|
||||
permissions.append("edit_messages")
|
||||
if delete_messages:
|
||||
permissions.append("delete_messages")
|
||||
if restrict_members:
|
||||
permissions.append("restrict_members")
|
||||
if invite_users:
|
||||
permissions.append("invite_users")
|
||||
if pin_messages:
|
||||
permissions.append("pin_messages")
|
||||
if promote_members:
|
||||
permissions.append("promote_members")
|
||||
if manage_video_chats:
|
||||
permissions.append("manage_video_chats")
|
||||
if anonymous:
|
||||
permissions.append("anonymous")
|
||||
if manage_chat:
|
||||
permissions.append("manage_chat")
|
||||
if permissions:
|
||||
params["admin"] = "+".join(permissions)
|
||||
return create_telegram_link(username, **params)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from aiogram import Bot
|
|||
from aiogram.dispatcher.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod
|
||||
from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod
|
||||
from aiogram.types import (
|
||||
CallbackQuery,
|
||||
Chat,
|
||||
|
|
@ -462,9 +462,9 @@ class TestDispatcher:
|
|||
async def my_handler(event: Any, **kwargs: Any):
|
||||
assert event == getattr(update, event_type)
|
||||
if has_chat:
|
||||
assert Chat.get_current(False)
|
||||
assert kwargs["event_chat"]
|
||||
if has_user:
|
||||
assert User.get_current(False)
|
||||
assert kwargs["event_from_user"]
|
||||
return kwargs
|
||||
|
||||
result = await router.feed_update(bot, update, test="PASS")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
|
||||
from aiogram.types import Update
|
||||
|
||||
|
||||
async def next_handler(*args, **kwargs):
|
||||
|
|
@ -11,3 +14,13 @@ class TestUserContextMiddleware:
|
|||
async def test_unexpected_event_type(self):
|
||||
with pytest.raises(RuntimeError):
|
||||
await UserContextMiddleware()(next_handler, object(), {})
|
||||
|
||||
async def test_call(self):
|
||||
middleware = UserContextMiddleware()
|
||||
data = {}
|
||||
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
|
||||
await middleware(next_handler, Update(update_id=42), data)
|
||||
|
||||
assert data["event_chat"] == 1
|
||||
assert data["event_from_user"] == 2
|
||||
assert data["event_thread_id"] == 3
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ PREFIX = "test"
|
|||
BOT_ID = 42
|
||||
CHAT_ID = -1
|
||||
USER_ID = 2
|
||||
THREAD_ID = 3
|
||||
FIELD = "data"
|
||||
|
||||
|
||||
|
|
@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder:
|
|||
with pytest.raises(ValueError):
|
||||
key_builder.build(key, FIELD)
|
||||
|
||||
def test_thread_id(self):
|
||||
key_builder = DefaultKeyBuilder(
|
||||
prefix=PREFIX,
|
||||
)
|
||||
key = StorageKey(
|
||||
chat_id=CHAT_ID,
|
||||
user_id=USER_ID,
|
||||
bot_id=BOT_ID,
|
||||
thread_id=THREAD_ID,
|
||||
destiny=DEFAULT_DESTINY,
|
||||
)
|
||||
assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}"
|
||||
|
||||
def test_create_isolation(self):
|
||||
fake_redis = object()
|
||||
storage = RedisStorage(redis=fake_redis)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue