Refactor UserContextMiddleware to use EventContext class

This update significantly refactors UserContextMiddleware to leverage a new class, EventContext. Instead of resolving event context as a tuple, it now produces an instance of EventContext. Additional adjustments include supporting a business connection ID for event context identification and facilitating backwards compatibility. Tests and other files were also updated accordingly for these changes.
This commit is contained in:
JRoot Junior 2024-04-09 01:26:31 +03:00
parent 7dd18bfa11
commit 924fb755cb
No known key found for this signature in database
GPG key ID: 738964250D5FF6E2
5 changed files with 128 additions and 63 deletions

View file

@ -1,13 +1,32 @@
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User
EVENT_CONTEXT_KEY = "event_context"
EVENT_FROM_USER_KEY = "event_from_user"
EVENT_CHAT_KEY = "event_chat"
EVENT_THREAD_ID_KEY = "event_thread_id"
@dataclass(frozen=True)
class EventContext:
chat: Optional[Chat] = None
user: Optional[User] = None
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
@property
def user_id(self) -> Optional[int]:
return self.user.id if self.user else None
@property
def chat_id(self) -> Optional[int]:
return self.chat.id if self.chat else None
class UserContextMiddleware(BaseMiddleware):
async def __call__(
self,
@ -17,93 +36,114 @@ class UserContextMiddleware(BaseMiddleware):
) -> Any:
if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user, thread_id = self.resolve_event_context(event=event)
if user is not None:
data[EVENT_FROM_USER_KEY] = user
if chat is not None:
data[EVENT_CHAT_KEY] = chat
if thread_id is not None:
data[EVENT_THREAD_ID_KEY] = thread_id
event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event)
# Backward compatibility
if event_context.user is not None:
data[EVENT_FROM_USER_KEY] = event_context.user
if event_context.chat is not None:
data[EVENT_CHAT_KEY] = event_context.chat
if event_context.thread_id is not None:
data[EVENT_THREAD_ID_KEY] = event_context.thread_id
return await handler(event, data)
@classmethod
def resolve_event_context(
cls, event: Update
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
def resolve_event_context(cls, event: Update) -> EventContext:
"""
Resolve chat and user instance from Update object
"""
if event.message:
return (
event.message.chat,
event.message.from_user,
event.message.message_thread_id if event.message.is_topic_message else None,
return EventContext(
chat=event.message.chat,
user=event.message.from_user,
thread_id=event.message.message_thread_id
if event.message.is_topic_message
else None,
)
if event.edited_message:
return (
event.edited_message.chat,
event.edited_message.from_user,
event.edited_message.message_thread_id
return EventContext(
chat=event.edited_message.chat,
user=event.edited_message.from_user,
thread_id=event.edited_message.message_thread_id
if event.edited_message.is_topic_message
else None,
)
if event.channel_post:
return event.channel_post.chat, None, None
return EventContext(chat=event.channel_post.chat)
if event.edited_channel_post:
return event.edited_channel_post.chat, None, None
return EventContext(chat=event.edited_channel_post.chat)
if event.inline_query:
return None, event.inline_query.from_user, None
return EventContext(user=event.inline_query.from_user)
if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user, None
return EventContext(user=event.chosen_inline_result.from_user)
if event.callback_query:
if event.callback_query.message:
return (
event.callback_query.message.chat,
event.callback_query.from_user,
event.callback_query.message.message_thread_id
return EventContext(
chat=event.callback_query.message.chat,
user=event.callback_query.from_user,
thread_id=event.callback_query.message.message_thread_id
if not isinstance(event.callback_query.message, InaccessibleMessage)
and event.callback_query.message.is_topic_message
else None,
)
return None, event.callback_query.from_user, None
return EventContext(user=event.callback_query.from_user)
if event.shipping_query:
return None, event.shipping_query.from_user, None
return EventContext(user=event.shipping_query.from_user)
if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user, None
return EventContext(user=event.pre_checkout_query.from_user)
if event.poll_answer:
return event.poll_answer.voter_chat, event.poll_answer.user, None
return EventContext(
chat=event.poll_answer.voter_chat,
user=event.poll_answer.user,
)
if event.my_chat_member:
return event.my_chat_member.chat, event.my_chat_member.from_user, None
return EventContext(
chat=event.my_chat_member.chat, user=event.my_chat_member.from_user
)
if event.chat_member:
return event.chat_member.chat, event.chat_member.from_user, None
return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user)
if event.chat_join_request:
return event.chat_join_request.chat, event.chat_join_request.from_user, None
return EventContext(
chat=event.chat_join_request.chat, user=event.chat_join_request.from_user
)
if event.message_reaction:
return event.message_reaction.chat, event.message_reaction.user, None
return EventContext(
chat=event.message_reaction.chat,
user=event.message_reaction.user,
)
if event.message_reaction_count:
return event.message_reaction_count.chat, None, None
return EventContext(chat=event.message_reaction_count.chat)
if event.chat_boost:
return event.chat_boost.chat, None, None
return EventContext(chat=event.chat_boost.chat)
if event.removed_chat_boost:
return event.removed_chat_boost.chat, None, None
return EventContext(chat=event.removed_chat_boost.chat)
if event.deleted_business_messages:
return event.deleted_business_messages.chat, None, None
return EventContext(
chat=event.deleted_business_messages.chat,
business_connection_id=event.deleted_business_messages.business_connection_id,
)
if event.business_connection:
return None, event.business_connection.user, None
return EventContext(
user=event.business_connection.user,
business_connection_id=event.business_connection.id,
)
if event.business_message:
return (
event.business_message.chat,
event.business_message.from_user,
event.business_message.message_thread_id
return EventContext(
chat=event.business_message.chat,
user=event.business_message.from_user,
thread_id=event.business_message.message_thread_id
if event.business_message.is_topic_message
else None,
business_connection_id=event.business_message.business_connection_id,
)
if event.edited_business_message:
return (
event.edited_business_message.chat,
event.edited_business_message.from_user,
event.edited_business_message.message_thread_id
return EventContext(
chat=event.edited_business_message.chat,
user=event.edited_business_message.from_user,
thread_id=event.edited_business_message.message_thread_id
if event.edited_business_message.is_topic_message
else None,
business_connection_id=event.edited_business_message.business_connection_id,
)
return None, None, None
return EventContext()

View file

@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext
from aiogram.fsm.context import FSMContext
from aiogram.fsm.storage.base import (
DEFAULT_DESTINY,
@ -47,16 +48,13 @@ class FSMContextMiddleware(BaseMiddleware):
data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
thread_id = data.get("event_thread_id")
chat_id = chat.id if chat else None
user_id = user.id if user else None
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
return self.resolve_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
chat_id=event_context.chat_id,
user_id=event_context.user_id,
thread_id=event_context.thread_id,
business_connection_id=event_context.business_connection_id,
destiny=destiny,
)
@ -66,6 +64,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id: Optional[int],
user_id: Optional[int],
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
@ -83,6 +82,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny,
)
return None
@ -93,6 +93,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
@ -102,6 +103,7 @@ class FSMContextMiddleware(BaseMiddleware):
chat_id=chat_id,
bot_id=bot.id,
thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny,
),
)

View file

@ -16,6 +16,7 @@ class StorageKey:
chat_id: int
user_id: int
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
destiny: str = DEFAULT_DESTINY

View file

@ -53,17 +53,20 @@ class DefaultKeyBuilder(KeyBuilder):
prefix: str = "fsm",
separator: str = ":",
with_bot_id: bool = False,
with_business_connection_id: bool = False,
with_destiny: bool = False,
) -> None:
"""
:param prefix: prefix for all records
:param separator: separator
:param with_bot_id: include Bot id in the key
:param with_business_connection_id: include business connection id
:param with_destiny: include destiny key
"""
self.prefix = prefix
self.separator = separator
self.with_bot_id = with_bot_id
self.with_business_connection_id = with_business_connection_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
@ -74,6 +77,8 @@ class DefaultKeyBuilder(KeyBuilder):
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_business_connection_id and key.business_connection_id:
parts.append(str(key.business_connection_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:

View file

@ -2,8 +2,11 @@ from unittest.mock import patch
import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.types import Update
from aiogram.dispatcher.middlewares.user_context import (
EventContext,
UserContextMiddleware,
)
from aiogram.types import Chat, Update, User
async def next_handler(*args, **kwargs):
@ -18,9 +21,23 @@ class TestUserContextMiddleware:
async def test_call(self):
middleware = UserContextMiddleware()
data = {}
with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]):
chat = Chat(id=1, type="private", title="Test")
user = User(id=2, first_name="Test", is_bot=False)
thread_id = 3
with patch.object(
UserContextMiddleware,
"resolve_event_context",
return_value=EventContext(user=user, chat=chat, thread_id=3),
):
await middleware(next_handler, Update(update_id=42), data)
assert data["event_chat"] == 1
assert data["event_from_user"] == 2
assert data["event_thread_id"] == 3
event_context = data["event_context"]
assert isinstance(event_context, EventContext)
assert event_context.chat is chat
assert event_context.user is user
assert event_context.thread_id == thread_id
assert data["event_chat"] is chat
assert data["event_from_user"] is user
assert data["event_thread_id"] == thread_id