From eeabf3cf32be5c4fc6cc3ef14764d69306310353 Mon Sep 17 00:00:00 2001 From: Andrew <11490628+andrew000@users.noreply.github.com> Date: Wed, 22 Nov 2023 21:11:34 +0200 Subject: [PATCH] Scenes Tests. Part 2 (#1371) * ADD tests for `SceneWizard` * ADD tests for `Scene` * Refactor ObserverDecorator to use on.message syntax in test_scene.py Cover `Scene::__init_subclass__::if isinstance(value, ObserverDecorator):` * Refactor `HistoryManager` in `aiogram/fsm/scene.py` Removed condition that checked if 'history' is empty before calling 'update_data' in 'Scene'. * ADD tests for `HistoryManager` --- aiogram/fsm/scene.py | 5 +- tests/test_fsm/test_scene.py | 724 ++++++++++++++++++++++++++++++++++- 2 files changed, 724 insertions(+), 5 deletions(-) diff --git a/aiogram/fsm/scene.py b/aiogram/fsm/scene.py index 89dca081..fe4de5c4 100644 --- a/aiogram/fsm/scene.py +++ b/aiogram/fsm/scene.py @@ -38,10 +38,7 @@ class HistoryManager: history = history[-self._size :] loggers.scene.debug("Push state=%s data=%s to history", state, data) - if not history: - await self._history_state.set_data({}) - else: - await self._history_state.update_data(history=history) + await self._history_state.update_data(history=history) async def pop(self) -> Optional[MemoryStorageRecord]: history_data = await self._history_state.get_data() diff --git a/tests/test_fsm/test_scene.py b/tests/test_fsm/test_scene.py index 9cb936d1..1313301f 100644 --- a/tests/test_fsm/test_scene.py +++ b/tests/test_fsm/test_scene.py @@ -8,14 +8,18 @@ import pytest from aiogram import Dispatcher, F, Router from aiogram.dispatcher.event.bases import NextMiddlewareType from aiogram.exceptions import SceneException +from aiogram.filters import StateFilter from aiogram.fsm.context import FSMContext from aiogram.fsm.scene import ( ActionContainer, After, + HandlerContainer, + HistoryManager, ObserverDecorator, ObserverMarker, Scene, SceneAction, + SceneConfig, SceneHandlerWrapper, SceneRegistry, ScenesManager, @@ -25,7 +29,7 @@ from aiogram.fsm.scene import ( ) from aiogram.fsm.state import State, StatesGroup from aiogram.fsm.storage.base import StorageKey -from aiogram.fsm.storage.memory import MemoryStorage +from aiogram.fsm.storage.memory import MemoryStorage, MemoryStorageRecord from aiogram.types import Chat, Message, TelegramObject, Update from tests.mocked_bot import MockedBot @@ -334,6 +338,724 @@ class TestSceneHandlerWrapper: assert scene_handler_wrapper.__await__() is scene_handler_wrapper +class TestHistoryManager: + async def test_history_manager_push(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + history_data = await history_manager._history_state.get_data() + assert history_data.get("history") == [{"state": "test_state", "data": data}] + + async def test_history_manager_push_if_history_overflow(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state, size=2) + + states = ["test_state", "test_state2", "test_state3", "test_state4"] + data = {"test_data": "test_data"} + for state in states: + await history_manager.push(state, data) + + history_data = await history_manager._history_state.get_data() + assert history_data.get("history") == [ + {"state": "test_state3", "data": data}, + {"state": "test_state4", "data": data}, + ] + + async def test_history_manager_pop(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + await history_manager.push("test_state2", data) + + record = await history_manager.pop() + history_data = await history_manager._history_state.get_data() + + assert isinstance(record, MemoryStorageRecord) + assert record == MemoryStorageRecord(state="test_state2", data=data) + assert history_data.get("history") == [{"state": "test_state", "data": data}] + + async def test_history_manager_pop_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + record = await history_manager.pop() + assert record is None + + async def test_history_manager_pop_if_history_become_empty_after_pop(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.pop() + + assert await history_manager._history_state.get_data() == {} + + async def test_history_manager_get(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + record = await history_manager.get() + + assert isinstance(record, MemoryStorageRecord) + assert record == MemoryStorageRecord(state="test_state", data=data) + + async def test_history_manager_get_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + record = await history_manager.get() + assert record is None + + async def test_history_manager_all(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager_size = 10 + history_manager = HistoryManager(state=state, size=history_manager_size) + + data = {"test_data": "test_data"} + for i in range(history_manager_size): + await history_manager.push(f"test_state{i}", data) + + records = await history_manager.all() + + assert isinstance(records, list) + assert len(records) == history_manager_size + assert all(isinstance(record, MemoryStorageRecord) for record in records) + + async def test_history_manager_all_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + records = await history_manager.all() + assert records == [] + + async def test_history_manager_clear(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + + history_manager = HistoryManager(state=state) + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.clear() + + assert await history_manager._history_state.get_data() == {} + + async def test_history_manager_snapshot(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + + history_manager = HistoryManager(state=state) + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.snapshot() + + assert await history_manager._history_state.get_data() == { + "history": [ + {"state": "test_state", "data": data}, + { + "state": await history_manager._state.get_state(), + "data": await history_manager._state.get_data(), + }, + ] + } + + async def test_history_manager_set_state(self): + state_mock = AsyncMock(spec=FSMContext) + state_mock.storage = MemoryStorage() + state_mock.key = StorageKey(bot_id=42, chat_id=-42, user_id=42) + state_mock.set_state = AsyncMock() + state_mock.set_data = AsyncMock() + + history_manager = HistoryManager(state=state_mock) + history_manager._state = state_mock + + state = "test_state" + data = {"test_data": "test_data"} + await history_manager._set_state(state, data) + + state_mock.set_state.assert_called_once_with(state) + state_mock.set_data.assert_called_once_with(data) + + async def test_history_manager_rollback(self): + history_manager = HistoryManager( + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + ) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + await history_manager.push("test_state2", data) + + record = await history_manager.get() + assert record == MemoryStorageRecord(state="test_state2", data=data) + + await history_manager.rollback() + + record = await history_manager.get() + assert record == MemoryStorageRecord(state="test_state", data=data) + + async def test_history_manager_rollback_if_not_previous_state(self): + history_manager = HistoryManager( + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + ) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + state = await history_manager.rollback() + assert state == "test_state" + + state = await history_manager.rollback() + assert state is None + + +class TestScene: + def test_scene_subclass_initialisation(self): + class ParentScene(Scene): + @on.message(F.text) + def parent_handler(self, *args, **kwargs): + pass + + @on.message.enter(F.text) + def parent_action(self, *args, **kwargs): + pass + + class SubScene( + ParentScene, + state="sub_state", + reset_data_on_enter=True, + reset_history_on_enter=True, + callback_query_without_state=True, + ): + general_handler = on.message( + F.text.casefold() == "test", after=After.goto("sub_state") + ) + + @on.message(F.text) + def sub_handler(self, *args, **kwargs): + pass + + @on.message.exit() + def sub_action(self, *args, **kwargs): + pass + + # Assert __scene_config__ attributes are correctly set for SubScene + assert isinstance(SubScene.__scene_config__, SceneConfig) + assert SubScene.__scene_config__.state == "sub_state" + assert SubScene.__scene_config__.reset_data_on_enter is True + assert SubScene.__scene_config__.reset_history_on_enter is True + assert SubScene.__scene_config__.callback_query_without_state is True + + # Assert handlers are correctly set + assert len(SubScene.__scene_config__.handlers) == 3 + + for handler in SubScene.__scene_config__.handlers: + assert isinstance(handler, HandlerContainer) + assert handler.name == "message" + assert handler.handler in ( + ParentScene.parent_handler, + SubScene.sub_handler, + _empty_handler, + ) + assert handler.filters == (F.text,) + + # Assert actions are correctly set + assert len(SubScene.__scene_config__.actions) == 2 + + enter_action = SubScene.__scene_config__.actions[SceneAction.enter] + assert isinstance(enter_action, dict) + assert "message" in enter_action + assert enter_action["message"].callback == ParentScene.parent_action + + exit_action = SubScene.__scene_config__.actions[SceneAction.exit] + assert isinstance(exit_action, dict) + assert "message" in exit_action + assert exit_action["message"].callback == SubScene.sub_action + + def test_scene_subclass_initialisation_bases_is_scene_subclass(self): + class NotAScene: + pass + + class MyScene(Scene, NotAScene): + pass + + class TestClass(MyScene, NotAScene): + pass + + assert MyScene in TestClass.__bases__ + assert NotAScene in TestClass.__bases__ + bases = [base for base in TestClass.__bases__ if not issubclass(base, Scene)] + assert Scene not in bases + assert NotAScene in bases + + def test_scene_add_to_router(self): + class MyScene(Scene): + @on.message(F.text) + def test_handler(self, *args, **kwargs): + pass + + router = Router() + MyScene.add_to_router(router) + + assert len(router.observers["message"].handlers) == 1 + + def test_scene_add_to_router_scene_with_callback_query_without_state(self): + class MyScene(Scene, callback_query_without_state=True): + @on.callback_query(F.data) + def test_handler(self, *args, **kwargs): + pass + + router = Router() + MyScene.add_to_router(router) + + assert len(router.observers["callback_query"].handlers) == 1 + assert ( + StateFilter(MyScene.__scene_config__.state) + not in router.observers["callback_query"].handlers[0].filters + ) + + def test_scene_as_handler(self): + class MyScene(Scene): + @on.message(F.text) + def test_handler(self, *args, **kwargs): + pass + + handler = MyScene.as_handler() + + router = Router() + router.message.register(handler) + assert router.observers["message"].handlers[0].callback == handler + + async def test_scene_as_handler_enter(self): + class MyScene(Scene): + @on.message.enter(F.text) + def test_handler(self, *args, **kwargs): + pass + + event = AsyncMock() + + scenes = ScenesManager( + registry=SceneRegistry(Router()), + update_type="message", + event=event, + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ), + data={}, + ) + scenes.enter = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + handler = MyScene.as_handler(**kwargs) + await handler(event, scenes) + + scenes.enter.assert_called_once_with(MyScene, **kwargs) + + +class TestSceneWizard: + async def test_scene_wizard_enter_with_reset_data_on_enter(self): + class MyScene(Scene, reset_data_on_enter=True): + pass + + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + state_mock = AsyncMock(spec=FSMContext) + state_mock.set_state = AsyncMock() + + wizard = SceneWizard( + scene_config=MyScene.__scene_config__, + manager=AsyncMock(spec=ScenesManager), + state=state_mock, + update_type="message", + event=AsyncMock(), + data={}, + ) + kwargs = {"test_kwargs": "test_kwargs"} + wizard._on_action = AsyncMock() + + await wizard.enter(**kwargs) + + state_mock.set_data.assert_called_once_with({}) + state_mock.set_state.assert_called_once_with(MyScene.__scene_config__.state) + wizard._on_action.assert_called_once_with(SceneAction.enter, **kwargs) + + async def test_scene_wizard_enter_with_reset_history_on_enter(self): + class MyScene(Scene, reset_history_on_enter=True): + pass + + state_mock = AsyncMock(spec=FSMContext) + state_mock.set_state = AsyncMock() + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.clear = AsyncMock() + + wizard = SceneWizard( + scene_config=MyScene.__scene_config__, + manager=manager, + state=state_mock, + update_type="message", + event=AsyncMock(), + data={}, + ) + kwargs = {"test_kwargs": "test_kwargs"} + wizard._on_action = AsyncMock() + + await wizard.enter(**kwargs) + + manager.history.clear.assert_called_once() + state_mock.set_state.assert_called_once_with(MyScene.__scene_config__.state) + wizard._on_action.assert_called_once_with(SceneAction.enter, **kwargs) + + async def test_scene_wizard_leave_with_history(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.snapshot = AsyncMock() + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard._on_action = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.leave(_with_history=False, **kwargs) + + manager.history.snapshot.assert_not_called() + wizard._on_action.assert_called_once_with(SceneAction.leave, **kwargs) + + async def test_scene_wizard_leave_without_history(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.snapshot = AsyncMock() + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard._on_action = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.leave(**kwargs) + + manager.history.snapshot.assert_called_once() + wizard._on_action.assert_called_once_with(SceneAction.leave, **kwargs) + + async def test_scene_wizard_back(self): + current_scene_config_mock = AsyncMock() + current_scene_config_mock.state = "test_state" + + previous_scene_config_mock = AsyncMock() + previous_scene_config_mock.state = "previous_test_state" + + previous_scene_mock = AsyncMock() + previous_scene_mock.__scene_config__ = previous_scene_config_mock + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.rollback = AsyncMock(return_value=previous_scene_mock) + manager.enter = AsyncMock() + + wizard = SceneWizard( + scene_config=current_scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.leave = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.back(**kwargs) + + wizard.leave.assert_called_once_with(_with_history=False, **kwargs) + manager.history.rollback.assert_called_once() + manager.enter.assert_called_once_with(previous_scene_mock, _check_active=False, **kwargs) + + async def test_scene_wizard_retake(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.goto = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.retake(**kwargs) + + wizard.goto.assert_called_once_with(scene_config_mock.state, **kwargs) + + async def test_scene_wizard_retake_exception(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = None + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + + kwargs = {"test_kwargs": "test_kwargs"} + + with pytest.raises(AssertionError, match="Scene state is not specified"): + await wizard.retake(**kwargs) + + async def test_scene_wizard_goto(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + scene_mock.__scene_config__ = scene_config_mock + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.leave = AsyncMock() + wizard.manager.enter = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.goto(scene_mock, **kwargs) + + wizard.leave.assert_called_once_with(**kwargs) + wizard.manager.enter.assert_called_once_with(scene_mock, _check_active=False, **kwargs) + + async def test_scene_wizard_on_action(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {SceneAction.enter: {"message": AsyncMock()}} + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + + event_mock = AsyncMock() + event_mock.type = "message" + + data = {"test_data": "test_data"} + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data=data, + ) + wizard.scene = scene_mock + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + scene_config_mock.actions[SceneAction.enter]["message"].call.assert_called_once_with( + scene_mock, event_mock, **{**data, **kwargs} + ) + assert result is True + + async def test_scene_wizard_on_action_no_scene(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + + with pytest.raises(SceneException, match="Scene is not initialized"): + await wizard._on_action(SceneAction.enter) + + async def test_scene_wizard_on_action_no_action_config(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {} + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + + event_mock = AsyncMock() + event_mock.type = "message" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data={}, + ) + wizard.scene = scene_mock + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + assert result is False + + async def test_scene_wizard_on_action_event_type_not_in_action_config(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {SceneAction.enter: {"test_update_type": AsyncMock()}} + scene_config_mock.state = "test_state" + + event_mock = AsyncMock() + event_mock.type = "message" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data={}, + ) + wizard.scene = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + assert result is False + + async def test_scene_wizard_set_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.state.set_data = AsyncMock() + + data = {"test_key": "test_value"} + await wizard.set_data(data) + + wizard.state.set_data.assert_called_once_with(data=data) + + async def test_scene_wizard_get_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.state.get_data = AsyncMock() + + await wizard.get_data() + + wizard.state.get_data.assert_called_once_with() + + async def test_scene_wizard_update_data_if_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + data = {"test_key": "test_value"} + kwargs = {"test_kwargs": "test_kwargs"} + + wizard.state.update_data = AsyncMock(return_value={**data, **kwargs}) + result = await wizard.update_data(data=data, **kwargs) + + wizard.state.update_data.assert_called_once_with(data={**data, **kwargs}) + assert result == {**data, **kwargs} + + async def test_scene_wizard_update_data_if_no_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + data = None + kwargs = {"test_kwargs": "test_kwargs"} + + wizard.state.update_data = AsyncMock(return_value={**kwargs}) + result = await wizard.update_data(data=data, **kwargs) + + wizard.state.update_data.assert_called_once_with(data=kwargs) + assert result == {**kwargs} + + async def test_scene_wizard_clear_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.set_data = AsyncMock() + + await wizard.clear_data() + + wizard.set_data.assert_called_once_with({}) + + class TestScenesManager: async def test_scenes_manager_get_scene(self, bot: MockedBot): class MyScene(Scene):