add feature of autoregistering routers

This commit is contained in:
drforse 2021-12-16 02:53:01 +03:00
parent 76ae5c4415
commit 8659132005
18 changed files with 235 additions and 6 deletions

View file

@ -5,6 +5,7 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union
from ..types import TelegramObject
from ..utils.imports import import_module
from ..utils.mixins import KeepRefsMixin
from ..utils.warnings import CodeHasNoEffect
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
@ -14,7 +15,7 @@ from .filters import BUILTIN_FILTERS
INTERNAL_UPDATE_TYPES = frozenset({"update", "error"})
class Router:
class Router(KeepRefsMixin):
"""
Router can route update, and it nested update types like messages, callback query,
polls and all other event types.
@ -25,15 +26,23 @@ class Router:
- By decorator - :obj:`@router.<event_type>(<filters, ...>)`
"""
def __init__(self, use_builtin_filters: bool = True, name: Optional[str] = None) -> None:
def __init__(
self,
use_builtin_filters: bool = True,
name: Optional[str] = None,
index: int = None
) -> None:
"""
:param use_builtin_filters: `aiogram` has many builtin filters and you can controll automatic registration of this filters in factory
:param name: Optional router name, can be useful for debugging
:param index: used only for ordering in utils.routers.find_all_routers
"""
super().__init__()
self.use_builtin_filters = use_builtin_filters
self.name = name or hex(id(self))
self.index = index
self._parent_router: Optional[Router] = None
self.sub_routers: List[Router] = []
@ -90,7 +99,7 @@ class Router:
observer.bind_filter(builtin_filter)
def __str__(self) -> str:
return f"{type(self).__name__} {self.name!r}"
return f"{type(self).__name__} {self.name!r} {self.index if self.index is not None else ''}"
def __repr__(self) -> str:
return f"<{self}>"

View file

@ -1,5 +1,6 @@
import importlib
from typing import Any
from typing import Any, Set
from pathlib import Path
def import_module(target: str) -> Any:
@ -21,3 +22,34 @@ def import_module(target: str) -> Any:
raise ValueError(f'Module "{module_name}" has no attribute "{attr_name}"')
return attribute
DEFAULT_EXCLUDE_MODULES = frozenset({"__pycache__", "__init__.py", "__main__.py"})
def import_all_modules(
root: str,
package: str = None,
exclude: Set[str] = DEFAULT_EXCLUDE_MODULES
):
"""
imports all modules inside root and inside all subdirectories, sub-subdirectories etc of root
The 'package' argument is required when performing a relative import. It
specifies the package to use as the anchor point from which to resolve the
relative import to an absolute import.
:param root: root directory where function will start importing and digging to subdirectories
:param package: your top-level package name if 'root' is not absolute (starts with .)
:param exclude: set of names that will be ignored,
if it is a directory - also doesn't iterate over its insides
"""
root_module = importlib.import_module(root, package)
root_dir = root_module.__path__[0]
for sub_dir in Path(root_dir).iterdir():
if sub_dir.name in exclude:
continue
if sub_dir.is_dir():
import_all_modules(f"{root}.{sub_dir.stem}", package)
else:
importlib.import_module(f"{root}.{sub_dir.stem}", package)

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import contextvars
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Optional, TypeVar, cast, overload
if TYPE_CHECKING:
from typing_extensions import Literal
__all__ = ("ContextInstanceMixin", "DataMixin")
__all__ = ("ContextInstanceMixin", "DataMixin", "KeepRefsMixin")
class DataMixin:
@ -93,3 +95,14 @@ class ContextInstanceMixin(Generic[ContextInstance]):
@classmethod
def reset_current(cls, token: contextvars.Token[ContextInstance]) -> None:
cls.__context_instance.reset(token)
class KeepRefsMixin:
__refs__ = defaultdict(weakref.WeakSet)
def __init__(self):
self.__refs__[self.__class__].add(self)
@classmethod
def get_instances(cls) -> weakref.WeakSet:
return cls.__refs__[cls]

64
aiogram/utils/routers.py Normal file
View file

@ -0,0 +1,64 @@
from typing import Set, List
from aiogram import Router
from aiogram.utils.imports import DEFAULT_EXCLUDE_MODULES, import_all_modules
def find_all_routers(
root: str,
order_by_index: bool = True,
package: str = None,
exclude_modules: Set[str] = DEFAULT_EXCLUDE_MODULES
) -> List[Router]:
"""
if order_by_index is True: routers are ordered by index applying following rules:
- indexes as in list: 0, 1, 2, ... -3, -2, -1
- if index is None - it is going between positive and negative numbers, also filling in
empty slots in positives, ordered as was ordered in WeakSet
ex.: 0, None, 2, None, 5, None, -1
ex.: 0, 1, 2, 3, 4, 5, None, -1
ex.: 0, 1, 2, 3, 4, 5, None, None, -3, -1
:param root: root directory where function will start importing and digging to subdirectories
:param order_by_index:
:param package: your top-level package name if 'root' is not absolute (starts with .)
:param exclude_modules: set of names that will be ignored,
if it is a directory - also doesn't iterate over its insides
"""
import_all_modules(root, package, exclude_modules)
routers = list(Router.get_instances())
return routers if order_by_index is False else _order_routers(routers)
def _order_routers(routers: 'List[Router]') -> List[Router]:
unordered = []
ordered_routers = {}
negative_ordered_routers = []
checked_indexes = {}
for router in routers:
if router.index is None:
unordered.append(router)
continue
if router.index < 0:
negative_ordered_routers.append(router)
else:
ordered_routers[router.index] = router
if router.index in checked_indexes:
raise ValueError(f"Views {checked_indexes[router.index]} and {router} have equal indexes!")
checked_indexes[router.index] = router
result = []
for i in range(len(routers) - len(negative_ordered_routers)):
ordered = ordered_routers.pop(i, None)
if ordered is not None:
result.append(ordered)
continue
if unordered:
result.append(unordered.pop(0))
# for case where there is a router with an index too big (more than the amount of routers)
result += sorted(ordered_routers.values(), key=lambda x: x.index)
result += sorted(negative_ordered_routers, key=lambda x: x.index)
return result

View file

View file

@ -0,0 +1,18 @@
from aiogram import Dispatcher, Bot
from aiogram.utils.routers import find_all_routers
from . import config
if __name__ == "__main__":
bot = Bot(config.TELEGRAM_BOT_TOKEN)
dp = Dispatcher()
# find all routers in routers_autoload.handlers and subdirectories, sub-subdirectories etc.
# the routers are by default ordered by their indexes (which can be still null)
# see info about ordering rules in find_all_routers docstring
routers = find_all_routers("routers_autoload.handlers")
for r in routers:
dp.include_router(r)
dp.run_polling(bot)

View file

@ -0,0 +1,3 @@
import os
TELEGRAM_BOT_TOKEN = os.environ["TELEGRAM_BOT_TOKEN"]

View file

@ -0,0 +1,9 @@
from aiogram import Router
from aiogram.types import CallbackQuery
router = Router()
@router.callback_query()
async def process_callback_query(q: CallbackQuery):
await q.answer("Success!", show_alert=True)

View file

@ -0,0 +1,10 @@
from aiogram import Router
from aiogram.types import Message
# add index -1 for this router to be registered the latest of all found in find_all_routers
router = Router(index=-1)
@router.message()
async def process_message(m: Message):
await m.copy_to(chat_id=m.chat.id)

View file

@ -0,0 +1,9 @@
from aiogram import Router
from aiogram.types import Message
router = Router()
@router.message(commands=["start"])
async def process_message(m: Message):
await m.answer("Hi!")

View file

@ -0,0 +1,4 @@
from aiogram import Router
router = Router(name="__init__")
print("__init__ imported")

View file

@ -0,0 +1,4 @@
from aiogram import Router
router = Router(name="small_module")
print("small_module imported")

View file

@ -0,0 +1,4 @@
from aiogram import Router
router = Router(name="small_package")
print("small_package imported")

View file

@ -0,0 +1,4 @@
from aiogram import Router
router = Router(name="nested_small_module")
print("nested_small_module imported")

View file

@ -1,6 +1,6 @@
import pytest
from aiogram.utils.mixins import ContextInstanceMixin, DataMixin
from aiogram.utils.mixins import ContextInstanceMixin, DataMixin, KeepRefsMixin
class ContextObject(ContextInstanceMixin["ContextObject"]):
@ -11,6 +11,10 @@ class DataObject(DataMixin):
pass
class RefsObject(KeepRefsMixin):
pass
class TestDataMixin:
def test_store_value(self):
obj = DataObject()
@ -52,3 +56,18 @@ class TestContextInstanceMixin:
TypeError, match=r"Value should be instance of 'ContextObject' not '.+'"
):
obj.set_current(42)
class TestKeepRefsMixin:
def test_refs_are_saved(self):
obj = RefsObject()
assert obj in RefsObject.get_instances()
def test_refs_are_deleted(self):
obj = RefsObject()
size_with_obj = len(RefsObject.get_instances())
del obj
size_without_obj = len(RefsObject.get_instances())
assert size_with_obj - 1 == size_without_obj

View file

@ -0,0 +1,27 @@
from aiogram import Router
from aiogram.utils.routers import find_all_routers, _order_routers
EXPECTED_ROUTERS_NAMES = frozenset({"__init__", "small_module",
"small_package", "nested_small_module"})
def test_all_routers_are_valid():
routers = find_all_routers("tests.modules_for_tests")
for router in routers:
assert isinstance(router, Router)
def test_all_expected_routers_are_found():
routers = find_all_routers("tests.modules_for_tests")
found_names = {router.name for router in routers}
for name in EXPECTED_ROUTERS_NAMES:
assert name in found_names
def test_routers_ordering():
indexes = [None, None, None, None, 1, -2, 4]
routers = [Router(index=index) for index in indexes]
ordered_indexes = [r.index for r in _order_routers(routers)]
assert ordered_indexes == [None, 1, None, None, 4, None, -2]