mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Small refactoring
This commit is contained in:
parent
bda7fcd13b
commit
c999964f11
3 changed files with 67 additions and 42 deletions
|
|
@ -1,11 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Final, Generator, List, Optional, Set
|
||||
from typing import Any, Dict, Final, Generator, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from ..types import TelegramObject
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
from .event.event import EventObserver
|
||||
from .event.telegram import TelegramEventObserver
|
||||
from ..types import TelegramObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fsm.scene import Scene
|
||||
|
||||
INTERNAL_UPDATE_TYPES: Final[frozenset[str]] = frozenset({"update", "error"})
|
||||
|
||||
|
|
@ -218,6 +221,29 @@ class Router:
|
|||
router.parent_router = self
|
||||
return router
|
||||
|
||||
def include_scenes(self, *scenes: Type[Scene], name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Include multiple scenes to this router as sub-routers
|
||||
|
||||
:param scenes: scene instances
|
||||
:param name: optional name for router
|
||||
:return:
|
||||
"""
|
||||
if not scenes:
|
||||
raise ValueError("At least one scene must be provided")
|
||||
for scene in scenes:
|
||||
self.include_scene(scene, name=name)
|
||||
|
||||
def include_scene(self, scene: Type[Scene], name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Include a scene to this router as sub-router
|
||||
|
||||
:param scene: scene instance
|
||||
:param name: optional name for router
|
||||
:return:
|
||||
"""
|
||||
self.include_router(scene.as_router(name=name))
|
||||
|
||||
async def emit_startup(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursively call startup callbacks
|
||||
|
|
|
|||
|
|
@ -8,9 +8,12 @@ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
|
|||
|
||||
from typing_extensions import Self
|
||||
|
||||
from aiogram import Dispatcher, Router, loggers
|
||||
from aiogram import loggers
|
||||
from aiogram.dispatcher.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher.event.bases import NextMiddlewareType
|
||||
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
|
||||
from aiogram.dispatcher.flags import extract_flags_from_object
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.exceptions import SceneException
|
||||
from aiogram.filters import StateFilter
|
||||
from aiogram.fsm.context import FSMContext
|
||||
|
|
@ -107,18 +110,6 @@ class ObserverDecorator:
|
|||
self.action = action
|
||||
self.after = after
|
||||
|
||||
def _wrap_class(self, target: Type[Scene]) -> None:
|
||||
if not issubclass(target, Scene):
|
||||
raise TypeError("Only subclass of Scene is allowed")
|
||||
if self.action is not None:
|
||||
raise TypeError("This action is not allowed for class")
|
||||
|
||||
filters = getattr(target, "__aiogram_filters__", None)
|
||||
if filters is None:
|
||||
filters = defaultdict(list)
|
||||
setattr(target, "__aiogram_filters__", filters)
|
||||
filters[self.name].extend(self.filters)
|
||||
|
||||
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
|
||||
handlers = getattr(target, "__aiogram_handler__", None)
|
||||
if not handlers:
|
||||
|
|
@ -134,21 +125,21 @@ class ObserverDecorator:
|
|||
)
|
||||
)
|
||||
|
||||
def _wrap_action(self, target: Type[Scene] | CallbackType) -> None:
|
||||
def _wrap_action(self, target: CallbackType) -> None:
|
||||
action = getattr(target, "__aiogram_action__", None)
|
||||
if action is None:
|
||||
action = defaultdict(dict)
|
||||
setattr(target, "__aiogram_action__", action)
|
||||
action[self.action][self.name] = CallableObject(target)
|
||||
|
||||
def __call__(self, target: Type[Scene] | CallbackType) -> Type[Scene] | CallbackType:
|
||||
if inspect.isclass(target):
|
||||
self._wrap_class(target)
|
||||
elif inspect.isfunction(target):
|
||||
def __call__(self, target: CallbackType) -> CallbackType:
|
||||
if inspect.isfunction(target):
|
||||
if self.action is None:
|
||||
self._wrap_filter(target)
|
||||
else:
|
||||
self._wrap_action(target)
|
||||
else:
|
||||
raise TypeError("Only function or method is allowed")
|
||||
return target
|
||||
|
||||
def leave(self) -> ActionContainer:
|
||||
|
|
@ -213,8 +204,6 @@ class HandlerContainer:
|
|||
class SceneConfig:
|
||||
state: Optional[str]
|
||||
"""Scene state"""
|
||||
filters: Dict[str, List[CallbackType]]
|
||||
"""Global scene filters"""
|
||||
handlers: List[HandlerContainer]
|
||||
"""Scene handlers"""
|
||||
actions: Dict[SceneAction, Dict[str, CallableObject]]
|
||||
|
|
@ -316,7 +305,6 @@ class Scene:
|
|||
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
filters: defaultdict[str, List[CallbackType]] = defaultdict(list)
|
||||
handlers: list[HandlerContainer] = []
|
||||
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
|
||||
|
||||
|
|
@ -328,7 +316,6 @@ class Scene:
|
|||
if not parent_scene_config:
|
||||
continue
|
||||
|
||||
filters.update(parent_scene_config.filters)
|
||||
handlers.extend(parent_scene_config.handlers)
|
||||
for action, action_handlers in parent_scene_config.actions.items():
|
||||
actions[action].update(action_handlers)
|
||||
|
|
@ -360,7 +347,6 @@ class Scene:
|
|||
|
||||
cls.__scene_config__ = SceneConfig(
|
||||
state=state_name,
|
||||
filters=dict(filters),
|
||||
handlers=handlers,
|
||||
actions=dict(actions),
|
||||
reset_data_on_enter=reset_data_on_enter,
|
||||
|
|
@ -379,10 +365,6 @@ class Scene:
|
|||
scene_config = cls.__scene_config__
|
||||
used_observers = set()
|
||||
|
||||
for observer_name, filters in scene_config.filters.items():
|
||||
router.observers[observer_name].filter(*filters)
|
||||
used_observers.add(observer_name)
|
||||
|
||||
for handler in scene_config.handlers:
|
||||
router.observers[handler.name].register(
|
||||
SceneHandlerWrapper(
|
||||
|
|
@ -391,6 +373,7 @@ class Scene:
|
|||
after=handler.after,
|
||||
),
|
||||
*handler.filters,
|
||||
flags=extract_flags_from_object(handler.handler),
|
||||
)
|
||||
used_observers.add(handler.name)
|
||||
|
||||
|
|
@ -706,15 +689,17 @@ class SceneRegistry:
|
|||
A class that represents a registry for scenes in a Telegram bot.
|
||||
"""
|
||||
|
||||
def __init__(self, router: Router) -> None:
|
||||
def __init__(self, router: Router, register_on_add: bool = False) -> None:
|
||||
"""
|
||||
Initialize a new instance of the SceneRegistry class.
|
||||
|
||||
:param router: The router instance used for scene registration.
|
||||
:param register_on_add: Whether to register the scenes to the router when they are added.
|
||||
"""
|
||||
self.router = router
|
||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
||||
self.register_on_add = register_on_add
|
||||
|
||||
self._scenes: Dict[Optional[str], Type[Scene]] = {}
|
||||
self._setup_middleware(router)
|
||||
|
||||
def _setup_middleware(self, router: Router) -> None:
|
||||
|
|
@ -764,23 +749,24 @@ class SceneRegistry:
|
|||
|
||||
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
|
||||
"""
|
||||
This method adds the specified scenes to the router.
|
||||
If a router is not provided, it uses the default router stored
|
||||
in the SceneRegistry instance.
|
||||
The scenes are included in the router by calling the `as_router()`
|
||||
method on each scene and passing the router as a parameter to this method.
|
||||
This method adds the specified scenes to the registry
|
||||
and optionally registers it to the router.
|
||||
|
||||
If a scene with the same state already exists in the registry, a SceneException is raised.
|
||||
|
||||
.. warning::
|
||||
|
||||
If the router is not specified, the scenes will not be registered to the router.
|
||||
You will need to include the scenes manually to the router or use the register method.
|
||||
|
||||
:param scenes: A variable length parameter that accepts one or more types of scenes.
|
||||
These scenes are instances of the Scene class.
|
||||
:param router: An optional parameter that specifies the router
|
||||
to which the scenes should be added. If not provided, the scenes will be
|
||||
added to the default router stored in the SceneRegistry instance.
|
||||
to which the scenes should be added.
|
||||
:return: None
|
||||
"""
|
||||
if router is None:
|
||||
router = self.router
|
||||
if not scenes:
|
||||
raise ValueError("At least one scene must be specified")
|
||||
|
||||
for scene in scenes:
|
||||
if scene.__scene_config__.state in self._scenes:
|
||||
|
|
@ -790,7 +776,19 @@ class SceneRegistry:
|
|||
|
||||
self._scenes[scene.__scene_config__.state] = scene
|
||||
|
||||
router.include_router(scene.as_router())
|
||||
if router:
|
||||
router.include_router(scene.as_router())
|
||||
elif self.register_on_add:
|
||||
self.router.include_router(scene.as_router())
|
||||
|
||||
def register(self, *scenes: Type[Scene]) -> None:
|
||||
"""
|
||||
Registers one or more scenes to the SceneRegistry.
|
||||
|
||||
:param scenes: One or more scene classes to register.
|
||||
:return: None
|
||||
"""
|
||||
self.add(*scenes, router=self.router)
|
||||
|
||||
def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -252,6 +252,7 @@ class QuizScene(Scene, state="quiz"):
|
|||
quiz_router = Router(name=__name__)
|
||||
# Add handler that initializes the scene
|
||||
quiz_router.message.register(QuizScene.as_handler(), Command("quiz"))
|
||||
quiz_router.include_scene(QuizScene)
|
||||
|
||||
|
||||
@quiz_router.message(Command("start"))
|
||||
|
|
@ -275,7 +276,7 @@ def create_dispatcher():
|
|||
# ... and then register a scene in the registry
|
||||
# by default, Scene will be mounted to the router that passed to the SceneRegistry,
|
||||
# but you can specify the router explicitly using the `router` argument
|
||||
scene_registry.add(QuizScene, router=quiz_router)
|
||||
scene_registry.add(QuizScene)
|
||||
|
||||
return dispatcher
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue