diff --git a/aiogram/scenes/_history.py b/aiogram/scenes/_history.py index 659c58f8..09ae2861 100644 --- a/aiogram/scenes/_history.py +++ b/aiogram/scenes/_history.py @@ -1,22 +1,17 @@ -from dataclasses import dataclass, replace +from dataclasses import replace from typing import Any, Dict, List, Optional from aiogram import loggers from aiogram.fsm.context import FSMContext - - -@dataclass -class StateContainer: - state: Optional[str] - data: Dict[str, Any] +from aiogram.fsm.storage.memory import MemoryStorageRecord class HistoryManager: - def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10): + def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10): self._size = size - self._state = context + self._state = state self._history_state = FSMContext( - storage=context.storage, key=replace(context.key, destiny=destiny) + storage=state.storage, key=replace(state.key, destiny=destiny) ) async def push(self, state: Optional[str], data: Dict[str, Any]) -> None: @@ -25,10 +20,14 @@ class HistoryManager: history.append({"state": state, "data": data}) if len(history) > self._size: history = history[-self._size :] - loggers.scene.debug("Push state=%s data=%s", state, data) - await self._history_state.update_data(history=history) + loggers.scene.debug("Push state=%s data=%s to history", state, data) - async def pop(self) -> Optional[StateContainer]: + if not history: + await self._history_state.set_data({}) + else: + await self._history_state.update_data(history=history) + + async def pop(self) -> Optional[MemoryStorageRecord]: history_data = await self._history_state.get_data() history = history_data.setdefault("history", []) if not history: @@ -36,41 +35,48 @@ class HistoryManager: record = history.pop() state = record["state"] data = record["data"] - await self._history_state.update_data(history=history) - loggers.scene.debug("Pop state=%s data=%s", state, data) - return StateContainer(state=state, data=data) + if not history: + await self._history_state.set_data({}) + else: + await self._history_state.update_data(history=history) + loggers.scene.debug("Pop state=%s data=%s from history", state, data) + return MemoryStorageRecord(state=state, data=data) - async def get(self) -> Optional[StateContainer]: + async def get(self) -> Optional[MemoryStorageRecord]: history_data = await self._history_state.get_data() history = history_data.setdefault("history", []) if not history: return None - return StateContainer(**history[-1]) + return MemoryStorageRecord(**history[-1]) - async def all(self) -> List[StateContainer]: + async def all(self) -> List[MemoryStorageRecord]: history_data = await self._history_state.get_data() history = history_data.setdefault("history", []) - return [StateContainer(**item) for item in history] + return [MemoryStorageRecord(**item) for item in history] async def clear(self) -> None: loggers.scene.debug("Clear history") - await self._history_state.clear() + await self._history_state.set_data({}) async def snapshot(self) -> None: state = await self._state.get_state() data = await self._state.get_data() await self.push(state, data) + async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None: + await self._state.set_state(state) + await self._state.set_data(data) + async def rollback(self) -> Optional[str]: - state_container = await self.pop() - if not state_container: + previous_state = await self.pop() + if not previous_state: + await self._set_state(None, {}) return None loggers.scene.debug( "Rollback to state=%s data=%s", - state_container.state, - state_container.data, + previous_state.state, + previous_state.data, ) - await self._state.set_state(state_container.state) - await self._state.set_data(state_container.data) - return state_container.state + await self._set_state(previous_state.state, previous_state.data) + return previous_state.state diff --git a/aiogram/scenes/_scene.py b/aiogram/scenes/_scene.py index aa128a4d..cde7b18b 100644 --- a/aiogram/scenes/_scene.py +++ b/aiogram/scenes/_scene.py @@ -4,22 +4,11 @@ import inspect from collections import defaultdict from dataclasses import dataclass from enum import Enum, auto -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union from typing_extensions import Self -from aiogram import Router, loggers +from aiogram import Dispatcher, Router, loggers from aiogram.dispatcher.event.bases import NextMiddlewareType from aiogram.dispatcher.event.handler import CallableObject, CallbackType from aiogram.filters import StateFilter @@ -325,24 +314,23 @@ class SceneWizard: await self.state.set_state(self.scene_config.state) await self._on_action(SceneAction.enter, **kwargs) - async def leave(self, **kwargs: Any) -> None: + async def leave(self, _with_history: bool = True, **kwargs: Any) -> None: loggers.scene.debug("Leaving scene %r", self.scene_config.state) - await self.manager.history.snapshot() + if _with_history: + await self.manager.history.snapshot() await self._on_action(SceneAction.leave, **kwargs) async def exit(self, **kwargs: Any) -> None: loggers.scene.debug("Exiting scene %r", self.scene_config.state) - await self.state.set_state(None) await self.manager.history.clear() await self._on_action(SceneAction.exit, **kwargs) - await self.manager.enter(None, _check_active=False, _with_history=False, **kwargs) + await self.manager.enter(None, _check_active=False, **kwargs) async def back(self, **kwargs: Any) -> None: loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state) - await self.leave() - await self.manager.history.rollback() + await self.leave(_with_history=False, **kwargs) new_scene = await self.manager.history.rollback() - await self.manager.enter(new_scene, _check_active=False, _with_history=False, **kwargs) + await self.manager.enter(new_scene, _check_active=False, **kwargs) async def replay(self, event: TelegramObject) -> None: await self._on_action(SceneAction.enter, event=event) @@ -430,16 +418,20 @@ class ScenesManager: self, scene_type: Optional[Union[Type[Scene], str]], _check_active: bool = True, - _with_history: bool = True, **kwargs: Any, ) -> None: scene = await self._get_scene(scene_type) + if _check_active: active_scene = await self._get_active_scene() if active_scene is not None: await active_scene.wizard.exit(**kwargs) - await scene.wizard.enter(_with_history=_with_history, **kwargs) + if not scene: + loggers.scene.debug("Reset state") + await self.state.set_state(None) + else: + await scene.wizard.enter(**kwargs) async def close(self, **kwargs: Any) -> None: try: @@ -454,13 +446,38 @@ class ScenesManager: class SceneRegistry: def __init__(self, router: Router) -> None: self.router = router + self._scenes: Dict[Optional[str], Type[Scene]] = {} + + self._setup_middleware(router) + + def _setup_middleware(self, router: Router) -> None: + if isinstance(router, Dispatcher): + # Small optimization for Dispatcher + # - we don't need to set up middleware for all observers + router.update.outer_middleware(self._update_middleware) + return for observer in router.observers.values(): if observer.event_name in {"update", "error"}: continue observer.outer_middleware(self._middleware) - self._scenes: Dict[Optional[str], Type[Scene]] = {} + async def _update_middleware( + self, + handler: NextMiddlewareType[TelegramObject], + event: TelegramObject, + data: Dict[str, Any], + ) -> Any: + assert isinstance(event, Update), "Event must be an Update instance" + + data["scenes"] = ScenesManager( + registry=self, + update_type=event.event_type, + event=event.event, + state=data["state"], + data=data, + ) + return await handler(event, data) async def _middleware( self, @@ -497,12 +514,10 @@ class SceneRegistry: raise TypeError("Scene must be a subclass of Scene or a string") try: - result = self._scenes[scene] + return self._scenes[scene] except KeyError: raise ValueError(f"Scene {scene!r} is not registered") - return result - @dataclass class After: diff --git a/examples/scene.py b/examples/scene.py index a3663f7c..1e4a7980 100644 --- a/examples/scene.py +++ b/examples/scene.py @@ -37,11 +37,7 @@ class CancellableScene(Scene): @on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back()) async def handle_back(self, message: Message): - await message.answer("Back.", reply_markup=ReplyKeyboardRemove()) - - @on.message.exit() - async def on_exit(self, message: Message): - await self.wizard.clear_data() + await message.answer("Back.") class LanguageScene(CancellableScene, state="language"): @@ -99,10 +95,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"): @on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene)) async def process_like_write_bots(self, message: Message): - await message.reply( - "Cool! I'm too!", - reply_markup=ReplyKeyboardRemove(), - ) + await message.reply("Cool! I'm too!") @on.message(F.text.casefold() == "no", after=After.exit()) async def process_dont_like_write_bots(self, message: Message): @@ -135,10 +128,9 @@ class NameScene(CancellableScene, state="name"): @on.message.leave() # Marker for handler that should be called when a user leaves the scene. async def on_leave(self, message: Message): - await message.answer( - f"Nice to meet you, {html.quote(message.text)}!", - reply_markup=ReplyKeyboardRemove(), - ) + data: FSMData = await self.wizard.get_data() + name = data.get("name", "Anonymous") + await message.answer(f"Nice to meet you, {html.quote(name)}!") @on.message(after=After.goto(LikeBotsScene)) async def input_name(self, message: Message):