Small refactoring

This commit is contained in:
Alex Root Junior 2023-10-13 00:00:36 +03:00
parent bda7fcd13b
commit c999964f11
No known key found for this signature in database
GPG key ID: 074C1D455EBEA4AC
3 changed files with 67 additions and 42 deletions

View file

@ -1,11 +1,14 @@
from __future__ import annotations
from typing import Any, Dict, Final, Generator, List, Optional, Set
from typing import Any, Dict, Final, Generator, List, Optional, Set, Type, TYPE_CHECKING
from ..types import TelegramObject
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
from .event.telegram import TelegramEventObserver
from ..types import TelegramObject
if TYPE_CHECKING:
from ..fsm.scene import Scene
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
@ -218,6 +221,29 @@ class Router:
router.parent_router = self
return router
def include_scenes(self, *scenes: Type[Scene], name: Optional[str] = None) -> None:
"""
Include multiple scenes to this router as sub-routers
:param scenes: scene instances
:param name: optional name for router
:return:
"""
if not scenes:
raise ValueError("At least one scene must be provided")
for scene in scenes:
self.include_scene(scene, name=name)
def include_scene(self, scene: Type[Scene], name: Optional[str] = None) -> None:
"""
Include a scene to this router as sub-router
:param scene: scene instance
:param name: optional name for router
:return:
"""
self.include_router(scene.as_router(name=name))
async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
"""
Recursively call startup callbacks

View file

@ -8,9 +8,12 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing_extensions import Self
from aiogram import Dispatcher, Router, loggers
from aiogram import loggers
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.dispatcher.router import Router
from aiogram.exceptions import SceneException
from aiogram.filters import StateFilter
from aiogram.fsm.context import FSMContext
@ -107,18 +110,6 @@ class ObserverDecorator:
self.action = action
self.after = after
def _wrap_class(self, target: Type[Scene]) -> None:
if not issubclass(target, Scene):
raise TypeError("Only subclass of Scene is allowed")
if self.action is not None:
raise TypeError("This action is not allowed for class")
filters = getattr(target, "__aiogram_filters__", None)
if filters is None:
filters = defaultdict(list)
setattr(target, "__aiogram_filters__", filters)
filters[self.name].extend(self.filters)
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
handlers = getattr(target, "__aiogram_handler__", None)
if not handlers:
@ -134,21 +125,21 @@ class ObserverDecorator:
)
)
def _wrap_action(self, target: Type[Scene] | CallbackType) -> None:
def _wrap_action(self, target: CallbackType) -> None:
action = getattr(target, "__aiogram_action__", None)
if action is None:
action = defaultdict(dict)
setattr(target, "__aiogram_action__", action)
action[self.action][self.name] = CallableObject(target)
def __call__(self, target: Type[Scene] | CallbackType) -> Type[Scene] | CallbackType:
if inspect.isclass(target):
self._wrap_class(target)
elif inspect.isfunction(target):
def __call__(self, target: CallbackType) -> CallbackType:
if inspect.isfunction(target):
if self.action is None:
self._wrap_filter(target)
else:
self._wrap_action(target)
else:
raise TypeError("Only function or method is allowed")
return target
def leave(self) -> ActionContainer:
@ -213,8 +204,6 @@ class HandlerContainer:
class SceneConfig:
state: Optional[str]
"""Scene state"""
filters: Dict[str, List[CallbackType]]
"""Global scene filters"""
handlers: List[HandlerContainer]
"""Scene handlers"""
actions: Dict[SceneAction, Dict[str, CallableObject]]
@ -316,7 +305,6 @@ class Scene:
super().__init_subclass__(**kwargs)
filters: defaultdict[str, List[CallbackType]] = defaultdict(list)
handlers: list[HandlerContainer] = []
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
@ -328,7 +316,6 @@ class Scene:
if not parent_scene_config:
continue
filters.update(parent_scene_config.filters)
handlers.extend(parent_scene_config.handlers)
for action, action_handlers in parent_scene_config.actions.items():
actions[action].update(action_handlers)
@ -360,7 +347,6 @@ class Scene:
cls.__scene_config__ = SceneConfig(
state=state_name,
filters=dict(filters),
handlers=handlers,
actions=dict(actions),
reset_data_on_enter=reset_data_on_enter,
@ -379,10 +365,6 @@ class Scene:
scene_config = cls.__scene_config__
used_observers = set()
for observer_name, filters in scene_config.filters.items():
router.observers[observer_name].filter(*filters)
used_observers.add(observer_name)
for handler in scene_config.handlers:
router.observers[handler.name].register(
SceneHandlerWrapper(
@ -391,6 +373,7 @@ class Scene:
after=handler.after,
),
*handler.filters,
flags=extract_flags_from_object(handler.handler),
)
used_observers.add(handler.name)
@ -706,15 +689,17 @@ class SceneRegistry:
A class that represents a registry for scenes in a Telegram bot.
"""
def __init__(self, router: Router) -> None:
def __init__(self, router: Router, register_on_add: bool = False) -> None:
"""
Initialize a new instance of the SceneRegistry class.
:param router: The router instance used for scene registration.
:param register_on_add: Whether to register the scenes to the router when they are added.
"""
self.router = router
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self.register_on_add = register_on_add
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None:
@ -764,23 +749,24 @@ class SceneRegistry:
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
"""
This method adds the specified scenes to the router.
If a router is not provided, it uses the default router stored
in the SceneRegistry instance.
The scenes are included in the router by calling the `as_router()`
method on each scene and passing the router as a parameter to this method.
This method adds the specified scenes to the registry
and optionally registers it to the router.
If a scene with the same state already exists in the registry, a SceneException is raised.
.. warning::
If the router is not specified, the scenes will not be registered to the router.
You will need to include the scenes manually to the router or use the register method.
:param scenes: A variable length parameter that accepts one or more types of scenes.
These scenes are instances of the Scene class.
:param router: An optional parameter that specifies the router
to which the scenes should be added. If not provided, the scenes will be
added to the default router stored in the SceneRegistry instance.
to which the scenes should be added.
:return: None
"""
if router is None:
router = self.router
if not scenes:
raise ValueError("At least one scene must be specified")
for scene in scenes:
if scene.__scene_config__.state in self._scenes:
@ -790,7 +776,19 @@ class SceneRegistry:
self._scenes[scene.__scene_config__.state] = scene
router.include_router(scene.as_router())
if router:
router.include_router(scene.as_router())
elif self.register_on_add:
self.router.include_router(scene.as_router())
def register(self, *scenes: Type[Scene]) -> None:
"""
Registers one or more scenes to the SceneRegistry.
:param scenes: One or more scene classes to register.
:return: None
"""
self.add(*scenes, router=self.router)
def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]:
"""

View file

@ -252,6 +252,7 @@ class QuizScene(Scene, state="quiz"):
quiz_router = Router(name=__name__)
# Add handler that initializes the scene
quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
quiz_router.include_scene(QuizScene)
@quiz_router.message(Command("start"))
@ -275,7 +276,7 @@ def create_dispatcher():
# ... and then register a scene in the registry
# by default, Scene will be mounted to the router that passed to the SceneRegistry,
# but you can specify the router explicitly using the `router` argument
scene_registry.add(QuizScene, router=quiz_router)
scene_registry.add(QuizScene)
return dispatcher