diff --git a/CHANGES/1161.feature.rst b/CHANGES/1161.feature.rst new file mode 100644 index 00000000..819c697c --- /dev/null +++ b/CHANGES/1161.feature.rst @@ -0,0 +1,17 @@ +Added support for FSM in Forum topics. + +The strategy can be changed in dispatcher: + +.. code-block:: python + + from aiogram.fsm.strategy import FSMStrategy + ... + dispatcher = Dispatcher( + fsm_strategy=FSMStrategy.USER_IN_THREAD, + storage=..., # Any persistent storage + ) + +.. note:: + + If you have implemented you own storages you should extend record key generation + with new one attribute - `thread_id` diff --git a/aiogram/dispatcher/middlewares/user_context.py b/aiogram/dispatcher/middlewares/user_context.py index 3531beb7..9ede4334 100644 --- a/aiogram/dispatcher/middlewares/user_context.py +++ b/aiogram/dispatcher/middlewares/user_context.py @@ -4,6 +4,10 @@ from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.types import Chat, TelegramObject, Update, User +EVENT_FROM_USER_KEY = "event_from_user" +EVENT_CHAT_KEY = "event_chat" +EVENT_THREAD_ID_KEY = "event_thread_id" + class UserContextMiddleware(BaseMiddleware): async def __call__( @@ -14,61 +18,64 @@ class UserContextMiddleware(BaseMiddleware): ) -> Any: if not isinstance(event, Update): raise RuntimeError("UserContextMiddleware got an unexpected event type!") - chat, user = self.resolve_event_context(event=event) - with self.context(chat=chat, user=user): - if user is not None: - data["event_from_user"] = user - if chat is not None: - data["event_chat"] = chat - return await handler(event, data) - - @contextmanager - def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]: - chat_token = None - user_token = None - if chat: - chat_token = chat.set_current(chat) - if user: - user_token = user.set_current(user) - try: - yield - finally: - if chat and chat_token: - chat.reset_current(chat_token) - if user and user_token: - user.reset_current(user_token) + chat, user, thread_id = self.resolve_event_context(event=event) + if user is not None: + data[EVENT_FROM_USER_KEY] = user + if chat is not None: + data[EVENT_CHAT_KEY] = chat + if thread_id is not None: + data[EVENT_THREAD_ID_KEY] = thread_id + return await handler(event, data) @classmethod - def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]: + def resolve_event_context( + cls, event: Update + ) -> Tuple[Optional[Chat], Optional[User], Optional[int]]: """ Resolve chat and user instance from Update object """ if event.message: - return event.message.chat, event.message.from_user + return ( + event.message.chat, + event.message.from_user, + event.message.message_thread_id if event.message.is_topic_message else None, + ) if event.edited_message: - return event.edited_message.chat, event.edited_message.from_user + return ( + event.edited_message.chat, + event.edited_message.from_user, + event.edited_message.message_thread_id + if event.edited_message.is_topic_message + else None, + ) if event.channel_post: - return event.channel_post.chat, None + return event.channel_post.chat, None, None if event.edited_channel_post: - return event.edited_channel_post.chat, None + return event.edited_channel_post.chat, None, None if event.inline_query: - return None, event.inline_query.from_user + return None, event.inline_query.from_user, None if event.chosen_inline_result: - return None, event.chosen_inline_result.from_user + return None, event.chosen_inline_result.from_user, None if event.callback_query: if event.callback_query.message: - return event.callback_query.message.chat, event.callback_query.from_user - return None, event.callback_query.from_user + return ( + event.callback_query.message.chat, + event.callback_query.from_user, + event.callback_query.message.message_thread_id + if event.callback_query.message.is_topic_message + else None, + ) + return None, event.callback_query.from_user, None if event.shipping_query: - return None, event.shipping_query.from_user + return None, event.shipping_query.from_user, None if event.pre_checkout_query: - return None, event.pre_checkout_query.from_user + return None, event.pre_checkout_query.from_user, None if event.poll_answer: - return None, event.poll_answer.user + return None, event.poll_answer.user, None if event.my_chat_member: - return event.my_chat_member.chat, event.my_chat_member.from_user + return event.my_chat_member.chat, event.my_chat_member.from_user, None if event.chat_member: - return event.chat_member.chat, event.chat_member.from_user + return event.chat_member.chat, event.chat_member.from_user, None if event.chat_join_request: - return event.chat_join_request.chat, event.chat_join_request.from_user - return None, None + return event.chat_join_request.chat, event.chat_join_request.from_user, None + return None, None, None diff --git a/aiogram/fsm/middleware.py b/aiogram/fsm/middleware.py index 0232ff0a..6de91a83 100644 --- a/aiogram/fsm/middleware.py +++ b/aiogram/fsm/middleware.py @@ -47,25 +47,42 @@ class FSMContextMiddleware(BaseMiddleware): ) -> Optional[FSMContext]: user = data.get("event_from_user") chat = data.get("event_chat") + thread_id = data.get("event_thread_id") chat_id = chat.id if chat else None user_id = user.id if user else None - return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny) + return self.resolve_context( + bot=bot, + chat_id=chat_id, + user_id=user_id, + thread_id=thread_id, + destiny=destiny, + ) def resolve_context( self, bot: Bot, chat_id: Optional[int], user_id: Optional[int], + thread_id: Optional[int] = None, destiny: str = DEFAULT_DESTINY, ) -> Optional[FSMContext]: if chat_id is None: chat_id = user_id if chat_id is not None and user_id is not None: - chat_id, user_id = apply_strategy( - chat_id=chat_id, user_id=user_id, strategy=self.strategy + chat_id, user_id, thread_id = apply_strategy( + chat_id=chat_id, + user_id=user_id, + thread_id=thread_id, + strategy=self.strategy, + ) + return self.get_context( + bot=bot, + chat_id=chat_id, + user_id=user_id, + thread_id=thread_id, + destiny=destiny, ) - return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny) return None def get_context( @@ -73,6 +90,7 @@ class FSMContextMiddleware(BaseMiddleware): bot: Bot, chat_id: int, user_id: int, + thread_id: Optional[int] = None, destiny: str = DEFAULT_DESTINY, ) -> FSMContext: return FSMContext( @@ -81,6 +99,7 @@ class FSMContextMiddleware(BaseMiddleware): user_id=user_id, chat_id=chat_id, bot_id=bot.id, + thread_id=thread_id, destiny=destiny, ), ) diff --git a/aiogram/fsm/storage/base.py b/aiogram/fsm/storage/base.py index b3551060..52cb62f2 100644 --- a/aiogram/fsm/storage/base.py +++ b/aiogram/fsm/storage/base.py @@ -15,6 +15,7 @@ class StorageKey: bot_id: int chat_id: int user_id: int + thread_id: Optional[int] = None destiny: str = DEFAULT_DESTINY diff --git a/aiogram/fsm/storage/redis.py b/aiogram/fsm/storage/redis.py index 76450f86..6a55d881 100644 --- a/aiogram/fsm/storage/redis.py +++ b/aiogram/fsm/storage/redis.py @@ -70,7 +70,10 @@ class DefaultKeyBuilder(KeyBuilder): parts = [self.prefix] if self.with_bot_id: parts.append(str(key.bot_id)) - parts.extend([str(key.chat_id), str(key.user_id)]) + parts.append(str(key.chat_id)) + if key.thread_id: + parts.append(str(key.thread_id)) + parts.append(str(key.user_id)) if self.with_destiny: parts.append(key.destiny) elif key.destiny != DEFAULT_DESTINY: diff --git a/aiogram/fsm/strategy.py b/aiogram/fsm/strategy.py index 4f540a4a..227924cb 100644 --- a/aiogram/fsm/strategy.py +++ b/aiogram/fsm/strategy.py @@ -1,16 +1,24 @@ from enum import Enum, auto -from typing import Tuple +from typing import Optional, Tuple class FSMStrategy(Enum): USER_IN_CHAT = auto() CHAT = auto() GLOBAL_USER = auto() + USER_IN_THREAD = auto() -def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]: +def apply_strategy( + strategy: FSMStrategy, + chat_id: int, + user_id: int, + thread_id: Optional[int] = None, +) -> Tuple[int, int, Optional[int]]: if strategy == FSMStrategy.CHAT: - return chat_id, chat_id + return chat_id, chat_id, None if strategy == FSMStrategy.GLOBAL_USER: - return user_id, user_id - return chat_id, user_id + return user_id, user_id, None + if strategy == FSMStrategy.USER_IN_THREAD: + return chat_id, user_id, thread_id + return chat_id, user_id, None diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index bcebfaa2..41ecef1b 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -14,7 +14,7 @@ from aiogram import Bot from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler from aiogram.dispatcher.router import Router -from aiogram.methods import GetMe, GetUpdates, Request, SendMessage, TelegramMethod +from aiogram.methods import GetMe, GetUpdates, SendMessage, TelegramMethod from aiogram.types import ( CallbackQuery, Chat, @@ -462,9 +462,9 @@ class TestDispatcher: async def my_handler(event: Any, **kwargs: Any): assert event == getattr(update, event_type) if has_chat: - assert Chat.get_current(False) + assert kwargs["event_chat"] if has_user: - assert User.get_current(False) + assert kwargs["event_from_user"] return kwargs result = await router.feed_update(bot, update, test="PASS") diff --git a/tests/test_dispatcher/test_middlewares/test_user_context.py b/tests/test_dispatcher/test_middlewares/test_user_context.py index ca2abb2d..54c09ce2 100644 --- a/tests/test_dispatcher/test_middlewares/test_user_context.py +++ b/tests/test_dispatcher/test_middlewares/test_user_context.py @@ -1,6 +1,9 @@ +from unittest.mock import patch + import pytest from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware +from aiogram.types import Update async def next_handler(*args, **kwargs): @@ -11,3 +14,13 @@ class TestUserContextMiddleware: async def test_unexpected_event_type(self): with pytest.raises(RuntimeError): await UserContextMiddleware()(next_handler, object(), {}) + + async def test_call(self): + middleware = UserContextMiddleware() + data = {} + with patch.object(UserContextMiddleware, "resolve_event_context", return_value=[1, 2, 3]): + await middleware(next_handler, Update(update_id=42), data) + + assert data["event_chat"] == 1 + assert data["event_from_user"] == 2 + assert data["event_thread_id"] == 3 diff --git a/tests/test_fsm/storage/test_redis.py b/tests/test_fsm/storage/test_redis.py index 6e42eb48..adca384a 100644 --- a/tests/test_fsm/storage/test_redis.py +++ b/tests/test_fsm/storage/test_redis.py @@ -11,6 +11,7 @@ PREFIX = "test" BOT_ID = 42 CHAT_ID = -1 USER_ID = 2 +THREAD_ID = 3 FIELD = "data" @@ -46,6 +47,19 @@ class TestRedisDefaultKeyBuilder: with pytest.raises(ValueError): key_builder.build(key, FIELD) + def test_thread_id(self): + key_builder = DefaultKeyBuilder( + prefix=PREFIX, + ) + key = StorageKey( + chat_id=CHAT_ID, + user_id=USER_ID, + bot_id=BOT_ID, + thread_id=THREAD_ID, + destiny=DEFAULT_DESTINY, + ) + assert key_builder.build(key, FIELD) == f"{PREFIX}:{CHAT_ID}:{THREAD_ID}:{USER_ID}:{FIELD}" + def test_create_isolation(self): fake_redis = object() storage = RedisStorage(redis=fake_redis) diff --git a/tests/test_fsm/test_strategy.py b/tests/test_fsm/test_strategy.py index b00a7b98..3dab2b3d 100644 --- a/tests/test_fsm/test_strategy.py +++ b/tests/test_fsm/test_strategy.py @@ -2,19 +2,41 @@ import pytest from aiogram.fsm.strategy import FSMStrategy, apply_strategy +CHAT_ID = -42 +USER_ID = 42 +THREAD_ID = 1 + +PRIVATE = (USER_ID, USER_ID, None) +CHAT = (CHAT_ID, USER_ID, None) +THREAD = (CHAT_ID, USER_ID, THREAD_ID) + class TestStrategy: @pytest.mark.parametrize( "strategy,case,expected", [ - [FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)], - [FSMStrategy.CHAT, (-42, 42), (-42, -42)], - [FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)], - [FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)], - [FSMStrategy.CHAT, (42, 42), (42, 42)], - [FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)], + [FSMStrategy.USER_IN_CHAT, CHAT, CHAT], + [FSMStrategy.USER_IN_CHAT, PRIVATE, PRIVATE], + [FSMStrategy.USER_IN_CHAT, THREAD, CHAT], + [FSMStrategy.CHAT, CHAT, (CHAT_ID, CHAT_ID, None)], + [FSMStrategy.CHAT, PRIVATE, (USER_ID, USER_ID, None)], + [FSMStrategy.CHAT, THREAD, (CHAT_ID, CHAT_ID, None)], + [FSMStrategy.GLOBAL_USER, CHAT, PRIVATE], + [FSMStrategy.GLOBAL_USER, PRIVATE, PRIVATE], + [FSMStrategy.GLOBAL_USER, THREAD, PRIVATE], + [FSMStrategy.USER_IN_THREAD, CHAT, CHAT], + [FSMStrategy.USER_IN_THREAD, PRIVATE, PRIVATE], + [FSMStrategy.USER_IN_THREAD, THREAD, THREAD], ], ) def test_strategy(self, strategy, case, expected): - chat_id, user_id = case - assert apply_strategy(chat_id=chat_id, user_id=user_id, strategy=strategy) == expected + chat_id, user_id, thread_id = case + assert ( + apply_strategy( + chat_id=chat_id, + user_id=user_id, + thread_id=thread_id, + strategy=strategy, + ) + == expected + )