diff --git a/tests/test_fsm/test_scene.py b/tests/test_fsm/test_scene.py index d45dc5de..51b211c3 100644 --- a/tests/test_fsm/test_scene.py +++ b/tests/test_fsm/test_scene.py @@ -5,7 +5,9 @@ from unittest.mock import ANY, AsyncMock, patch import pytest -from aiogram import F +from aiogram import Dispatcher, F, Router +from aiogram.dispatcher.event.bases import NextMiddlewareType +from aiogram.exceptions import SceneException from aiogram.fsm.context import FSMContext from aiogram.fsm.scene import ( ActionContainer, @@ -15,12 +17,17 @@ from aiogram.fsm.scene import ( Scene, SceneAction, SceneHandlerWrapper, + SceneRegistry, ScenesManager, SceneWizard, _empty_handler, on, ) +from aiogram.fsm.state import State, StatesGroup +from aiogram.fsm.storage.memory import MemoryStorage +from aiogram.fsm.storage.base import StorageKey from aiogram.types import Chat, Message, TelegramObject, Update +from tests.mocked_bot import MockedBot class TestOnMarker: @@ -325,3 +332,243 @@ class TestSceneHandlerWrapper: assert hasattr(scene_handler_wrapper, "__await__") assert scene_handler_wrapper.__await__() is scene_handler_wrapper + + +class TestSceneRegistry: + def test_scene_registry_initialization(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + assert registry.router == router + assert registry.register_on_add == register_on_add + assert registry._scenes == {} + + def test_scene_registry_add_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_pass_router(self): + class MyScene(Scene): + pass + + router = Router() + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router) + registry.add(MyScene, router=router) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_already_exists(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException): + registry.add(MyScene) + + def test_scene_registry_add_scene_no_scenes(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry._scenes = {} + + with pytest.raises(ValueError, match="At least one scene must be specified"): + registry.add() + + def test_scene_registry_register(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.register(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get("test_scene") + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_state(self): + class MyStates(StatesGroup): + test_state = State() + + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = MyStates.test_state + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_not_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = 42 + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException, match="Scene must be a subclass of Scene or a string"): + registry.get(MyScene) + + def test_scene_registry_get_scene_not_registered(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + with pytest.raises(SceneException): + registry.get("test_scene") + + def test_scene_registry_setup_middleware_with_dispatcher(self): + router = Router() + + registry = SceneRegistry(router) + + dispatcher = Dispatcher() + registry._setup_middleware(dispatcher) + + assert registry._update_middleware in dispatcher.update.outer_middleware + + for name, observer in dispatcher.observers.items(): + if name == "update": + continue + assert registry._update_middleware not in observer.outer_middleware + + def test_scene_registry_setup_middleware_with_router(self): + inner_router = Router() + + registry = SceneRegistry(inner_router) + + outer_router = Router() + registry._setup_middleware(outer_router) + + for name, observer in outer_router.observers.items(): + if name in ("update", "error"): + continue + assert registry._middleware in observer.outer_middleware + + @pytest.mark.asyncio + async def test_scene_registry_update_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update(update_id=42, message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + )) + data = {"state": FSMContext(storage=MemoryStorage(), + key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id))} + + result = await registry._update_middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value + + @pytest.mark.asyncio + async def test_scene_registry_update_middleware_not_update(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ) + data = {"state": FSMContext(storage=MemoryStorage(), + key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id))} + + with pytest.raises(AssertionError, match="Event must be an Update instance"): + await registry._update_middleware(handler, event, data) + + @pytest.mark.asyncio + async def test_scene_registry_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update(update_id=42, message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + )) + data = {"state": FSMContext(storage=MemoryStorage(), + key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id)), + "event_update": event} + + result = await registry._middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value