From c999964f11bc3e6174fdaca16c82503152b94c06 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Fri, 13 Oct 2023 00:00:36 +0300 Subject: [PATCH] Small refactoring --- aiogram/dispatcher/router.py | 30 +++++++++++++- aiogram/fsm/scene.py | 76 ++++++++++++++++++------------------ examples/quiz_scene.py | 3 +- 3 files changed, 67 insertions(+), 42 deletions(-) diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index 4dd4284f..bb5a8c8b 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Any, Dict, Final, Generator, List, Optional, Set +from typing import Any, Dict, Final, Generator, List, Optional, Set, Type, TYPE_CHECKING -from ..types import TelegramObject from .event.bases import REJECTED, UNHANDLED from .event.event import EventObserver from .event.telegram import TelegramEventObserver +from ..types import TelegramObject + +if TYPE_CHECKING: + from ..fsm.scene import Scene INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"}) @@ -218,6 +221,29 @@ class Router: router.parent_router = self return router + def include_scenes(self, *scenes: Type[Scene], name: Optional[str] = None) -> None: + """ + Include multiple scenes to this router as sub-routers + + :param scenes: scene instances + :param name: optional name for router + :return: + """ + if not scenes: + raise ValueError("At least one scene must be provided") + for scene in scenes: + self.include_scene(scene, name=name) + + def include_scene(self, scene: Type[Scene], name: Optional[str] = None) -> None: + """ + Include a scene to this router as sub-router + + :param scene: scene instance + :param name: optional name for router + :return: + """ + self.include_router(scene.as_router(name=name)) + async def emit_startup(self, *args: Any, **kwargs: Any) -> None: """ Recursively call startup callbacks diff --git a/aiogram/fsm/scene.py b/aiogram/fsm/scene.py index 08ea5b2f..e47ced33 100644 --- a/aiogram/fsm/scene.py +++ b/aiogram/fsm/scene.py @@ -8,9 +8,12 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union from typing_extensions import Self -from aiogram import Dispatcher, Router, loggers +from aiogram import loggers +from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.event.bases import NextMiddlewareType from aiogram.dispatcher.event.handler import CallableObject, CallbackType +from aiogram.dispatcher.flags import extract_flags_from_object +from aiogram.dispatcher.router import Router from aiogram.exceptions import SceneException from aiogram.filters import StateFilter from aiogram.fsm.context import FSMContext @@ -107,18 +110,6 @@ class ObserverDecorator: self.action = action self.after = after - def _wrap_class(self, target: Type[Scene]) -> None: - if not issubclass(target, Scene): - raise TypeError("Only subclass of Scene is allowed") - if self.action is not None: - raise TypeError("This action is not allowed for class") - - filters = getattr(target, "__aiogram_filters__", None) - if filters is None: - filters = defaultdict(list) - setattr(target, "__aiogram_filters__", filters) - filters[self.name].extend(self.filters) - def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None: handlers = getattr(target, "__aiogram_handler__", None) if not handlers: @@ -134,21 +125,21 @@ class ObserverDecorator: ) ) - def _wrap_action(self, target: Type[Scene] | CallbackType) -> None: + def _wrap_action(self, target: CallbackType) -> None: action = getattr(target, "__aiogram_action__", None) if action is None: action = defaultdict(dict) setattr(target, "__aiogram_action__", action) action[self.action][self.name] = CallableObject(target) - def __call__(self, target: Type[Scene] | CallbackType) -> Type[Scene] | CallbackType: - if inspect.isclass(target): - self._wrap_class(target) - elif inspect.isfunction(target): + def __call__(self, target: CallbackType) -> CallbackType: + if inspect.isfunction(target): if self.action is None: self._wrap_filter(target) else: self._wrap_action(target) + else: + raise TypeError("Only function or method is allowed") return target def leave(self) -> ActionContainer: @@ -213,8 +204,6 @@ class HandlerContainer: class SceneConfig: state: Optional[str] """Scene state""" - filters: Dict[str, List[CallbackType]] - """Global scene filters""" handlers: List[HandlerContainer] """Scene handlers""" actions: Dict[SceneAction, Dict[str, CallableObject]] @@ -316,7 +305,6 @@ class Scene: super().__init_subclass__(**kwargs) - filters: defaultdict[str, List[CallbackType]] = defaultdict(list) handlers: list[HandlerContainer] = [] actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict) @@ -328,7 +316,6 @@ class Scene: if not parent_scene_config: continue - filters.update(parent_scene_config.filters) handlers.extend(parent_scene_config.handlers) for action, action_handlers in parent_scene_config.actions.items(): actions[action].update(action_handlers) @@ -360,7 +347,6 @@ class Scene: cls.__scene_config__ = SceneConfig( state=state_name, - filters=dict(filters), handlers=handlers, actions=dict(actions), reset_data_on_enter=reset_data_on_enter, @@ -379,10 +365,6 @@ class Scene: scene_config = cls.__scene_config__ used_observers = set() - for observer_name, filters in scene_config.filters.items(): - router.observers[observer_name].filter(*filters) - used_observers.add(observer_name) - for handler in scene_config.handlers: router.observers[handler.name].register( SceneHandlerWrapper( @@ -391,6 +373,7 @@ class Scene: after=handler.after, ), *handler.filters, + flags=extract_flags_from_object(handler.handler), ) used_observers.add(handler.name) @@ -706,15 +689,17 @@ class SceneRegistry: A class that represents a registry for scenes in a Telegram bot. """ - def __init__(self, router: Router) -> None: + def __init__(self, router: Router, register_on_add: bool = False) -> None: """ Initialize a new instance of the SceneRegistry class. :param router: The router instance used for scene registration. + :param register_on_add: Whether to register the scenes to the router when they are added. """ self.router = router - self._scenes: Dict[Optional[str], Type[Scene]] = {} + self.register_on_add = register_on_add + self._scenes: Dict[Optional[str], Type[Scene]] = {} self._setup_middleware(router) def _setup_middleware(self, router: Router) -> None: @@ -764,23 +749,24 @@ class SceneRegistry: def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None: """ - This method adds the specified scenes to the router. - If a router is not provided, it uses the default router stored - in the SceneRegistry instance. - The scenes are included in the router by calling the `as_router()` - method on each scene and passing the router as a parameter to this method. + This method adds the specified scenes to the registry + and optionally registers it to the router. If a scene with the same state already exists in the registry, a SceneException is raised. + .. warning:: + + If the router is not specified, the scenes will not be registered to the router. + You will need to include the scenes manually to the router or use the register method. + :param scenes: A variable length parameter that accepts one or more types of scenes. These scenes are instances of the Scene class. :param router: An optional parameter that specifies the router - to which the scenes should be added. If not provided, the scenes will be - added to the default router stored in the SceneRegistry instance. + to which the scenes should be added. :return: None """ - if router is None: - router = self.router + if not scenes: + raise ValueError("At least one scene must be specified") for scene in scenes: if scene.__scene_config__.state in self._scenes: @@ -790,7 +776,19 @@ class SceneRegistry: self._scenes[scene.__scene_config__.state] = scene - router.include_router(scene.as_router()) + if router: + router.include_router(scene.as_router()) + elif self.register_on_add: + self.router.include_router(scene.as_router()) + + def register(self, *scenes: Type[Scene]) -> None: + """ + Registers one or more scenes to the SceneRegistry. + + :param scenes: One or more scene classes to register. + :return: None + """ + self.add(*scenes, router=self.router) def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]: """ diff --git a/examples/quiz_scene.py b/examples/quiz_scene.py index aafdbd02..30eb71c6 100644 --- a/examples/quiz_scene.py +++ b/examples/quiz_scene.py @@ -252,6 +252,7 @@ class QuizScene(Scene, state="quiz"): quiz_router = Router(name=__name__) # Add handler that initializes the scene quiz_router.message.register(QuizScene.as_handler(), Command("quiz")) +quiz_router.include_scene(QuizScene) @quiz_router.message(Command("start")) @@ -275,7 +276,7 @@ def create_dispatcher(): # ... and then register a scene in the registry # by default, Scene will be mounted to the router that passed to the SceneRegistry, # but you can specify the router explicitly using the `router` argument - scene_registry.add(QuizScene, router=quiz_router) + scene_registry.add(QuizScene) return dispatcher