mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
add feature of autoregistering routers
This commit is contained in:
parent
76ae5c4415
commit
8659132005
18 changed files with 235 additions and 6 deletions
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union
|
|||
|
||||
from ..types import TelegramObject
|
||||
from ..utils.imports import import_module
|
||||
from ..utils.mixins import KeepRefsMixin
|
||||
from ..utils.warnings import CodeHasNoEffect
|
||||
from .event.bases import REJECTED, UNHANDLED
|
||||
from .event.event import EventObserver
|
||||
|
|
@ -14,7 +15,7 @@ from .filters import BUILTIN_FILTERS
|
|||
INTERNAL_UPDATE_TYPES = frozenset({"update", "error"})
|
||||
|
||||
|
||||
class Router:
|
||||
class Router(KeepRefsMixin):
|
||||
"""
|
||||
Router can route update, and it nested update types like messages, callback query,
|
||||
polls and all other event types.
|
||||
|
|
@ -25,15 +26,23 @@ class Router:
|
|||
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
|
||||
"""
|
||||
|
||||
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
use_builtin_filters: bool = True,
|
||||
name: Optional[str] = None,
|
||||
index: int = None
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
|
||||
:param name: Optional router name, can be useful for debugging
|
||||
:param index: used only for ordering in utils.routers.find_all_routers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.use_builtin_filters = use_builtin_filters
|
||||
self.name = name or hex(id(self))
|
||||
self.index = index
|
||||
|
||||
self._parent_router: Optional[Router] = None
|
||||
self.sub_routers: List[Router] = []
|
||||
|
|
@ -90,7 +99,7 @@ class Router:
|
|||
observer.bind_filter(builtin_filter)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__} {self.name!r}"
|
||||
return f"{type(self).__name__} {self.name!r} {self.index if self.index is not None else ''}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self}>"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import importlib
|
||||
from typing import Any
|
||||
from typing import Any, Set
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def import_module(target: str) -> Any:
|
||||
|
|
@ -21,3 +22,34 @@ def import_module(target: str) -> Any:
|
|||
raise ValueError(f'Module "{module_name}" has no attribute "{attr_name}"')
|
||||
|
||||
return attribute
|
||||
|
||||
|
||||
DEFAULT_EXCLUDE_MODULES = frozenset({"__pycache__", "__init__.py", "__main__.py"})
|
||||
|
||||
|
||||
def import_all_modules(
|
||||
root: str,
|
||||
package: str = None,
|
||||
exclude: Set[str] = DEFAULT_EXCLUDE_MODULES
|
||||
):
|
||||
"""
|
||||
imports all modules inside root and inside all subdirectories, sub-subdirectories etc of root
|
||||
|
||||
The 'package' argument is required when performing a relative import. It
|
||||
specifies the package to use as the anchor point from which to resolve the
|
||||
relative import to an absolute import.
|
||||
|
||||
:param root: root directory where function will start importing and digging to subdirectories
|
||||
:param package: your top-level package name if 'root' is not absolute (starts with .)
|
||||
:param exclude: set of names that will be ignored,
|
||||
if it is a directory - also doesn't iterate over its insides
|
||||
"""
|
||||
root_module = importlib.import_module(root, package)
|
||||
root_dir = root_module.__path__[0]
|
||||
for sub_dir in Path(root_dir).iterdir():
|
||||
if sub_dir.name in exclude:
|
||||
continue
|
||||
if sub_dir.is_dir():
|
||||
import_all_modules(f"{root}.{sub_dir.stem}", package)
|
||||
else:
|
||||
importlib.import_module(f"{root}.{sub_dir.stem}", package)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
__all__ = ("ContextInstanceMixin", "DataMixin")
|
||||
__all__ = ("ContextInstanceMixin", "DataMixin", "KeepRefsMixin")
|
||||
|
||||
|
||||
class DataMixin:
|
||||
|
|
@ -93,3 +95,14 @@ class ContextInstanceMixin(Generic[ContextInstance]):
|
|||
@classmethod
|
||||
def reset_current(cls, token: contextvars.Token[ContextInstance]) -> None:
|
||||
cls.__context_instance.reset(token)
|
||||
|
||||
|
||||
class KeepRefsMixin:
|
||||
__refs__ = defaultdict(weakref.WeakSet)
|
||||
|
||||
def __init__(self):
|
||||
self.__refs__[self.__class__].add(self)
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> weakref.WeakSet:
|
||||
return cls.__refs__[cls]
|
||||
|
|
|
|||
64
aiogram/utils/routers.py
Normal file
64
aiogram/utils/routers.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Set, List
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.utils.imports import DEFAULT_EXCLUDE_MODULES, import_all_modules
|
||||
|
||||
|
||||
def find_all_routers(
|
||||
root: str,
|
||||
order_by_index: bool = True,
|
||||
package: str = None,
|
||||
exclude_modules: Set[str] = DEFAULT_EXCLUDE_MODULES
|
||||
) -> List[Router]:
|
||||
"""
|
||||
if order_by_index is True: routers are ordered by index applying following rules:
|
||||
- indexes as in list: 0, 1, 2, ... -3, -2, -1
|
||||
- if index is None - it is going between positive and negative numbers, also filling in
|
||||
empty slots in positives, ordered as was ordered in WeakSet
|
||||
ex.: 0, None, 2, None, 5, None, -1
|
||||
ex.: 0, 1, 2, 3, 4, 5, None, -1
|
||||
ex.: 0, 1, 2, 3, 4, 5, None, None, -3, -1
|
||||
|
||||
:param root: root directory where function will start importing and digging to subdirectories
|
||||
:param order_by_index:
|
||||
:param package: your top-level package name if 'root' is not absolute (starts with .)
|
||||
:param exclude_modules: set of names that will be ignored,
|
||||
if it is a directory - also doesn't iterate over its insides
|
||||
"""
|
||||
import_all_modules(root, package, exclude_modules)
|
||||
routers = list(Router.get_instances())
|
||||
return routers if order_by_index is False else _order_routers(routers)
|
||||
|
||||
|
||||
def _order_routers(routers: 'List[Router]') -> List[Router]:
|
||||
unordered = []
|
||||
ordered_routers = {}
|
||||
negative_ordered_routers = []
|
||||
checked_indexes = {}
|
||||
for router in routers:
|
||||
if router.index is None:
|
||||
unordered.append(router)
|
||||
continue
|
||||
if router.index < 0:
|
||||
negative_ordered_routers.append(router)
|
||||
else:
|
||||
ordered_routers[router.index] = router
|
||||
|
||||
if router.index in checked_indexes:
|
||||
raise ValueError(f"Views {checked_indexes[router.index]} and {router} have equal indexes!")
|
||||
checked_indexes[router.index] = router
|
||||
|
||||
result = []
|
||||
for i in range(len(routers) - len(negative_ordered_routers)):
|
||||
ordered = ordered_routers.pop(i, None)
|
||||
if ordered is not None:
|
||||
result.append(ordered)
|
||||
continue
|
||||
if unordered:
|
||||
result.append(unordered.pop(0))
|
||||
|
||||
# for case where there is a router with an index too big (more than the amount of routers)
|
||||
result += sorted(ordered_routers.values(), key=lambda x: x.index)
|
||||
|
||||
result += sorted(negative_ordered_routers, key=lambda x: x.index)
|
||||
return result
|
||||
0
examples/routers_autoload/__init__.py
Normal file
0
examples/routers_autoload/__init__.py
Normal file
18
examples/routers_autoload/__main__.py
Normal file
18
examples/routers_autoload/__main__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from aiogram import Dispatcher, Bot
|
||||
from aiogram.utils.routers import find_all_routers
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = Bot(config.TELEGRAM_BOT_TOKEN)
|
||||
dp = Dispatcher()
|
||||
|
||||
# find all routers in routers_autoload.handlers and subdirectories, sub-subdirectories etc.
|
||||
# the routers are by default ordered by their indexes (which can be still null)
|
||||
# see info about ordering rules in find_all_routers docstring
|
||||
routers = find_all_routers("routers_autoload.handlers")
|
||||
|
||||
for r in routers:
|
||||
dp.include_router(r)
|
||||
dp.run_polling(bot)
|
||||
3
examples/routers_autoload/config.py
Normal file
3
examples/routers_autoload/config.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
import os
|
||||
|
||||
TELEGRAM_BOT_TOKEN = os.environ["TELEGRAM_BOT_TOKEN"]
|
||||
0
examples/routers_autoload/handlers/__init__.py
Normal file
0
examples/routers_autoload/handlers/__init__.py
Normal file
9
examples/routers_autoload/handlers/callback_query.py
Normal file
9
examples/routers_autoload/handlers/callback_query.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from aiogram import Router
|
||||
from aiogram.types import CallbackQuery
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.callback_query()
|
||||
async def process_callback_query(q: CallbackQuery):
|
||||
await q.answer("Success!", show_alert=True)
|
||||
0
examples/routers_autoload/handlers/message/__init__.py
Normal file
0
examples/routers_autoload/handlers/message/__init__.py
Normal file
10
examples/routers_autoload/handlers/message/echo.py
Normal file
10
examples/routers_autoload/handlers/message/echo.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from aiogram import Router
|
||||
from aiogram.types import Message
|
||||
|
||||
# add index -1 for this router to be registered the latest of all found in find_all_routers
|
||||
router = Router(index=-1)
|
||||
|
||||
|
||||
@router.message()
|
||||
async def process_message(m: Message):
|
||||
await m.copy_to(chat_id=m.chat.id)
|
||||
9
examples/routers_autoload/handlers/message/start.py
Normal file
9
examples/routers_autoload/handlers/message/start.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from aiogram import Router
|
||||
from aiogram.types import Message
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message(commands=["start"])
|
||||
async def process_message(m: Message):
|
||||
await m.answer("Hi!")
|
||||
4
tests/modules_for_tests/__init__.py
Normal file
4
tests/modules_for_tests/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from aiogram import Router
|
||||
|
||||
router = Router(name="__init__")
|
||||
print("__init__ imported")
|
||||
4
tests/modules_for_tests/small_module.py
Normal file
4
tests/modules_for_tests/small_module.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from aiogram import Router
|
||||
|
||||
router = Router(name="small_module")
|
||||
print("small_module imported")
|
||||
4
tests/modules_for_tests/small_package/__init__.py
Normal file
4
tests/modules_for_tests/small_package/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from aiogram import Router
|
||||
|
||||
router = Router(name="small_package")
|
||||
print("small_package imported")
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from aiogram import Router
|
||||
|
||||
router = Router(name="nested_small_module")
|
||||
print("nested_small_module imported")
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.utils.mixins import ContextInstanceMixin, DataMixin
|
||||
from aiogram.utils.mixins import ContextInstanceMixin, DataMixin, KeepRefsMixin
|
||||
|
||||
|
||||
class ContextObject(ContextInstanceMixin["ContextObject"]):
|
||||
|
|
@ -11,6 +11,10 @@ class DataObject(DataMixin):
|
|||
pass
|
||||
|
||||
|
||||
class RefsObject(KeepRefsMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestDataMixin:
|
||||
def test_store_value(self):
|
||||
obj = DataObject()
|
||||
|
|
@ -52,3 +56,18 @@ class TestContextInstanceMixin:
|
|||
TypeError, match=r"Value should be instance of 'ContextObject' not '.+'"
|
||||
):
|
||||
obj.set_current(42)
|
||||
|
||||
|
||||
class TestKeepRefsMixin:
|
||||
def test_refs_are_saved(self):
|
||||
obj = RefsObject()
|
||||
|
||||
assert obj in RefsObject.get_instances()
|
||||
|
||||
def test_refs_are_deleted(self):
|
||||
obj = RefsObject()
|
||||
size_with_obj = len(RefsObject.get_instances())
|
||||
del obj
|
||||
size_without_obj = len(RefsObject.get_instances())
|
||||
|
||||
assert size_with_obj - 1 == size_without_obj
|
||||
|
|
|
|||
27
tests/test_utils/test_routers.py
Normal file
27
tests/test_utils/test_routers.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from aiogram import Router
|
||||
from aiogram.utils.routers import find_all_routers, _order_routers
|
||||
|
||||
|
||||
EXPECTED_ROUTERS_NAMES = frozenset({"__init__", "small_module",
|
||||
"small_package", "nested_small_module"})
|
||||
|
||||
|
||||
def test_all_routers_are_valid():
|
||||
routers = find_all_routers("tests.modules_for_tests")
|
||||
for router in routers:
|
||||
assert isinstance(router, Router)
|
||||
|
||||
|
||||
def test_all_expected_routers_are_found():
|
||||
routers = find_all_routers("tests.modules_for_tests")
|
||||
found_names = {router.name for router in routers}
|
||||
for name in EXPECTED_ROUTERS_NAMES:
|
||||
assert name in found_names
|
||||
|
||||
|
||||
def test_routers_ordering():
|
||||
indexes = [None, None, None, None, 1, -2, 4]
|
||||
routers = [Router(index=index) for index in indexes]
|
||||
ordered_indexes = [r.index for r in _order_routers(routers)]
|
||||
assert ordered_indexes == [None, 1, None, None, 4, None, -2]
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue