Base implementation

This commit is contained in:
Alex Root Junior 2023-08-25 22:06:11 +03:00
parent b7be9c2b81
commit f27f3eea42
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
8 changed files with 658 additions and 3 deletions

View file

@ -18,7 +18,7 @@ CallbackType = Callable[..., Any]
@dataclass
class CallableMixin:
class CallableObject:
callback: CallbackType
awaitable: bool = field(init=False)
spec: inspect.FullArgSpec = field(init=False)
@ -48,7 +48,7 @@ class CallableMixin:
@dataclass
class FilterObject(CallableMixin):
class FilterObject(CallableObject):
magic: Optional[MagicFilter] = None
def __post_init__(self) -> None:
@ -75,7 +75,7 @@ class FilterObject(CallableMixin):
@dataclass
class HandlerObject(CallableMixin):
class HandlerObject(CallableObject):
filters: Optional[List[FilterObject]] = None
flags: Dict[str, Any] = field(default_factory=dict)

View file

@ -4,3 +4,4 @@ dispatcher = logging.getLogger("aiogram.dispatcher")
event = logging.getLogger("aiogram.event")
middlewares = logging.getLogger("aiogram.middlewares")
webhook = logging.getLogger("aiogram.webhook")
scene = logging.getLogger("aiogram.scene")

View file

@ -0,0 +1,15 @@
from __future__ import annotations
__all__ = [
"Scene",
"SceneRegistry",
"SceneManager",
"on",
]
from ._manager import SceneManager
from ._marker import OnMarker
from ._registry import SceneRegistry
from ._scene import Scene
on = OnMarker()

View file

@ -0,0 +1,73 @@
from dataclasses import replace, dataclass
from typing import Any, Dict, Optional
from aiogram import loggers
from aiogram.fsm.context import FSMContext
@dataclass
class StateContainer:
state: str
data: Dict[str, Any]
class HistoryManager:
def __init__(self, context: FSMContext, destiny: str = "history", size: int = 10):
self._size = size
self._context = context
self._history_context = FSMContext(
storage=context.storage, key=replace(context.key, destiny=destiny)
)
async def push(self, state: str, data: Dict[str, Any]) -> None:
history_data = await self._history_context.get_data()
history = history_data.setdefault("history", [])
history.append({"state": state, "data": data})
if len(history) > self._size:
history = history[-self._size :]
loggers.scene.debug("Push state=%s data=%s", state, data)
await self._history_context.update_data(history=history)
async def pop(self) -> Optional[StateContainer]:
history_data = await self._history_context.get_data()
history = history_data.setdefault("history", [])
if not history:
return None
record = history.pop()
state = record["state"]
data = record["data"]
await self._history_context.update_data(history=history)
loggers.scene.debug("Pop state=%s data=%s", state, data)
return StateContainer(state=state, data=data)
async def clear(self):
loggers.scene.debug("Clear history")
await self._history_context.clear()
async def get(self) -> list[StateContainer]:
history_data = await self._history_context.get_data()
history = history_data.setdefault("history", [])
return [StateContainer(**item) for item in history]
async def snapshot(self) -> None:
state = await self._context.get_state()
data = await self._context.get_data()
await self.push(state, data)
async def rollback(self) -> Optional[str]:
state_container = await self.pop()
if not state_container:
return None
state_container = await self.pop()
if not state_container:
return None
loggers.scene.debug(
"Rollback to state=%s data=%s",
state_container.state,
state_container.data,
)
await self._context.set_state(state_container.state)
await self._context.set_data(state_container.data)
return state_container.state

View file

@ -0,0 +1,79 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from aiogram.fsm.context import FSMContext
from aiogram.types import TelegramObject, Update
from ._history import HistoryManager
if TYPE_CHECKING:
from ._registry import SceneRegistry
from ._scene import Scene
class SceneManager:
def __init__(
self,
registry: SceneRegistry,
update: Update,
event: TelegramObject,
context: FSMContext,
data: dict[str, Any],
) -> None:
self.registry = registry
self.update = update
self.event = event
self.context = context
self.data = data
self._history = HistoryManager(self.context)
async def _get_scene(self, scene_type: type[Scene] | str) -> Scene:
scene_type = self.registry.get(scene_type)
return scene_type(
manager=self,
update=self.update,
event=self.event,
context=self.context,
data=self.data,
)
async def _get_active_scene(self) -> Scene | None:
state = await self.context.get_state()
if state is None:
return None
return await self._get_scene(state)
async def enter(self, scene_type: type[Scene] | str, **kwargs: Any) -> None:
active_scene = await self._get_active_scene()
if active_scene is not None:
await active_scene.leave(**kwargs)
scene = await self._get_scene(scene_type)
await scene.enter(**kwargs)
await self._history.snapshot()
async def leave(self, **kwargs: Any) -> None:
try:
scene = await self._get_active_scene()
except ValueError:
return
if not scene:
return
await scene.leave(**kwargs)
async def exit(self, **kwargs: Any) -> None:
try:
scene = await self._get_active_scene()
except ValueError:
return
if not scene:
return
await scene.exit(**kwargs)
await self._history.clear()
async def back(self, **kwargs: Any) -> None:
previous_state = await self._history.rollback()
if previous_state is not None:
await self.enter(previous_state, **kwargs)
else:
await self.exit(**kwargs)

195
aiogram/scenes/_marker.py Normal file
View file

@ -0,0 +1,195 @@
from __future__ import annotations
import inspect
from collections import defaultdict
from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, Tuple, Type, Union, Optional
from typing_extensions import Self
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.types import TelegramObject
if TYPE_CHECKING:
from ._scene import Scene
class ObserverMarker:
def __init__(self, name: str) -> None:
self.name = name
def __call__(
self,
*filters: CallbackType,
) -> ObserverDecorator:
return ObserverDecorator(
self.name,
filters,
)
def enter(self, *filters: CallbackType) -> ObserverDecorator:
return ObserverDecorator(self.name, filters, action=SceneAction.enter)
def leave(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.leave)
def exit(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.exit)
def back(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.back)
class ObserverDecorator:
def __init__(
self,
name: str,
filters: tuple[CallbackType, ...],
action: SceneAction | None = None,
) -> None:
self.name = name
self.filters = filters
self.action = action
def _wrap_class(self, target: Type[Scene]) -> None:
from ._scene import Scene
if not issubclass(target, Scene):
raise TypeError("Only subclass of Scene is allowed")
if self.action is not None:
raise TypeError("This action is not allowed for class")
filters = getattr(target, "__aiogram_filters__", None)
if filters is None:
filters = defaultdict(list)
setattr(target, "__aiogram_filters__", filters)
filters[self.name].extend(self.filters)
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
setattr(target, "__aiogram_handler__", True)
filters = getattr(target, "__aiogram_filters__", None)
if filters is None:
filters = defaultdict(list)
setattr(target, "__aiogram_filters__", filters)
filters[self.name].extend(self.filters)
def _wrap_action(self, target: Type[Scene] | CallbackType) -> None:
action = getattr(target, "__aiogram_action__", None)
if action is None:
action = defaultdict(dict)
setattr(target, "__aiogram_action__", action)
action[self.action][self.name] = CallableObject(target)
def __call__(self, target: Type[Scene] | CallbackType) -> Type[Scene] | CallbackType:
if inspect.isclass(target):
self._wrap_class(target)
elif inspect.isfunction(target):
if self.action is None:
self._wrap_filter(target)
else:
self._wrap_action(target)
return target
def leave(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.leave)
def enter(self, target: Type[Scene]) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.enter, target)
def exit(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.exit)
def back(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.back)
class SceneAction(Enum):
enter = auto()
leave = auto()
exit = auto()
back = auto()
class ActionContainer:
def __init__(
self,
name: str,
filters: Tuple[CallbackType, ...],
action: SceneAction,
target: Type[Scene] | None = None,
) -> None:
self.name = name
self.filters = filters
self.action = action
self.target = target
async def __call__(
self,
scene: Scene,
event: TelegramObject,
) -> None:
if self.action == SceneAction.enter and self.target is not None:
await scene.goto(self.target)
elif self.action == SceneAction.leave:
await scene.leave()
elif self.action == SceneAction.exit:
await scene.exit()
elif self.action == SceneAction.back:
await scene.back()
@property
def __aiogram_filters__(self) -> Dict[str, Tuple[CallbackType, ...]]:
return {self.name: self.filters}
def __await__(self) -> Self:
return self
class ControlActionContainer:
def __init__(self, name: str, action: SceneAction) -> None:
self.name = name
self.action = action
class OnMarker:
"""
The `_On` class is used as a marker class to define different types of events in the Scenes.
Attributes:
- :code:`message`: Event marker for handling `Message` events.
- :code:`edited_message`: Event marker for handling edited `Message` events.
- :code:`channel_post`: Event marker for handling channel `Post` events.
- :code:`edited_channel_post`: Event marker for handling edited channel `Post` events.
- :code:`inline_query`: Event marker for handling `InlineQuery` events.
- :code:`chosen_inline_result`: Event marker for handling chosen `InlineResult` events.
- :code:`callback_query`: Event marker for handling `CallbackQuery` events.
- :code:`shipping_query`: Event marker for handling `ShippingQuery` events.
- :code:`pre_checkout_query`: Event marker for handling `PreCheckoutQuery` events.
- :code:`poll`: Event marker for handling `Poll` events.
- :code:`poll_answer`: Event marker for handling `PollAnswer` events.
- :code:`my_chat_member`: Event marker for handling my chat `Member` events.
- :code:`chat_member`: Event marker for handling chat `Member` events.
- :code:`chat_join_request`: Event marker for handling chat `JoinRequest` events.
- :code:`error`: Event marker for handling `Error` events.
.. note::
This is a marker class and does not contain any methods or implementation logic.
"""
message = ObserverMarker("message")
edited_message = ObserverMarker("edited_message")
channel_post = ObserverMarker("channel_post")
edited_channel_post = ObserverMarker("edited_channel_post")
inline_query = ObserverMarker("inline_query")
chosen_inline_result = ObserverMarker("chosen_inline_result")
callback_query = ObserverMarker("callback_query")
shipping_query = ObserverMarker("shipping_query")
pre_checkout_query = ObserverMarker("pre_checkout_query")
poll = ObserverMarker("poll")
poll_answer = ObserverMarker("poll_answer")
my_chat_member = ObserverMarker("my_chat_member")
chat_member = ObserverMarker("chat_member")
chat_join_request = ObserverMarker("chat_join_request")

View file

@ -0,0 +1,85 @@
from __future__ import annotations
import inspect
from typing import Any, Dict, Optional, Type, Union
from aiogram import Dispatcher, Router
from ..dispatcher.event.bases import NextMiddlewareType
from ..types import TelegramObject
from ._manager import SceneManager
from ._scene import Scene
class SceneRegistry:
def __init__(self, router: Router) -> None:
self.router = router
for observer in router.observers.values():
if observer.event_name in {"update", "error"}:
continue
observer.outer_middleware(self._middleware)
self._scenes: Dict[str, Type[Scene]] = {}
async def _error_middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
data["scenes"] = SceneManager(
registry=self,
update=event.update,
event=event,
context=data["state"],
data=data,
)
return await handler(event, data)
async def _middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
data["scenes"] = SceneManager(
registry=self,
update=data["event_update"],
event=event,
context=data["state"],
data=data,
)
return await handler(event, data)
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
if router is None:
router = self.router
for scene in scenes:
if scene.__aiogram_scene_name__ in self._scenes:
raise ValueError(f"Scene {scene.__aiogram_scene_name__} already exists")
self._scenes[scene.__aiogram_scene_name__] = scene
router.include_router(scene.as_router())
def get(self, scene: Union[Type[Scene], str]) -> Type[Scene]:
if inspect.isclass(scene) and issubclass(scene, Scene):
target = scene.__aiogram_scene_name__
check_class = True
else:
target = scene
check_class = False
if not isinstance(target, str):
raise TypeError("Scene must be a string or subclass of Scene")
try:
result = self._scenes[target]
except KeyError:
raise ValueError(f"Scene {scene!r} is not registered")
if check_class and not issubclass(result, scene):
raise ValueError(f"Scene {scene!r} is not registered")
return result

207
aiogram/scenes/_scene.py Normal file
View file

@ -0,0 +1,207 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Tuple, Type, Union
from typing_extensions import Self
from aiogram import Router, loggers
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.filters import StateFilter
from aiogram.fsm.context import FSMContext
from aiogram.types import TelegramObject, Update
from ._marker import ActionContainer, SceneAction
if TYPE_CHECKING:
from ._manager import SceneManager
class _SceneMeta(type):
def __new__(
mcs,
name: str,
bases: Tuple[type],
namespace: Dict[str, Any],
**kwargs: Any,
) -> _SceneMeta:
state_name = kwargs.pop("state", f"{namespace['__module__']}:{name}")
aiogram_filters: defaultdict[str, List[CallbackType]] = defaultdict(list)
aiogram_handlers: list[CallbackType] = []
aiogram_actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
for base in bases:
if parent_aiogram_filters := getattr(base, "__aiogram_filters__", None):
aiogram_filters.update(parent_aiogram_filters)
if parent_aiogram_handlers := getattr(base, "__aiogram_handlers__", None):
aiogram_handlers.extend(parent_aiogram_handlers)
if parent_aiogram_actions := getattr(base, "__aiogram_actions__", None):
for action, handlers in parent_aiogram_actions.items():
aiogram_actions[action].update(handlers)
for name, value in namespace.items():
if hasattr(value, "__aiogram_handler__"):
aiogram_handlers.append(value)
elif isinstance(value, ActionContainer):
aiogram_handlers.append(value)
elif hasattr(value, "__aiogram_action__"):
for action, handlers in value.__aiogram_action__.items():
aiogram_actions[action].update(handlers)
namespace.update(
{
"__aiogram_scene_name__": state_name,
"__aiogram_filters__": aiogram_filters,
"__aiogram_handlers__": aiogram_handlers,
"__aiogram_actions__": aiogram_actions,
}
)
return super().__new__(mcs, name, bases, namespace, **kwargs)
class SceneHandlerWrapper:
def __init__(self, scene: Type[Scene], handler: CallbackType) -> None:
self.scene = scene
self.handler = CallableObject(handler)
async def __call__(
self,
event: TelegramObject,
state: FSMContext,
scenes: SceneManager,
event_update: Update,
**kwargs: Any,
) -> Any:
scene = self.scene(
manager=scenes,
update=event_update,
event=event,
context=state,
data=kwargs,
)
return await self.handler.call(
scene,
event,
state=state,
event_update=event_update,
scenes=scenes,
**kwargs,
)
def __await__(self) -> Self:
return self
class Scene(metaclass=_SceneMeta):
__aiogram_scene_name__: ClassVar[str]
__aiogram_filters__: ClassVar[Dict[str, List[CallbackType]]]
__aiogram_handlers__: ClassVar[List[CallbackType]]
__aiogram_actions__: ClassVar[Dict[SceneAction, Dict[str, CallableObject]]]
def __init__(
self,
manager: SceneManager,
update: Update,
event: TelegramObject,
context: FSMContext,
data: Dict[str, Any],
) -> None:
self.manager = manager
self.update = update
self.event = event
self.context = context
self.data = data
@classmethod
def as_router(cls) -> Router:
router = Router(name=cls.__aiogram_scene_name__)
used_observers = set()
for observer, filters in cls.__aiogram_filters__.items():
router.observers[observer].filter(*filters)
used_observers.add(observer)
for handler in cls.__aiogram_handlers__:
handler_filters = getattr(handler, "__aiogram_filters__", None)
if not handler_filters:
continue
for observer, filters in handler_filters.items():
router.observers[observer].register(SceneHandlerWrapper(cls, handler), *filters)
used_observers.add(observer)
for observer in used_observers:
router.observers[observer].filter(StateFilter(cls.__aiogram_scene_name__))
return router
@classmethod
def as_handler(cls) -> CallbackType:
async def enter_to_scene_handler(event: TelegramObject, scenes: SceneManager) -> None:
await scenes.enter(cls)
return enter_to_scene_handler
async def enter(self, **kwargs: Any) -> None:
loggers.scene.debug("Entering scene %s", self.__aiogram_scene_name__)
state = await self.context.get_state()
await self.context.set_state(self.__aiogram_scene_name__)
try:
if not await self._on_action(SceneAction.enter, **kwargs):
loggers.scene.error(
"Enter action not found in scene %s for event %r", self, type(self.event)
)
except Exception as e:
await self.context.set_state(state)
raise e
async def leave(self, **kwargs: Any) -> None:
loggers.scene.debug("Leaving scene %s", self.__aiogram_scene_name__)
state = await self.context.get_state()
await self.context.set_state(None)
try:
await self._on_action(SceneAction.leave, **kwargs)
except Exception as e:
await self.context.set_state(state)
raise e
async def exit(self, **kwargs: Any) -> None:
loggers.scene.debug("Exiting scene %s", self.__aiogram_scene_name__)
state = await self.context.get_state()
await self.context.set_state(None)
try:
await self._on_action(SceneAction.exit, **kwargs)
except Exception as e:
await self.context.set_state(state)
raise e
async def back(self, **kwargs: Any) -> None:
loggers.scene.debug("Back to previous scene from scene %s", self.__aiogram_scene_name__)
await self.manager.back(**kwargs)
async def replay(self, event: TelegramObject) -> None:
await self._on_action(SceneAction.enter, event=event)
async def goto(self, scene: Union[Type[Scene], str]) -> None:
await self.manager.enter(scene)
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
loggers.scene.debug("Call action %r in scene %r", action.name, self.__aiogram_scene_name__)
action_config = self.__aiogram_actions__.get(action, {})
if not action_config:
loggers.scene.debug(
"Action %r not found in scene %r", action.name, self.__aiogram_scene_name__
)
return False
event_type = self.update.event_type
if event_type not in action_config:
loggers.scene.debug(
"Action %r for event %r not found in scene %r",
action.name,
event_type,
self.__aiogram_scene_name__,
)
return False
await action_config[event_type].call(self, self.event, **{**self.data, **kwargs})
return True