From 86591320056dee5d1870121181d42785f939813c Mon Sep 17 00:00:00 2001 From: drforse Date: Thu, 16 Dec 2021 02:53:01 +0300 Subject: [PATCH] add feature of autoregistering routers --- aiogram/dispatcher/router.py | 15 ++++- aiogram/utils/imports.py | 34 +++++++++- aiogram/utils/mixins.py | 15 ++++- aiogram/utils/routers.py | 64 +++++++++++++++++++ examples/routers_autoload/__init__.py | 0 examples/routers_autoload/__main__.py | 18 ++++++ examples/routers_autoload/config.py | 3 + .../routers_autoload/handlers/__init__.py | 0 .../handlers/callback_query.py | 9 +++ .../handlers/message/__init__.py | 0 .../routers_autoload/handlers/message/echo.py | 10 +++ .../handlers/message/start.py | 9 +++ tests/modules_for_tests/__init__.py | 4 ++ tests/modules_for_tests/small_module.py | 4 ++ .../small_package/__init__.py | 4 ++ .../small_package/nested_small_module.py | 4 ++ tests/test_utils/test_mixins.py | 21 +++++- tests/test_utils/test_routers.py | 27 ++++++++ 18 files changed, 235 insertions(+), 6 deletions(-) create mode 100644 aiogram/utils/routers.py create mode 100644 examples/routers_autoload/__init__.py create mode 100644 examples/routers_autoload/__main__.py create mode 100644 examples/routers_autoload/config.py create mode 100644 examples/routers_autoload/handlers/__init__.py create mode 100644 examples/routers_autoload/handlers/callback_query.py create mode 100644 examples/routers_autoload/handlers/message/__init__.py create mode 100644 examples/routers_autoload/handlers/message/echo.py create mode 100644 examples/routers_autoload/handlers/message/start.py create mode 100644 tests/modules_for_tests/__init__.py create mode 100644 tests/modules_for_tests/small_module.py create mode 100644 tests/modules_for_tests/small_package/__init__.py create mode 100644 tests/modules_for_tests/small_package/nested_small_module.py create mode 100644 tests/test_utils/test_routers.py diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py index c9e7dea9..d3ee1cc8 100644 --- a/aiogram/dispatcher/router.py +++ b/aiogram/dispatcher/router.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union from ..types import TelegramObject from ..utils.imports import import_module +from ..utils.mixins import KeepRefsMixin from ..utils.warnings import CodeHasNoEffect from .event.bases import REJECTED, UNHANDLED from .event.event import EventObserver @@ -14,7 +15,7 @@ from .filters import BUILTIN_FILTERS INTERNAL_UPDATE_TYPES = frozenset({"update", "error"}) -class Router: +class Router(KeepRefsMixin): """ Router can route update, and it nested update types like messages, callback query, polls and all other event types. @@ -25,15 +26,23 @@ class Router: - By decorator - :obj:`@router.()` """ - def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None: + def __init__( + self, + use_builtin_filters: bool = True, + name: Optional[str] = None, + index: int = None + ) -> None: """ :param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory :param name: Optional router name, can be useful for debugging + :param index: used only for ordering in utils.routers.find_all_routers """ + super().__init__() self.use_builtin_filters = use_builtin_filters self.name = name or hex(id(self)) + self.index = index self._parent_router: Optional[Router] = None self.sub_routers: List[Router] = [] @@ -90,7 +99,7 @@ class Router: observer.bind_filter(builtin_filter) def __str__(self) -> str: - return f"{type(self).__name__} {self.name!r}" + return f"{type(self).__name__} {self.name!r} {self.index if self.index is not None else ''}" def __repr__(self) -> str: return f"<{self}>" diff --git a/aiogram/utils/imports.py b/aiogram/utils/imports.py index edc0a6a0..23f5209e 100644 --- a/aiogram/utils/imports.py +++ b/aiogram/utils/imports.py @@ -1,5 +1,6 @@ import importlib -from typing import Any +from typing import Any, Set +from pathlib import Path def import_module(target: str) -> Any: @@ -21,3 +22,34 @@ def import_module(target: str) -> Any: raise ValueError(f'Module "{module_name}" has no attribute "{attr_name}"') return attribute + + +DEFAULT_EXCLUDE_MODULES = frozenset({"__pycache__", "__init__.py", "__main__.py"}) + + +def import_all_modules( + root: str, + package: str = None, + exclude: Set[str] = DEFAULT_EXCLUDE_MODULES +): + """ + imports all modules inside root and inside all subdirectories, sub-subdirectories etc of root + + The 'package' argument is required when performing a relative import. It + specifies the package to use as the anchor point from which to resolve the + relative import to an absolute import. + + :param root: root directory where function will start importing and digging to subdirectories + :param package: your top-level package name if 'root' is not absolute (starts with .) + :param exclude: set of names that will be ignored, + if it is a directory - also doesn't iterate over its insides + """ + root_module = importlib.import_module(root, package) + root_dir = root_module.__path__[0] + for sub_dir in Path(root_dir).iterdir(): + if sub_dir.name in exclude: + continue + if sub_dir.is_dir(): + import_all_modules(f"{root}.{sub_dir.stem}", package) + else: + importlib.import_module(f"{root}.{sub_dir.stem}", package) diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py index 80f5afe9..cea9b5fd 100644 --- a/aiogram/utils/mixins.py +++ b/aiogram/utils/mixins.py @@ -1,12 +1,14 @@ from __future__ import annotations import contextvars +import weakref +from collections import defaultdict from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload if TYPE_CHECKING: from typing_extensions import Literal -__all__ = ("ContextInstanceMixin", "DataMixin") +__all__ = ("ContextInstanceMixin", "DataMixin", "KeepRefsMixin") class DataMixin: @@ -93,3 +95,14 @@ class ContextInstanceMixin(Generic[ContextInstance]): @classmethod def reset_current(cls, token: contextvars.Token[ContextInstance]) -> None: cls.__context_instance.reset(token) + + +class KeepRefsMixin: + __refs__ = defaultdict(weakref.WeakSet) + + def __init__(self): + self.__refs__[self.__class__].add(self) + + @classmethod + def get_instances(cls) -> weakref.WeakSet: + return cls.__refs__[cls] diff --git a/aiogram/utils/routers.py b/aiogram/utils/routers.py new file mode 100644 index 00000000..f2c049ac --- /dev/null +++ b/aiogram/utils/routers.py @@ -0,0 +1,64 @@ +from typing import Set, List + +from aiogram import Router +from aiogram.utils.imports import DEFAULT_EXCLUDE_MODULES, import_all_modules + + +def find_all_routers( + root: str, + order_by_index: bool = True, + package: str = None, + exclude_modules: Set[str] = DEFAULT_EXCLUDE_MODULES +) -> List[Router]: + """ + if order_by_index is True: routers are ordered by index applying following rules: + - indexes as in list: 0, 1, 2, ... -3, -2, -1 + - if index is None - it is going between positive and negative numbers, also filling in + empty slots in positives, ordered as was ordered in WeakSet + ex.: 0, None, 2, None, 5, None, -1 + ex.: 0, 1, 2, 3, 4, 5, None, -1 + ex.: 0, 1, 2, 3, 4, 5, None, None, -3, -1 + + :param root: root directory where function will start importing and digging to subdirectories + :param order_by_index: + :param package: your top-level package name if 'root' is not absolute (starts with .) + :param exclude_modules: set of names that will be ignored, + if it is a directory - also doesn't iterate over its insides + """ + import_all_modules(root, package, exclude_modules) + routers = list(Router.get_instances()) + return routers if order_by_index is False else _order_routers(routers) + + +def _order_routers(routers: 'List[Router]') -> List[Router]: + unordered = [] + ordered_routers = {} + negative_ordered_routers = [] + checked_indexes = {} + for router in routers: + if router.index is None: + unordered.append(router) + continue + if router.index < 0: + negative_ordered_routers.append(router) + else: + ordered_routers[router.index] = router + + if router.index in checked_indexes: + raise ValueError(f"Views {checked_indexes[router.index]} and {router} have equal indexes!") + checked_indexes[router.index] = router + + result = [] + for i in range(len(routers) - len(negative_ordered_routers)): + ordered = ordered_routers.pop(i, None) + if ordered is not None: + result.append(ordered) + continue + if unordered: + result.append(unordered.pop(0)) + + # for case where there is a router with an index too big (more than the amount of routers) + result += sorted(ordered_routers.values(), key=lambda x: x.index) + + result += sorted(negative_ordered_routers, key=lambda x: x.index) + return result diff --git a/examples/routers_autoload/__init__.py b/examples/routers_autoload/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/routers_autoload/__main__.py b/examples/routers_autoload/__main__.py new file mode 100644 index 00000000..656a8aa9 --- /dev/null +++ b/examples/routers_autoload/__main__.py @@ -0,0 +1,18 @@ +from aiogram import Dispatcher, Bot +from aiogram.utils.routers import find_all_routers + +from . import config + + +if __name__ == "__main__": + bot = Bot(config.TELEGRAM_BOT_TOKEN) + dp = Dispatcher() + + # find all routers in routers_autoload.handlers and subdirectories, sub-subdirectories etc. + # the routers are by default ordered by their indexes (which can be still null) + # see info about ordering rules in find_all_routers docstring + routers = find_all_routers("routers_autoload.handlers") + + for r in routers: + dp.include_router(r) + dp.run_polling(bot) diff --git a/examples/routers_autoload/config.py b/examples/routers_autoload/config.py new file mode 100644 index 00000000..756cb43d --- /dev/null +++ b/examples/routers_autoload/config.py @@ -0,0 +1,3 @@ +import os + +TELEGRAM_BOT_TOKEN = os.environ["TELEGRAM_BOT_TOKEN"] diff --git a/examples/routers_autoload/handlers/__init__.py b/examples/routers_autoload/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/routers_autoload/handlers/callback_query.py b/examples/routers_autoload/handlers/callback_query.py new file mode 100644 index 00000000..cbd8038e --- /dev/null +++ b/examples/routers_autoload/handlers/callback_query.py @@ -0,0 +1,9 @@ +from aiogram import Router +from aiogram.types import CallbackQuery + +router = Router() + + +@router.callback_query() +async def process_callback_query(q: CallbackQuery): + await q.answer("Success!", show_alert=True) diff --git a/examples/routers_autoload/handlers/message/__init__.py b/examples/routers_autoload/handlers/message/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/routers_autoload/handlers/message/echo.py b/examples/routers_autoload/handlers/message/echo.py new file mode 100644 index 00000000..71a3adeb --- /dev/null +++ b/examples/routers_autoload/handlers/message/echo.py @@ -0,0 +1,10 @@ +from aiogram import Router +from aiogram.types import Message + +# add index -1 for this router to be registered the latest of all found in find_all_routers +router = Router(index=-1) + + +@router.message() +async def process_message(m: Message): + await m.copy_to(chat_id=m.chat.id) diff --git a/examples/routers_autoload/handlers/message/start.py b/examples/routers_autoload/handlers/message/start.py new file mode 100644 index 00000000..cf0b348c --- /dev/null +++ b/examples/routers_autoload/handlers/message/start.py @@ -0,0 +1,9 @@ +from aiogram import Router +from aiogram.types import Message + +router = Router() + + +@router.message(commands=["start"]) +async def process_message(m: Message): + await m.answer("Hi!") diff --git a/tests/modules_for_tests/__init__.py b/tests/modules_for_tests/__init__.py new file mode 100644 index 00000000..c1aaa4f7 --- /dev/null +++ b/tests/modules_for_tests/__init__.py @@ -0,0 +1,4 @@ +from aiogram import Router + +router = Router(name="__init__") +print("__init__ imported") diff --git a/tests/modules_for_tests/small_module.py b/tests/modules_for_tests/small_module.py new file mode 100644 index 00000000..88d5bc06 --- /dev/null +++ b/tests/modules_for_tests/small_module.py @@ -0,0 +1,4 @@ +from aiogram import Router + +router = Router(name="small_module") +print("small_module imported") diff --git a/tests/modules_for_tests/small_package/__init__.py b/tests/modules_for_tests/small_package/__init__.py new file mode 100644 index 00000000..19a3cbd6 --- /dev/null +++ b/tests/modules_for_tests/small_package/__init__.py @@ -0,0 +1,4 @@ +from aiogram import Router + +router = Router(name="small_package") +print("small_package imported") diff --git a/tests/modules_for_tests/small_package/nested_small_module.py b/tests/modules_for_tests/small_package/nested_small_module.py new file mode 100644 index 00000000..8183f2e7 --- /dev/null +++ b/tests/modules_for_tests/small_package/nested_small_module.py @@ -0,0 +1,4 @@ +from aiogram import Router + +router = Router(name="nested_small_module") +print("nested_small_module imported") diff --git a/tests/test_utils/test_mixins.py b/tests/test_utils/test_mixins.py index f9fbbade..9ad35b05 100644 --- a/tests/test_utils/test_mixins.py +++ b/tests/test_utils/test_mixins.py @@ -1,6 +1,6 @@ import pytest -from aiogram.utils.mixins import ContextInstanceMixin, DataMixin +from aiogram.utils.mixins import ContextInstanceMixin, DataMixin, KeepRefsMixin class ContextObject(ContextInstanceMixin["ContextObject"]): @@ -11,6 +11,10 @@ class DataObject(DataMixin): pass +class RefsObject(KeepRefsMixin): + pass + + class TestDataMixin: def test_store_value(self): obj = DataObject() @@ -52,3 +56,18 @@ class TestContextInstanceMixin: TypeError, match=r"Value should be instance of 'ContextObject' not '.+'" ): obj.set_current(42) + + +class TestKeepRefsMixin: + def test_refs_are_saved(self): + obj = RefsObject() + + assert obj in RefsObject.get_instances() + + def test_refs_are_deleted(self): + obj = RefsObject() + size_with_obj = len(RefsObject.get_instances()) + del obj + size_without_obj = len(RefsObject.get_instances()) + + assert size_with_obj - 1 == size_without_obj diff --git a/tests/test_utils/test_routers.py b/tests/test_utils/test_routers.py new file mode 100644 index 00000000..9210b1e2 --- /dev/null +++ b/tests/test_utils/test_routers.py @@ -0,0 +1,27 @@ +from aiogram import Router +from aiogram.utils.routers import find_all_routers, _order_routers + + +EXPECTED_ROUTERS_NAMES = frozenset({"__init__", "small_module", + "small_package", "nested_small_module"}) + + +def test_all_routers_are_valid(): + routers = find_all_routers("tests.modules_for_tests") + for router in routers: + assert isinstance(router, Router) + + +def test_all_expected_routers_are_found(): + routers = find_all_routers("tests.modules_for_tests") + found_names = {router.name for router in routers} + for name in EXPECTED_ROUTERS_NAMES: + assert name in found_names + + +def test_routers_ordering(): + indexes = [None, None, None, None, 1, -2, 4] + routers = [Router(index=index) for index in indexes] + ordered_indexes = [r.index for r in _order_routers(routers)] + assert ordered_indexes == [None, 1, None, None, 4, None, -2] +