mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Add Router
This commit is contained in:
parent
0acdb24c3c
commit
de115452fd
12 changed files with 409 additions and 2 deletions
|
|
@ -17,3 +17,12 @@ class BaseBot(ContextInstanceMixin):
|
|||
|
||||
async def emit(self, method: TelegramMethod[T]) -> T:
|
||||
return await self.session.make_request(self.token, method)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.session.close()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from pydantic import BaseConfig, BaseModel, Extra
|
||||
|
||||
from aiogram.utils.mixins import ContextInstanceMixin
|
||||
|
||||
class TelegramObject(BaseModel):
|
||||
|
||||
class TelegramObject(ContextInstanceMixin, BaseModel):
|
||||
class Config(BaseConfig):
|
||||
use_enum_values = True
|
||||
orm_mode = True
|
||||
|
|
|
|||
0
aiogram/dispatcher/__init__.py
Normal file
0
aiogram/dispatcher/__init__.py
Normal file
47
aiogram/dispatcher/dispatcher.py
Normal file
47
aiogram/dispatcher/dispatcher.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from ..api.client.bot import Bot
|
||||
from ..api.methods import TelegramMethod
|
||||
from ..api.types import Update
|
||||
from .router import Router
|
||||
|
||||
|
||||
class Dispatcher(Router):
|
||||
@property
|
||||
def parent_router(self) -> Optional[Router]:
|
||||
return super().parent_router
|
||||
|
||||
@parent_router.setter
|
||||
def parent_router(self, value) -> Optional[Router]:
|
||||
# Dispatcher is root Router then configuring parent router is not allowed
|
||||
raise RuntimeError("Dispatcher can not be attached to another Router.")
|
||||
|
||||
async def feed_update(self, bot: Bot, update: Update):
|
||||
"""
|
||||
Main entry point for incoming updates
|
||||
|
||||
:param bot:
|
||||
:param update:
|
||||
:return:
|
||||
"""
|
||||
Bot.set_current(bot)
|
||||
async for result in self.update_handler.trigger(update, bot=bot):
|
||||
yield result
|
||||
|
||||
@classmethod
|
||||
async def listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]:
|
||||
update_id: Optional[int] = None
|
||||
while True:
|
||||
for update in await bot.get_updates(offset=update_id):
|
||||
yield update
|
||||
update_id = update.update_id + 1
|
||||
|
||||
async def polling(self, bot: Bot):
|
||||
self.emit_startup(bot=bot)
|
||||
try:
|
||||
async for update in self.listen_updates(bot):
|
||||
async for result in self.feed_update(bot, update):
|
||||
if isinstance(result, TelegramMethod):
|
||||
await result
|
||||
finally:
|
||||
self.emit_shutdown(bot=bot)
|
||||
0
aiogram/dispatcher/event/__init__.py
Normal file
0
aiogram/dispatcher/event/__init__.py
Normal file
55
aiogram/dispatcher/event/handler.py
Normal file
55
aiogram/dispatcher/event/handler.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union
|
||||
|
||||
CallbackType = Callable[[Any], Awaitable[Any]]
|
||||
SyncFilter = Callable[[Any], Any]
|
||||
AsyncFilter = Callable[[Any], Awaitable[Any]]
|
||||
FilterType = Union[SyncFilter, AsyncFilter]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallableMixin:
|
||||
callback: Callable[[Any], Any]
|
||||
awaitable: bool = field(init=False)
|
||||
spec: inspect.FullArgSpec = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
callback = self.callback
|
||||
self.awaitable = inspect.isawaitable(callback) or inspect.iscoroutinefunction(callback)
|
||||
while hasattr(callback, "__wrapped__"): # Try to resolve decorated callbacks
|
||||
callback = callback.__wrapped__
|
||||
self.spec = inspect.getfullargspec(callback)
|
||||
|
||||
def _prepare_kwargs(self, kwargs):
|
||||
if self.spec.varkw:
|
||||
return kwargs
|
||||
|
||||
return {k: v for k, v in kwargs.items() if k in self.spec.args}
|
||||
|
||||
async def call(self, *args, **kwargs):
|
||||
wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs))
|
||||
if self.awaitable:
|
||||
return await wrapped()
|
||||
return wrapped()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterObject(CallableMixin):
|
||||
callback: FilterType
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandlerObject(CallableMixin):
|
||||
callback: CallbackType
|
||||
filters: List[FilterObject]
|
||||
|
||||
async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]:
|
||||
for event_filter in self.filters:
|
||||
check = await event_filter.call(*args, **kwargs)
|
||||
if not check:
|
||||
return False, {}
|
||||
if isinstance(check, dict):
|
||||
kwargs.update(check)
|
||||
return True, kwargs
|
||||
127
aiogram/dispatcher/event/observer.py
Normal file
127
aiogram/dispatcher/event/observer.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..filters.base import BaseFilter
|
||||
from .handler import CallbackType, FilterObject, FilterType, HandlerObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EventObserver:
|
||||
"""
|
||||
Base events observer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: CallbackType, *filters: FilterType):
|
||||
"""
|
||||
Register callback with filters
|
||||
|
||||
:param callback:
|
||||
:param filters:
|
||||
:return:
|
||||
"""
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
callback=callback, filters=[FilterObject(filter_) for filter_ in filters]
|
||||
)
|
||||
)
|
||||
|
||||
async def trigger(self, *args, **kwargs):
|
||||
for handler in self.handlers:
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
if result:
|
||||
try:
|
||||
yield await handler.call(*args, **data)
|
||||
except SkipHandler:
|
||||
continue
|
||||
|
||||
def __call__(self, *args: FilterType):
|
||||
def wrapper(callback: CallbackType):
|
||||
self.register(callback, *args)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TelegramEventObserver(EventObserver):
|
||||
"""
|
||||
Event observer for Telegram events
|
||||
"""
|
||||
|
||||
def __init__(self, router: Router, event_name: str):
|
||||
super().__init__()
|
||||
|
||||
self.router: Router = router
|
||||
self.event_name: str = event_name
|
||||
self.filters: List[Type[BaseFilter]] = []
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
if not isinstance(bound_filter, BaseFilter):
|
||||
pass
|
||||
self.filters.append(bound_filter)
|
||||
|
||||
def _resolve_filters_chain(self):
|
||||
registry: List[FilterType] = []
|
||||
routers: List[Router] = []
|
||||
|
||||
router = self.router
|
||||
while router and router not in routers:
|
||||
observer = router.observers[self.event_name]
|
||||
routers.append(router)
|
||||
router = router.parent_router
|
||||
|
||||
for filter_ in observer.filters:
|
||||
if filter_ in registry:
|
||||
continue
|
||||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def register(self, callback: CallbackType, *filters: FilterType, **bound_filters):
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
return super().register(callback, *filters, *resolved_filters)
|
||||
|
||||
async def trigger(self, *args, **kwargs):
|
||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||
yield result
|
||||
break
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
filters: List[BaseFilter] = []
|
||||
if not full_config:
|
||||
return filters
|
||||
|
||||
for bound_filter in self._resolve_filters_chain():
|
||||
# Try to initialize filter.
|
||||
try:
|
||||
f = bound_filter(**full_config)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
# Clean full config to prevent to re-initialize another filter with the same configuration
|
||||
for key in f.__fields__:
|
||||
full_config.pop(key)
|
||||
|
||||
filters.append(f)
|
||||
|
||||
if full_config:
|
||||
raise ValueError(f"Unknown filters: {set(full_config.keys())}")
|
||||
|
||||
return filters
|
||||
|
||||
def __call__(self, *args: FilterType, **bound_filters):
|
||||
def wrapper(callback: CallbackType):
|
||||
self.register(callback, *args, **bound_filters)
|
||||
return callback
|
||||
|
||||
return wrapper
|
||||
0
aiogram/dispatcher/filters/__init__.py
Normal file
0
aiogram/dispatcher/filters/__init__.py
Normal file
13
aiogram/dispatcher/filters/base.py
Normal file
13
aiogram/dispatcher/filters/base.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseFilter(ABC, BaseModel):
|
||||
@abstractmethod
|
||||
async def __call__(self, *args, **kwargs) -> Union[bool, Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def __await__(self):
|
||||
return self.__call__
|
||||
154
aiogram/dispatcher/router.py
Normal file
154
aiogram/dispatcher/router.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from ..api.types import Chat, Update, User
|
||||
from .event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(self):
|
||||
self._parent_router: Optional[Router] = None
|
||||
self.sub_routers: List[Router] = []
|
||||
|
||||
self.update_handler = TelegramEventObserver(router=self, event_name="update")
|
||||
self.message_handler = TelegramEventObserver(router=self, event_name="message")
|
||||
self.edited_message_handler = TelegramEventObserver(
|
||||
router=self, event_name="edited_message"
|
||||
)
|
||||
self.channel_post_handler = TelegramEventObserver(router=self, event_name="channel_post")
|
||||
self.edited_channel_post_handler = TelegramEventObserver(
|
||||
router=self, event_name="edited_channel_post"
|
||||
)
|
||||
self.inline_query_handler = TelegramEventObserver(router=self, event_name="inline_query")
|
||||
self.chosen_inline_result_handler = TelegramEventObserver(
|
||||
router=self, event_name="chosen_inline_result"
|
||||
)
|
||||
self.callback_query_handler = TelegramEventObserver(
|
||||
router=self, event_name="callback_query"
|
||||
)
|
||||
self.shipping_query_handler = TelegramEventObserver(
|
||||
router=self, event_name="shipping_query"
|
||||
)
|
||||
self.pre_checkout_query_handler = TelegramEventObserver(
|
||||
router=self, event_name="pre_checkout_query"
|
||||
)
|
||||
self.poll_handler = TelegramEventObserver(router=self, event_name="poll")
|
||||
|
||||
self.startup = EventObserver()
|
||||
self.shutdown = EventObserver()
|
||||
|
||||
self.observers = {
|
||||
"message": self.message_handler,
|
||||
"edited_message": self.edited_message_handler,
|
||||
"channel_post": self.channel_post_handler,
|
||||
"edited_channel_post": self.edited_channel_post_handler,
|
||||
"inline_query": self.inline_query_handler,
|
||||
"chosen_inline_result": self.chosen_inline_result_handler,
|
||||
"callback_query": self.callback_query_handler,
|
||||
"shipping_query": self.shipping_query_handler,
|
||||
"pre_checkout_query": self.pre_checkout_query_handler,
|
||||
"poll": self.poll_handler,
|
||||
}
|
||||
|
||||
self.update_handler.register(self._listen_update)
|
||||
|
||||
@property
|
||||
def parent_router(self) -> Optional[Router]:
|
||||
return self._parent_router
|
||||
|
||||
@parent_router.setter
|
||||
def parent_router(self, router: Router) -> None:
|
||||
if self._parent_router:
|
||||
raise RuntimeError(f"Router is already attached to {self._parent_router!r}")
|
||||
if self == router:
|
||||
raise RuntimeError("Self-referencing routers is not allowed")
|
||||
parent: Optional[Router] = router
|
||||
while parent is not None:
|
||||
if parent == self:
|
||||
raise RuntimeError("Circular referencing of Router is not allowed")
|
||||
parent = parent.parent_router
|
||||
|
||||
self._parent_router = router
|
||||
|
||||
def include_router(self, router: Router) -> Router:
|
||||
router.parent_router = self
|
||||
self.sub_routers.append(router)
|
||||
return router
|
||||
|
||||
async def _listen_update(self, update: Update, **kwargs: Any) -> Any:
|
||||
kwargs.update(event_update=update, event_router=self)
|
||||
|
||||
chat: Optional[Chat] = None
|
||||
from_user: Optional[User] = None
|
||||
|
||||
if update.message:
|
||||
update_type = "message"
|
||||
from_user = update.message.from_user
|
||||
chat = update.message.chat
|
||||
event = update.message
|
||||
elif update.edited_message:
|
||||
update_type = "edited_message"
|
||||
from_user = update.edited_message.from_user
|
||||
chat = update.edited_message.chat
|
||||
event = update.edited_message
|
||||
elif update.channel_post:
|
||||
update_type = "channel_post"
|
||||
chat = update.channel_post.chat
|
||||
event = update.channel_post
|
||||
elif update.edited_channel_post:
|
||||
update_type = "edited_channel_post"
|
||||
chat = update.edited_channel_post.chat
|
||||
event = update.edited_channel_post
|
||||
elif update.inline_query:
|
||||
update_type = "inline_query"
|
||||
from_user = update.inline_query.from_user
|
||||
event = update.inline_query
|
||||
elif update.chosen_inline_result:
|
||||
update_type = "chosen_inline_result"
|
||||
from_user = update.chosen_inline_result.from_user
|
||||
event = update.chosen_inline_result
|
||||
elif update.callback_query:
|
||||
update_type = "callback_query"
|
||||
if update.callback_query.message:
|
||||
chat = update.callback_query.message.chat
|
||||
from_user = update.callback_query.from_user
|
||||
event = update.callback_query
|
||||
elif update.shipping_query:
|
||||
update_type = "shipping_query"
|
||||
from_user = update.shipping_query.from_user
|
||||
event = update.shipping_query
|
||||
elif update.pre_checkout_query:
|
||||
update_type = "pre_checkout_query"
|
||||
from_user = update.pre_checkout_query.from_user
|
||||
event = update.pre_checkout_query
|
||||
elif update.poll:
|
||||
update_type = "poll"
|
||||
event = update.poll
|
||||
else:
|
||||
raise SkipHandler
|
||||
|
||||
observer = self.observers[update_type]
|
||||
if from_user:
|
||||
User.set_current(from_user)
|
||||
if chat:
|
||||
Chat.set_current(chat)
|
||||
async for result in observer.trigger(event, **kwargs):
|
||||
return result
|
||||
|
||||
for router in self.sub_routers:
|
||||
kwargs.update(event_router=router)
|
||||
async for result in router.update_handler.trigger(update, **kwargs):
|
||||
return result
|
||||
|
||||
raise SkipHandler
|
||||
|
||||
def emit_startup(self, *args, **kwargs):
|
||||
self.startup.trigger(*args, **kwargs)
|
||||
for router in self.sub_routers:
|
||||
router.emit_startup(*args, **kwargs)
|
||||
|
||||
def emit_shutdown(self, *args, **kwargs):
|
||||
self.startup.trigger(*args, **kwargs)
|
||||
for router in self.sub_routers:
|
||||
router.emit_startup(*args, **kwargs)
|
||||
|
|
@ -34,6 +34,7 @@ T = TypeVar("T")
|
|||
|
||||
class ContextInstanceMixin:
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}")
|
||||
return cls
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue