diff --git a/tests/test_fsm/test_scene.py b/tests/test_fsm/test_scene.py index d45dc5de..9cb936d1 100644 --- a/tests/test_fsm/test_scene.py +++ b/tests/test_fsm/test_scene.py @@ -5,7 +5,9 @@ from unittest.mock import ANY, AsyncMock, patch import pytest -from aiogram import F +from aiogram import Dispatcher, F, Router +from aiogram.dispatcher.event.bases import NextMiddlewareType +from aiogram.exceptions import SceneException from aiogram.fsm.context import FSMContext from aiogram.fsm.scene import ( ActionContainer, @@ -15,12 +17,17 @@ from aiogram.fsm.scene import ( Scene, SceneAction, SceneHandlerWrapper, + SceneRegistry, ScenesManager, SceneWizard, _empty_handler, on, ) +from aiogram.fsm.state import State, StatesGroup +from aiogram.fsm.storage.base import StorageKey +from aiogram.fsm.storage.memory import MemoryStorage from aiogram.types import Chat, Message, TelegramObject, Update +from tests.mocked_bot import MockedBot class TestOnMarker: @@ -325,3 +332,494 @@ class TestSceneHandlerWrapper: assert hasattr(scene_handler_wrapper, "__await__") assert scene_handler_wrapper.__await__() is scene_handler_wrapper + + +class TestScenesManager: + async def test_scenes_manager_get_scene(self, bot: MockedBot): + class MyScene(Scene): + pass + + router = Router() + + registry = SceneRegistry(router) + registry.add(MyScene) + + scenes_manager = ScenesManager( + registry=registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await scenes_manager._get_scene(MyScene) + assert isinstance(scene, MyScene) + assert isinstance(scene.wizard, SceneWizard) + assert scene.wizard.scene_config == MyScene.__scene_config__ + assert scene.wizard.manager == scenes_manager + assert scene.wizard.update_type == "message" + assert scene.wizard.data == {} + + async def test_scenes_manager_get_active_scene(self, bot: MockedBot): + class TestScene(Scene): + pass + + class TestScene2(Scene, state="test_state2"): + pass + + registry = SceneRegistry(Router()) + registry.add(TestScene, TestScene2) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await manager._get_active_scene() + assert isinstance(scene, TestScene) + + await manager.enter(TestScene2) + scene = await manager._get_active_scene() + assert isinstance(scene, TestScene2) + + async def test_scenes_manager_get_active_scene_with_scene_exception(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await manager._get_active_scene() + + assert scene is None + + async def test_scenes_manager_enter_with_scene_type_none(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + assert await manager.enter(None) is None + + async def test_scenes_manager_enter_with_scene_exception(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = "invalid_scene" + with pytest.raises(SceneException, match=f"Scene {scene!r} is not registered"): + await manager.enter(scene) + + async def test_scenes_manager_close_if_active_scene(self, bot: MockedBot): + class TestScene(Scene): + pass + + registry = SceneRegistry(Router()) + registry.add(TestScene) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + manager._get_active_scene = AsyncMock( + return_value=TestScene( + SceneWizard( + scene_config=TestScene.__scene_config__, + manager=manager, + state=manager.state, + update_type="message", + event=manager.event, + data={}, + ) + ) + ) + manager._get_active_scene.return_value.wizard.exit = AsyncMock() + + await manager.close() + + manager._get_active_scene.assert_called_once() + manager._get_active_scene.return_value.wizard.exit.assert_called_once() + + async def test_scenes_manager_close_if_no_active_scene(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + manager._get_active_scene = AsyncMock(return_value=None) + + result = await manager.close() + + manager._get_active_scene.assert_called_once() + + assert result is None + + +class TestSceneRegistry: + def test_scene_registry_initialization(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + assert registry.router == router + assert registry.register_on_add == register_on_add + assert registry._scenes == {} + + def test_scene_registry_add_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_pass_router(self): + class MyScene(Scene): + pass + + router = Router() + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router) + registry.add(MyScene, router=router) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_already_exists(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException): + registry.add(MyScene) + + def test_scene_registry_add_scene_no_scenes(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry._scenes = {} + + with pytest.raises(ValueError, match="At least one scene must be specified"): + registry.add() + + def test_scene_registry_register(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.register(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get("test_scene") + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_state(self): + class MyStates(StatesGroup): + test_state = State() + + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = MyStates.test_state + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_not_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = 42 + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException, match="Scene must be a subclass of Scene or a string"): + registry.get(MyScene) + + def test_scene_registry_get_scene_not_registered(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + with pytest.raises(SceneException): + registry.get("test_scene") + + def test_scene_registry_setup_middleware_with_dispatcher(self): + router = Router() + + registry = SceneRegistry(router) + + dispatcher = Dispatcher() + registry._setup_middleware(dispatcher) + + assert registry._update_middleware in dispatcher.update.outer_middleware + + for name, observer in dispatcher.observers.items(): + if name == "update": + continue + assert registry._update_middleware not in observer.outer_middleware + + def test_scene_registry_setup_middleware_with_router(self): + inner_router = Router() + + registry = SceneRegistry(inner_router) + + outer_router = Router() + registry._setup_middleware(outer_router) + + for name, observer in outer_router.observers.items(): + if name in ("update", "error"): + continue + assert registry._middleware in observer.outer_middleware + + async def test_scene_registry_update_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ) + } + + result = await registry._update_middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value + + async def test_scene_registry_update_middleware_not_update(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ) + } + + with pytest.raises(AssertionError, match="Event must be an Update instance"): + await registry._update_middleware(handler, event, data) + + async def test_scene_registry_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + "event_update": event, + } + + result = await registry._middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value