Small optimizations

This commit is contained in:
Alex Root Junior 2023-08-26 17:32:36 +03:00
parent 278697297e
commit bcdf8ea9da
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
3 changed files with 80 additions and 67 deletions

View file

@ -1,22 +1,17 @@
from dataclasses import dataclass, replace
from dataclasses import replace
from typing import Any, Dict, List, Optional
from aiogram import loggers
from aiogram.fsm.context import FSMContext
@dataclass
class StateContainer:
state: Optional[str]
data: Dict[str, Any]
from aiogram.fsm.storage.memory import MemoryStorageRecord
class HistoryManager:
def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10):
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
self._size = size
self._state = context
self._state = state
self._history_state = FSMContext(
storage=context.storage, key=replace(context.key, destiny=destiny)
storage=state.storage, key=replace(state.key, destiny=destiny)
)
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
@ -25,10 +20,14 @@ class HistoryManager:
history.append({"state": state, "data": data})
if len(history) > self._size:
history = history[-self._size :]
loggers.scene.debug("Push state=%s data=%s", state, data)
await self._history_state.update_data(history=history)
loggers.scene.debug("Push state=%s data=%s to history", state, data)
async def pop(self) -> Optional[StateContainer]:
if not history:
await self._history_state.set_data({})
else:
await self._history_state.update_data(history=history)
async def pop(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
@ -36,41 +35,48 @@ class HistoryManager:
record = history.pop()
state = record["state"]
data = record["data"]
await self._history_state.update_data(history=history)
loggers.scene.debug("Pop state=%s data=%s", state, data)
return StateContainer(state=state, data=data)
if not history:
await self._history_state.set_data({})
else:
await self._history_state.update_data(history=history)
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
return MemoryStorageRecord(state=state, data=data)
async def get(self) -> Optional[StateContainer]:
async def get(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
return None
return StateContainer(**history[-1])
return MemoryStorageRecord(**history[-1])
async def all(self) -> List[StateContainer]:
async def all(self) -> List[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
return [StateContainer(**item) for item in history]
return [MemoryStorageRecord(**item) for item in history]
async def clear(self) -> None:
loggers.scene.debug("Clear history")
await self._history_state.clear()
await self._history_state.set_data({})
async def snapshot(self) -> None:
state = await self._state.get_state()
data = await self._state.get_data()
await self.push(state, data)
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
await self._state.set_state(state)
await self._state.set_data(data)
async def rollback(self) -> Optional[str]:
state_container = await self.pop()
if not state_container:
previous_state = await self.pop()
if not previous_state:
await self._set_state(None, {})
return None
loggers.scene.debug(
"Rollback to state=%s data=%s",
state_container.state,
state_container.data,
previous_state.state,
previous_state.data,
)
await self._state.set_state(state_container.state)
await self._state.set_data(state_container.data)
return state_container.state
await self._set_state(previous_state.state, previous_state.data)
return previous_state.state

View file

@ -4,22 +4,11 @@ import inspect
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing_extensions import Self
from aiogram import Router, loggers
from aiogram import Dispatcher, Router, loggers
from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.filters import StateFilter
@ -325,24 +314,23 @@ class SceneWizard:
await self.state.set_state(self.scene_config.state)
await self._on_action(SceneAction.enter, **kwargs)
async def leave(self, **kwargs: Any) -> None:
async def leave(self, _with_history: bool = True, **kwargs: Any) -> None:
loggers.scene.debug("Leaving scene %r", self.scene_config.state)
await self.manager.history.snapshot()
if _with_history:
await self.manager.history.snapshot()
await self._on_action(SceneAction.leave, **kwargs)
async def exit(self, **kwargs: Any) -> None:
loggers.scene.debug("Exiting scene %r", self.scene_config.state)
await self.state.set_state(None)
await self.manager.history.clear()
await self._on_action(SceneAction.exit, **kwargs)
await self.manager.enter(None, _check_active=False, _with_history=False, **kwargs)
await self.manager.enter(None, _check_active=False, **kwargs)
async def back(self, **kwargs: Any) -> None:
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
await self.leave()
await self.manager.history.rollback()
await self.leave(_with_history=False, **kwargs)
new_scene = await self.manager.history.rollback()
await self.manager.enter(new_scene, _check_active=False, _with_history=False, **kwargs)
await self.manager.enter(new_scene, _check_active=False, **kwargs)
async def replay(self, event: TelegramObject) -> None:
await self._on_action(SceneAction.enter, event=event)
@ -430,16 +418,20 @@ class ScenesManager:
self,
scene_type: Optional[Union[Type[Scene], str]],
_check_active: bool = True,
_with_history: bool = True,
**kwargs: Any,
) -> None:
scene = await self._get_scene(scene_type)
if _check_active:
active_scene = await self._get_active_scene()
if active_scene is not None:
await active_scene.wizard.exit(**kwargs)
await scene.wizard.enter(_with_history=_with_history, **kwargs)
if not scene:
loggers.scene.debug("Reset state")
await self.state.set_state(None)
else:
await scene.wizard.enter(**kwargs)
async def close(self, **kwargs: Any) -> None:
try:
@ -454,13 +446,38 @@ class ScenesManager:
class SceneRegistry:
def __init__(self, router: Router) -> None:
self.router = router
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None:
if isinstance(router, Dispatcher):
# Small optimization for Dispatcher
# - we don't need to set up middleware for all observers
router.update.outer_middleware(self._update_middleware)
return
for observer in router.observers.values():
if observer.event_name in {"update", "error"}:
continue
observer.outer_middleware(self._middleware)
self._scenes: Dict[Optional[str], Type[Scene]] = {}
async def _update_middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
assert isinstance(event, Update), "Event must be an Update instance"
data["scenes"] = ScenesManager(
registry=self,
update_type=event.event_type,
event=event.event,
state=data["state"],
data=data,
)
return await handler(event, data)
async def _middleware(
self,
@ -497,12 +514,10 @@ class SceneRegistry:
raise TypeError("Scene must be a subclass of Scene or a string")
try:
result = self._scenes[scene]
return self._scenes[scene]
except KeyError:
raise ValueError(f"Scene {scene!r} is not registered")
return result
@dataclass
class After:

View file

@ -37,11 +37,7 @@ class CancellableScene(Scene):
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
async def handle_back(self, message: Message):
await message.answer("Back.", reply_markup=ReplyKeyboardRemove())
@on.message.exit()
async def on_exit(self, message: Message):
await self.wizard.clear_data()
await message.answer("Back.")
class LanguageScene(CancellableScene, state="language"):
@ -99,10 +95,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
async def process_like_write_bots(self, message: Message):
await message.reply(
"Cool! I'm too!",
reply_markup=ReplyKeyboardRemove(),
)
await message.reply("Cool! I'm too!")
@on.message(F.text.casefold() == "no", after=After.exit())
async def process_dont_like_write_bots(self, message: Message):
@ -135,10 +128,9 @@ class NameScene(CancellableScene, state="name"):
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
async def on_leave(self, message: Message):
await message.answer(
f"Nice to meet you, {html.quote(message.text)}!",
reply_markup=ReplyKeyboardRemove(),
)
data: FSMData = await self.wizard.get_data()
name = data.get("name", "Anonymous")
await message.answer(f"Nice to meet you, {html.quote(name)}!")
@on.message(after=After.goto(LikeBotsScene))
async def input_name(self, message: Message):