From 924fb755cbb8cfa7c0255e96c53eba3eb36f97bd Mon Sep 17 00:00:00 2001 From: JRoot Junior Date: Tue, 9 Apr 2024 01:26:31 +0300 Subject: [PATCH] Refactor UserContextMiddleware to use EventContext class This update significantly refactors UserContextMiddleware to leverage a new class, EventContext. Instead of resolving event context as a tuple, it now produces an instance of EventContext. Additional adjustments include supporting a business connection ID for event context identification and facilitating backwards compatibility. Tests and other files were also updated accordingly for these changes. --- .../dispatcher/middlewares/user_context.py | 138 +++++++++++------- aiogram/fsm/middleware.py | 18 ++- aiogram/fsm/storage/base.py | 1 + aiogram/fsm/storage/redis.py | 5 + .../test_middlewares/test_user_context.py | 29 +++- 5 files changed, 128 insertions(+), 63 deletions(-) diff --git a/aiogram/dispatcher/middlewares/user_context.py b/aiogram/dispatcher/middlewares/user_context.py index 2cec944b..0c780048 100644 --- a/aiogram/dispatcher/middlewares/user_context.py +++ b/aiogram/dispatcher/middlewares/user_context.py @@ -1,13 +1,32 @@ -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, Optional from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.types import Chat, InaccessibleMessage, TelegramObject, Update, User +EVENT_CONTEXT_KEY = "event_context" + EVENT_FROM_USER_KEY = "event_from_user" EVENT_CHAT_KEY = "event_chat" EVENT_THREAD_ID_KEY = "event_thread_id" +@dataclass(frozen=True) +class EventContext: + chat: Optional[Chat] = None + user: Optional[User] = None + thread_id: Optional[int] = None + business_connection_id: Optional[str] = None + + @property + def user_id(self) -> Optional[int]: + return self.user.id if self.user else None + + @property + def chat_id(self) -> Optional[int]: + return self.chat.id if self.chat else None + + class UserContextMiddleware(BaseMiddleware): async def __call__( self, @@ -17,93 +36,114 @@ class UserContextMiddleware(BaseMiddleware): ) -> Any: if not isinstance(event, Update): raise RuntimeError("UserContextMiddleware got an unexpected event type!") - chat, user, thread_id = self.resolve_event_context(event=event) - if user is not None: - data[EVENT_FROM_USER_KEY] = user - if chat is not None: - data[EVENT_CHAT_KEY] = chat - if thread_id is not None: - data[EVENT_THREAD_ID_KEY] = thread_id + event_context = data[EVENT_CONTEXT_KEY] = self.resolve_event_context(event=event) + + # Backward compatibility + if event_context.user is not None: + data[EVENT_FROM_USER_KEY] = event_context.user + if event_context.chat is not None: + data[EVENT_CHAT_KEY] = event_context.chat + if event_context.thread_id is not None: + data[EVENT_THREAD_ID_KEY] = event_context.thread_id + return await handler(event, data) @classmethod - def resolve_event_context( - cls, event: Update - ) -> Tuple[Optional[Chat], Optional[User], Optional[int]]: + def resolve_event_context(cls, event: Update) -> EventContext: """ Resolve chat and user instance from Update object """ if event.message: - return ( - event.message.chat, - event.message.from_user, - event.message.message_thread_id if event.message.is_topic_message else None, + return EventContext( + chat=event.message.chat, + user=event.message.from_user, + thread_id=event.message.message_thread_id + if event.message.is_topic_message + else None, ) if event.edited_message: - return ( - event.edited_message.chat, - event.edited_message.from_user, - event.edited_message.message_thread_id + return EventContext( + chat=event.edited_message.chat, + user=event.edited_message.from_user, + thread_id=event.edited_message.message_thread_id if event.edited_message.is_topic_message else None, ) if event.channel_post: - return event.channel_post.chat, None, None + return EventContext(chat=event.channel_post.chat) if event.edited_channel_post: - return event.edited_channel_post.chat, None, None + return EventContext(chat=event.edited_channel_post.chat) if event.inline_query: - return None, event.inline_query.from_user, None + return EventContext(user=event.inline_query.from_user) if event.chosen_inline_result: - return None, event.chosen_inline_result.from_user, None + return EventContext(user=event.chosen_inline_result.from_user) if event.callback_query: if event.callback_query.message: - return ( - event.callback_query.message.chat, - event.callback_query.from_user, - event.callback_query.message.message_thread_id + return EventContext( + chat=event.callback_query.message.chat, + user=event.callback_query.from_user, + thread_id=event.callback_query.message.message_thread_id if not isinstance(event.callback_query.message, InaccessibleMessage) and event.callback_query.message.is_topic_message else None, ) - return None, event.callback_query.from_user, None + return EventContext(user=event.callback_query.from_user) if event.shipping_query: - return None, event.shipping_query.from_user, None + return EventContext(user=event.shipping_query.from_user) if event.pre_checkout_query: - return None, event.pre_checkout_query.from_user, None + return EventContext(user=event.pre_checkout_query.from_user) if event.poll_answer: - return event.poll_answer.voter_chat, event.poll_answer.user, None + return EventContext( + chat=event.poll_answer.voter_chat, + user=event.poll_answer.user, + ) if event.my_chat_member: - return event.my_chat_member.chat, event.my_chat_member.from_user, None + return EventContext( + chat=event.my_chat_member.chat, user=event.my_chat_member.from_user + ) if event.chat_member: - return event.chat_member.chat, event.chat_member.from_user, None + return EventContext(chat=event.chat_member.chat, user=event.chat_member.from_user) if event.chat_join_request: - return event.chat_join_request.chat, event.chat_join_request.from_user, None + return EventContext( + chat=event.chat_join_request.chat, user=event.chat_join_request.from_user + ) if event.message_reaction: - return event.message_reaction.chat, event.message_reaction.user, None + return EventContext( + chat=event.message_reaction.chat, + user=event.message_reaction.user, + ) if event.message_reaction_count: - return event.message_reaction_count.chat, None, None + return EventContext(chat=event.message_reaction_count.chat) if event.chat_boost: - return event.chat_boost.chat, None, None + return EventContext(chat=event.chat_boost.chat) if event.removed_chat_boost: - return event.removed_chat_boost.chat, None, None + return EventContext(chat=event.removed_chat_boost.chat) if event.deleted_business_messages: - return event.deleted_business_messages.chat, None, None + return EventContext( + chat=event.deleted_business_messages.chat, + business_connection_id=event.deleted_business_messages.business_connection_id, + ) if event.business_connection: - return None, event.business_connection.user, None + return EventContext( + user=event.business_connection.user, + business_connection_id=event.business_connection.id, + ) if event.business_message: - return ( - event.business_message.chat, - event.business_message.from_user, - event.business_message.message_thread_id + return EventContext( + chat=event.business_message.chat, + user=event.business_message.from_user, + thread_id=event.business_message.message_thread_id if event.business_message.is_topic_message else None, + business_connection_id=event.business_message.business_connection_id, ) if event.edited_business_message: - return ( - event.edited_business_message.chat, - event.edited_business_message.from_user, - event.edited_business_message.message_thread_id + return EventContext( + chat=event.edited_business_message.chat, + user=event.edited_business_message.from_user, + thread_id=event.edited_business_message.message_thread_id if event.edited_business_message.is_topic_message else None, + business_connection_id=event.edited_business_message.business_connection_id, ) - return None, None, None + return EventContext() diff --git a/aiogram/fsm/middleware.py b/aiogram/fsm/middleware.py index d1f1d973..de934574 100644 --- a/aiogram/fsm/middleware.py +++ b/aiogram/fsm/middleware.py @@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, cast from aiogram import Bot from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext from aiogram.fsm.context import FSMContext from aiogram.fsm.storage.base import ( DEFAULT_DESTINY, @@ -47,16 +48,13 @@ class FSMContextMiddleware(BaseMiddleware): data: Dict[str, Any], destiny: str = DEFAULT_DESTINY, ) -> Optional[FSMContext]: - user = data.get("event_from_user") - chat = data.get("event_chat") - thread_id = data.get("event_thread_id") - chat_id = chat.id if chat else None - user_id = user.id if user else None + event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY)) return self.resolve_context( bot=bot, - chat_id=chat_id, - user_id=user_id, - thread_id=thread_id, + chat_id=event_context.chat_id, + user_id=event_context.user_id, + thread_id=event_context.thread_id, + business_connection_id=event_context.business_connection_id, destiny=destiny, ) @@ -66,6 +64,7 @@ class FSMContextMiddleware(BaseMiddleware): chat_id: Optional[int], user_id: Optional[int], thread_id: Optional[int] = None, + business_connection_id: Optional[str] = None, destiny: str = DEFAULT_DESTINY, ) -> Optional[FSMContext]: if chat_id is None: @@ -83,6 +82,7 @@ class FSMContextMiddleware(BaseMiddleware): chat_id=chat_id, user_id=user_id, thread_id=thread_id, + business_connection_id=business_connection_id, destiny=destiny, ) return None @@ -93,6 +93,7 @@ class FSMContextMiddleware(BaseMiddleware): chat_id: int, user_id: int, thread_id: Optional[int] = None, + business_connection_id: Optional[str] = None, destiny: str = DEFAULT_DESTINY, ) -> FSMContext: return FSMContext( @@ -102,6 +103,7 @@ class FSMContextMiddleware(BaseMiddleware): chat_id=chat_id, bot_id=bot.id, thread_id=thread_id, + business_connection_id=business_connection_id, destiny=destiny, ), ) diff --git a/aiogram/fsm/storage/base.py b/aiogram/fsm/storage/base.py index 52cb62f2..a66d56be 100644 --- a/aiogram/fsm/storage/base.py +++ b/aiogram/fsm/storage/base.py @@ -16,6 +16,7 @@ class StorageKey: chat_id: int user_id: int thread_id: Optional[int] = None + business_connection_id: Optional[str] = None destiny: str = DEFAULT_DESTINY diff --git a/aiogram/fsm/storage/redis.py b/aiogram/fsm/storage/redis.py index 33e44be4..eae71eec 100644 --- a/aiogram/fsm/storage/redis.py +++ b/aiogram/fsm/storage/redis.py @@ -53,17 +53,20 @@ class DefaultKeyBuilder(KeyBuilder): prefix: str = "fsm", separator: str = ":", with_bot_id: bool = False, + with_business_connection_id: bool = False, with_destiny: bool = False, ) -> None: """ :param prefix: prefix for all records :param separator: separator :param with_bot_id: include Bot id in the key + :param with_business_connection_id: include business connection id :param with_destiny: include destiny key """ self.prefix = prefix self.separator = separator self.with_bot_id = with_bot_id + self.with_business_connection_id = with_business_connection_id self.with_destiny = with_destiny def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str: @@ -74,6 +77,8 @@ class DefaultKeyBuilder(KeyBuilder): if key.thread_id: parts.append(str(key.thread_id)) parts.append(str(key.user_id)) + if self.with_business_connection_id and key.business_connection_id: + parts.append(str(key.business_connection_id)) if self.with_destiny: parts.append(key.destiny) elif key.destiny != DEFAULT_DESTINY: diff --git a/tests/test_dispatcher/test_middlewares/test_user_context.py b/tests/test_dispatcher/test_middlewares/test_user_context.py index 54c09ce2..b3818fbe 100644 --- a/tests/test_dispatcher/test_middlewares/test_user_context.py +++ b/tests/test_dispatcher/test_middlewares/test_user_context.py @@ -2,8 +2,11 @@ from unittest.mock import patch import pytest -from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware -from aiogram.types import Update +from aiogram.dispatcher.middlewares.user_context import ( + EventContext, + UserContextMiddleware, +) +from aiogram.types import Chat, Update, User async def next_handler(*args, **kwargs): @@ -18,9 +21,23 @@ class TestUserContextMiddleware: async def test_call(self): middleware = UserContextMiddleware() data = {} - with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]): + + chat = Chat(id=1, type="private", title="Test") + user = User(id=2, first_name="Test", is_bot=False) + thread_id = 3 + + with patch.object( + UserContextMiddleware, + "resolve_event_context", + return_value=EventContext(user=user, chat=chat, thread_id=3), + ): await middleware(next_handler, Update(update_id=42), data) - assert data["event_chat"] == 1 - assert data["event_from_user"] == 2 - assert data["event_thread_id"] == 3 + event_context = data["event_context"] + assert isinstance(event_context, EventContext) + assert event_context.chat is chat + assert event_context.user is user + assert event_context.thread_id == thread_id + assert data["event_chat"] is chat + assert data["event_from_user"] is user + assert data["event_thread_id"] == thread_id