mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Small optimizations
This commit is contained in:
parent
278697297e
commit
bcdf8ea9da
3 changed files with 80 additions and 67 deletions
|
|
@ -1,22 +1,17 @@
|
|||
from dataclasses import dataclass, replace
|
||||
from dataclasses import replace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from aiogram import loggers
|
||||
from aiogram.fsm.context import FSMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateContainer:
|
||||
state: Optional[str]
|
||||
data: Dict[str, Any]
|
||||
from aiogram.fsm.storage.memory import MemoryStorageRecord
|
||||
|
||||
|
||||
class HistoryManager:
|
||||
def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10):
|
||||
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
|
||||
self._size = size
|
||||
self._state = context
|
||||
self._state = state
|
||||
self._history_state = FSMContext(
|
||||
storage=context.storage, key=replace(context.key, destiny=destiny)
|
||||
storage=state.storage, key=replace(state.key, destiny=destiny)
|
||||
)
|
||||
|
||||
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
||||
|
|
@ -25,10 +20,14 @@ class HistoryManager:
|
|||
history.append({"state": state, "data": data})
|
||||
if len(history) > self._size:
|
||||
history = history[-self._size :]
|
||||
loggers.scene.debug("Push state=%s data=%s", state, data)
|
||||
await self._history_state.update_data(history=history)
|
||||
loggers.scene.debug("Push state=%s data=%s to history", state, data)
|
||||
|
||||
async def pop(self) -> Optional[StateContainer]:
|
||||
if not history:
|
||||
await self._history_state.set_data({})
|
||||
else:
|
||||
await self._history_state.update_data(history=history)
|
||||
|
||||
async def pop(self) -> Optional[MemoryStorageRecord]:
|
||||
history_data = await self._history_state.get_data()
|
||||
history = history_data.setdefault("history", [])
|
||||
if not history:
|
||||
|
|
@ -36,41 +35,48 @@ class HistoryManager:
|
|||
record = history.pop()
|
||||
state = record["state"]
|
||||
data = record["data"]
|
||||
await self._history_state.update_data(history=history)
|
||||
loggers.scene.debug("Pop state=%s data=%s", state, data)
|
||||
return StateContainer(state=state, data=data)
|
||||
if not history:
|
||||
await self._history_state.set_data({})
|
||||
else:
|
||||
await self._history_state.update_data(history=history)
|
||||
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
|
||||
return MemoryStorageRecord(state=state, data=data)
|
||||
|
||||
async def get(self) -> Optional[StateContainer]:
|
||||
async def get(self) -> Optional[MemoryStorageRecord]:
|
||||
history_data = await self._history_state.get_data()
|
||||
history = history_data.setdefault("history", [])
|
||||
if not history:
|
||||
return None
|
||||
return StateContainer(**history[-1])
|
||||
return MemoryStorageRecord(**history[-1])
|
||||
|
||||
async def all(self) -> List[StateContainer]:
|
||||
async def all(self) -> List[MemoryStorageRecord]:
|
||||
history_data = await self._history_state.get_data()
|
||||
history = history_data.setdefault("history", [])
|
||||
return [StateContainer(**item) for item in history]
|
||||
return [MemoryStorageRecord(**item) for item in history]
|
||||
|
||||
async def clear(self) -> None:
|
||||
loggers.scene.debug("Clear history")
|
||||
await self._history_state.clear()
|
||||
await self._history_state.set_data({})
|
||||
|
||||
async def snapshot(self) -> None:
|
||||
state = await self._state.get_state()
|
||||
data = await self._state.get_data()
|
||||
await self.push(state, data)
|
||||
|
||||
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
|
||||
await self._state.set_state(state)
|
||||
await self._state.set_data(data)
|
||||
|
||||
async def rollback(self) -> Optional[str]:
|
||||
state_container = await self.pop()
|
||||
if not state_container:
|
||||
previous_state = await self.pop()
|
||||
if not previous_state:
|
||||
await self._set_state(None, {})
|
||||
return None
|
||||
|
||||
loggers.scene.debug(
|
||||
"Rollback to state=%s data=%s",
|
||||
state_container.state,
|
||||
state_container.data,
|
||||
previous_state.state,
|
||||
previous_state.data,
|
||||
)
|
||||
await self._state.set_state(state_container.state)
|
||||
await self._state.set_data(state_container.data)
|
||||
return state_container.state
|
||||
await self._set_state(previous_state.state, previous_state.data)
|
||||
return previous_state.state
|
||||
|
|
|
|||
|
|
@ -4,22 +4,11 @@ import inspect
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from aiogram import Router, loggers
|
||||
from aiogram import Dispatcher, Router, loggers
|
||||
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
||||
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
||||
from aiogram.filters import StateFilter
|
||||
|
|
@ -325,24 +314,23 @@ class SceneWizard:
|
|||
await self.state.set_state(self.scene_config.state)
|
||||
await self._on_action(SceneAction.enter, **kwargs)
|
||||
|
||||
async def leave(self, **kwargs: Any) -> None:
|
||||
async def leave(self, _with_history: bool = True, **kwargs: Any) -> None:
|
||||
loggers.scene.debug("Leaving scene %r", self.scene_config.state)
|
||||
await self.manager.history.snapshot()
|
||||
if _with_history:
|
||||
await self.manager.history.snapshot()
|
||||
await self._on_action(SceneAction.leave, **kwargs)
|
||||
|
||||
async def exit(self, **kwargs: Any) -> None:
|
||||
loggers.scene.debug("Exiting scene %r", self.scene_config.state)
|
||||
await self.state.set_state(None)
|
||||
await self.manager.history.clear()
|
||||
await self._on_action(SceneAction.exit, **kwargs)
|
||||
await self.manager.enter(None, _check_active=False, _with_history=False, **kwargs)
|
||||
await self.manager.enter(None, _check_active=False, **kwargs)
|
||||
|
||||
async def back(self, **kwargs: Any) -> None:
|
||||
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
|
||||
await self.leave()
|
||||
await self.manager.history.rollback()
|
||||
await self.leave(_with_history=False, **kwargs)
|
||||
new_scene = await self.manager.history.rollback()
|
||||
await self.manager.enter(new_scene, _check_active=False, _with_history=False, **kwargs)
|
||||
await self.manager.enter(new_scene, _check_active=False, **kwargs)
|
||||
|
||||
async def replay(self, event: TelegramObject) -> None:
|
||||
await self._on_action(SceneAction.enter, event=event)
|
||||
|
|
@ -430,16 +418,20 @@ class ScenesManager:
|
|||
self,
|
||||
scene_type: Optional[Union[Type[Scene], str]],
|
||||
_check_active: bool = True,
|
||||
_with_history: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
scene = await self._get_scene(scene_type)
|
||||
|
||||
if _check_active:
|
||||
active_scene = await self._get_active_scene()
|
||||
if active_scene is not None:
|
||||
await active_scene.wizard.exit(**kwargs)
|
||||
|
||||
await scene.wizard.enter(_with_history=_with_history, **kwargs)
|
||||
if not scene:
|
||||
loggers.scene.debug("Reset state")
|
||||
await self.state.set_state(None)
|
||||
else:
|
||||
await scene.wizard.enter(**kwargs)
|
||||
|
||||
async def close(self, **kwargs: Any) -> None:
|
||||
try:
|
||||
|
|
@ -454,13 +446,38 @@ class ScenesManager:
|
|||
class SceneRegistry:
|
||||
def __init__(self, router: Router) -> None:
|
||||
self.router = router
|
||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
||||
|
||||
self._setup_middleware(router)
|
||||
|
||||
def _setup_middleware(self, router: Router) -> None:
|
||||
if isinstance(router, Dispatcher):
|
||||
# Small optimization for Dispatcher
|
||||
# - we don't need to set up middleware for all observers
|
||||
router.update.outer_middleware(self._update_middleware)
|
||||
return
|
||||
|
||||
for observer in router.observers.values():
|
||||
if observer.event_name in {"update", "error"}:
|
||||
continue
|
||||
observer.outer_middleware(self._middleware)
|
||||
|
||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
||||
async def _update_middleware(
|
||||
self,
|
||||
handler: NextMiddlewareType[TelegramObject],
|
||||
event: TelegramObject,
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(event, Update), "Event must be an Update instance"
|
||||
|
||||
data["scenes"] = ScenesManager(
|
||||
registry=self,
|
||||
update_type=event.event_type,
|
||||
event=event.event,
|
||||
state=data["state"],
|
||||
data=data,
|
||||
)
|
||||
return await handler(event, data)
|
||||
|
||||
async def _middleware(
|
||||
self,
|
||||
|
|
@ -497,12 +514,10 @@ class SceneRegistry:
|
|||
raise TypeError("Scene must be a subclass of Scene or a string")
|
||||
|
||||
try:
|
||||
result = self._scenes[scene]
|
||||
return self._scenes[scene]
|
||||
except KeyError:
|
||||
raise ValueError(f"Scene {scene!r} is not registered")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class After:
|
||||
|
|
|
|||
|
|
@ -37,11 +37,7 @@ class CancellableScene(Scene):
|
|||
|
||||
@on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back())
|
||||
async def handle_back(self, message: Message):
|
||||
await message.answer("Back.", reply_markup=ReplyKeyboardRemove())
|
||||
|
||||
@on.message.exit()
|
||||
async def on_exit(self, message: Message):
|
||||
await self.wizard.clear_data()
|
||||
await message.answer("Back.")
|
||||
|
||||
|
||||
class LanguageScene(CancellableScene, state="language"):
|
||||
|
|
@ -99,10 +95,7 @@ class LikeBotsScene(CancellableScene, state="like_bots"):
|
|||
|
||||
@on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene))
|
||||
async def process_like_write_bots(self, message: Message):
|
||||
await message.reply(
|
||||
"Cool! I'm too!",
|
||||
reply_markup=ReplyKeyboardRemove(),
|
||||
)
|
||||
await message.reply("Cool! I'm too!")
|
||||
|
||||
@on.message(F.text.casefold() == "no", after=After.exit())
|
||||
async def process_dont_like_write_bots(self, message: Message):
|
||||
|
|
@ -135,10 +128,9 @@ class NameScene(CancellableScene, state="name"):
|
|||
|
||||
@on.message.leave() # Marker for handler that should be called when a user leaves the scene.
|
||||
async def on_leave(self, message: Message):
|
||||
await message.answer(
|
||||
f"Nice to meet you, {html.quote(message.text)}!",
|
||||
reply_markup=ReplyKeyboardRemove(),
|
||||
)
|
||||
data: FSMData = await self.wizard.get_data()
|
||||
name = data.get("name", "Anonymous")
|
||||
await message.answer(f"Nice to meet you, {html.quote(name)}!")
|
||||
|
||||
@on.message(after=After.goto(LikeBotsScene))
|
||||
async def input_name(self, message: Message):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue