diff --git a/aiogram/api/client/base.py b/aiogram/api/client/base.py index bbb3f974..5ded4c45 100644 --- a/aiogram/api/client/base.py +++ b/aiogram/api/client/base.py @@ -17,3 +17,12 @@ class BaseBot(ContextInstanceMixin): async def emit(self, method: TelegramMethod[T]) -> T: return await self.session.make_request(self.token, method) + + async def close(self): + await self.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.session.close() diff --git a/aiogram/api/client/session/base.py b/aiogram/api/client/session/base.py index 837b8709..30b4400e 100644 --- a/aiogram/api/client/session/base.py +++ b/aiogram/api/client/session/base.py @@ -1,5 +1,4 @@ import abc -import asyncio import datetime import json from typing import Any, Callable, Optional, TypeVar, Union diff --git a/aiogram/api/types/base.py b/aiogram/api/types/base.py index 4b5ca3be..b5e905ef 100644 --- a/aiogram/api/types/base.py +++ b/aiogram/api/types/base.py @@ -1,7 +1,9 @@ from pydantic import BaseConfig, BaseModel, Extra +from aiogram.utils.mixins import ContextInstanceMixin -class TelegramObject(BaseModel): + +class TelegramObject(ContextInstanceMixin, BaseModel): class Config(BaseConfig): use_enum_values = True orm_mode = True diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py new file mode 100644 index 00000000..72e3cda9 --- /dev/null +++ b/aiogram/dispatcher/dispatcher.py @@ -0,0 +1,47 @@ +from typing import AsyncGenerator, Optional + +from ..api.client.bot import Bot +from ..api.methods import TelegramMethod +from ..api.types import Update +from .router import Router + + +class Dispatcher(Router): + @property + def parent_router(self) -> Optional[Router]: + return super().parent_router + + @parent_router.setter + def parent_router(self, value) -> Optional[Router]: + # Dispatcher is root Router then configuring parent router is not allowed + raise RuntimeError("Dispatcher can not be attached to another Router.") + + async def feed_update(self, bot: Bot, update: Update): + """ + Main entry point for incoming updates + + :param bot: + :param update: + :return: + """ + Bot.set_current(bot) + async for result in self.update_handler.trigger(update, bot=bot): + yield result + + @classmethod + async def listen_updates(cls, bot: Bot) -> AsyncGenerator[Update, None]: + update_id: Optional[int] = None + while True: + for update in await bot.get_updates(offset=update_id): + yield update + update_id = update.update_id + 1 + + async def polling(self, bot: Bot): + self.emit_startup(bot=bot) + try: + async for update in self.listen_updates(bot): + async for result in self.feed_update(bot, update): + if isinstance(result, TelegramMethod): + await result + finally: + self.emit_shutdown(bot=bot) diff --git a/aiogram/dispatcher/event/__init__.py b/aiogram/dispatcher/event/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py new file mode 100644 index 00000000..bac21309 --- /dev/null +++ b/aiogram/dispatcher/event/handler.py @@ -0,0 +1,55 @@ +import inspect +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union + +CallbackType = Callable[[Any], Awaitable[Any]] +SyncFilter = Callable[[Any], Any] +AsyncFilter = Callable[[Any], Awaitable[Any]] +FilterType = Union[SyncFilter, AsyncFilter] + + +@dataclass +class CallableMixin: + callback: Callable[[Any], Any] + awaitable: bool = field(init=False) + spec: inspect.FullArgSpec = field(init=False) + + def __post_init__(self): + callback = self.callback + self.awaitable = inspect.isawaitable(callback) or inspect.iscoroutinefunction(callback) + while hasattr(callback, "__wrapped__"): # Try to resolve decorated callbacks + callback = callback.__wrapped__ + self.spec = inspect.getfullargspec(callback) + + def _prepare_kwargs(self, kwargs): + if self.spec.varkw: + return kwargs + + return {k: v for k, v in kwargs.items() if k in self.spec.args} + + async def call(self, *args, **kwargs): + wrapped = partial(self.callback, *args, **self._prepare_kwargs(kwargs)) + if self.awaitable: + return await wrapped() + return wrapped() + + +@dataclass +class FilterObject(CallableMixin): + callback: FilterType + + +@dataclass +class HandlerObject(CallableMixin): + callback: CallbackType + filters: List[FilterObject] + + async def check(self, *args: Any, **kwargs: Any) -> Tuple[bool, Dict[str, Any]]: + for event_filter in self.filters: + check = await event_filter.call(*args, **kwargs) + if not check: + return False, {} + if isinstance(check, dict): + kwargs.update(check) + return True, kwargs diff --git a/aiogram/dispatcher/event/observer.py b/aiogram/dispatcher/event/observer.py new file mode 100644 index 00000000..aee259e5 --- /dev/null +++ b/aiogram/dispatcher/event/observer.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Type + +from pydantic import ValidationError + +from ..filters.base import BaseFilter +from .handler import CallbackType, FilterObject, FilterType, HandlerObject + +if TYPE_CHECKING: + from aiogram.dispatcher.router import Router + + +class SkipHandler(Exception): + pass + + +class EventObserver: + """ + Base events observer + """ + + def __init__(self): + self.handlers: List[HandlerObject] = [] + + def register(self, callback: CallbackType, *filters: FilterType): + """ + Register callback with filters + + :param callback: + :param filters: + :return: + """ + self.handlers.append( + HandlerObject( + callback=callback, filters=[FilterObject(filter_) for filter_ in filters] + ) + ) + + async def trigger(self, *args, **kwargs): + for handler in self.handlers: + result, data = await handler.check(*args, **kwargs) + if result: + try: + yield await handler.call(*args, **data) + except SkipHandler: + continue + + def __call__(self, *args: FilterType): + def wrapper(callback: CallbackType): + self.register(callback, *args) + return callback + + return wrapper + + +class TelegramEventObserver(EventObserver): + """ + Event observer for Telegram events + """ + + def __init__(self, router: Router, event_name: str): + super().__init__() + + self.router: Router = router + self.event_name: str = event_name + self.filters: List[Type[BaseFilter]] = [] + + def bind_filter(self, bound_filter: Type[BaseFilter]) -> None: + if not isinstance(bound_filter, BaseFilter): + pass + self.filters.append(bound_filter) + + def _resolve_filters_chain(self): + registry: List[FilterType] = [] + routers: List[Router] = [] + + router = self.router + while router and router not in routers: + observer = router.observers[self.event_name] + routers.append(router) + router = router.parent_router + + for filter_ in observer.filters: + if filter_ in registry: + continue + yield filter_ + registry.append(filter_) + + def register(self, callback: CallbackType, *filters: FilterType, **bound_filters): + resolved_filters = self.resolve_filters(bound_filters) + return super().register(callback, *filters, *resolved_filters) + + async def trigger(self, *args, **kwargs): + async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs): + yield result + break + + def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]: + filters: List[BaseFilter] = [] + if not full_config: + return filters + + for bound_filter in self._resolve_filters_chain(): + # Try to initialize filter. + try: + f = bound_filter(**full_config) + except ValidationError: + continue + + # Clean full config to prevent to re-initialize another filter with the same configuration + for key in f.__fields__: + full_config.pop(key) + + filters.append(f) + + if full_config: + raise ValueError(f"Unknown filters: {set(full_config.keys())}") + + return filters + + def __call__(self, *args: FilterType, **bound_filters): + def wrapper(callback: CallbackType): + self.register(callback, *args, **bound_filters) + return callback + + return wrapper diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/dispatcher/filters/base.py b/aiogram/dispatcher/filters/base.py new file mode 100644 index 00000000..4d12cb69 --- /dev/null +++ b/aiogram/dispatcher/filters/base.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Union + +from pydantic import BaseModel + + +class BaseFilter(ABC, BaseModel): + @abstractmethod + async def __call__(self, *args, **kwargs) -> Union[bool, Dict[str, Any]]: + pass + + def __await__(self): + return self.__call__ diff --git a/aiogram/dispatcher/router.py b/aiogram/dispatcher/router.py new file mode 100644 index 00000000..93fd29a6 --- /dev/null +++ b/aiogram/dispatcher/router.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from ..api.types import Chat, Update, User +from .event.observer import EventObserver, SkipHandler, TelegramEventObserver + + +class Router: + def __init__(self): + self._parent_router: Optional[Router] = None + self.sub_routers: List[Router] = [] + + self.update_handler = TelegramEventObserver(router=self, event_name="update") + self.message_handler = TelegramEventObserver(router=self, event_name="message") + self.edited_message_handler = TelegramEventObserver( + router=self, event_name="edited_message" + ) + self.channel_post_handler = TelegramEventObserver(router=self, event_name="channel_post") + self.edited_channel_post_handler = TelegramEventObserver( + router=self, event_name="edited_channel_post" + ) + self.inline_query_handler = TelegramEventObserver(router=self, event_name="inline_query") + self.chosen_inline_result_handler = TelegramEventObserver( + router=self, event_name="chosen_inline_result" + ) + self.callback_query_handler = TelegramEventObserver( + router=self, event_name="callback_query" + ) + self.shipping_query_handler = TelegramEventObserver( + router=self, event_name="shipping_query" + ) + self.pre_checkout_query_handler = TelegramEventObserver( + router=self, event_name="pre_checkout_query" + ) + self.poll_handler = TelegramEventObserver(router=self, event_name="poll") + + self.startup = EventObserver() + self.shutdown = EventObserver() + + self.observers = { + "message": self.message_handler, + "edited_message": self.edited_message_handler, + "channel_post": self.channel_post_handler, + "edited_channel_post": self.edited_channel_post_handler, + "inline_query": self.inline_query_handler, + "chosen_inline_result": self.chosen_inline_result_handler, + "callback_query": self.callback_query_handler, + "shipping_query": self.shipping_query_handler, + "pre_checkout_query": self.pre_checkout_query_handler, + "poll": self.poll_handler, + } + + self.update_handler.register(self._listen_update) + + @property + def parent_router(self) -> Optional[Router]: + return self._parent_router + + @parent_router.setter + def parent_router(self, router: Router) -> None: + if self._parent_router: + raise RuntimeError(f"Router is already attached to {self._parent_router!r}") + if self == router: + raise RuntimeError("Self-referencing routers is not allowed") + parent: Optional[Router] = router + while parent is not None: + if parent == self: + raise RuntimeError("Circular referencing of Router is not allowed") + parent = parent.parent_router + + self._parent_router = router + + def include_router(self, router: Router) -> Router: + router.parent_router = self + self.sub_routers.append(router) + return router + + async def _listen_update(self, update: Update, **kwargs: Any) -> Any: + kwargs.update(event_update=update, event_router=self) + + chat: Optional[Chat] = None + from_user: Optional[User] = None + + if update.message: + update_type = "message" + from_user = update.message.from_user + chat = update.message.chat + event = update.message + elif update.edited_message: + update_type = "edited_message" + from_user = update.edited_message.from_user + chat = update.edited_message.chat + event = update.edited_message + elif update.channel_post: + update_type = "channel_post" + chat = update.channel_post.chat + event = update.channel_post + elif update.edited_channel_post: + update_type = "edited_channel_post" + chat = update.edited_channel_post.chat + event = update.edited_channel_post + elif update.inline_query: + update_type = "inline_query" + from_user = update.inline_query.from_user + event = update.inline_query + elif update.chosen_inline_result: + update_type = "chosen_inline_result" + from_user = update.chosen_inline_result.from_user + event = update.chosen_inline_result + elif update.callback_query: + update_type = "callback_query" + if update.callback_query.message: + chat = update.callback_query.message.chat + from_user = update.callback_query.from_user + event = update.callback_query + elif update.shipping_query: + update_type = "shipping_query" + from_user = update.shipping_query.from_user + event = update.shipping_query + elif update.pre_checkout_query: + update_type = "pre_checkout_query" + from_user = update.pre_checkout_query.from_user + event = update.pre_checkout_query + elif update.poll: + update_type = "poll" + event = update.poll + else: + raise SkipHandler + + observer = self.observers[update_type] + if from_user: + User.set_current(from_user) + if chat: + Chat.set_current(chat) + async for result in observer.trigger(event, **kwargs): + return result + + for router in self.sub_routers: + kwargs.update(event_router=router) + async for result in router.update_handler.trigger(update, **kwargs): + return result + + raise SkipHandler + + def emit_startup(self, *args, **kwargs): + self.startup.trigger(*args, **kwargs) + for router in self.sub_routers: + router.emit_startup(*args, **kwargs) + + def emit_shutdown(self, *args, **kwargs): + self.startup.trigger(*args, **kwargs) + for router in self.sub_routers: + router.emit_startup(*args, **kwargs) diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py index 12d22737..eaaeb1a3 100644 --- a/aiogram/utils/mixins.py +++ b/aiogram/utils/mixins.py @@ -34,6 +34,7 @@ T = TypeVar("T") class ContextInstanceMixin: def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) cls.__context_instance = contextvars.ContextVar(f"instance_{cls.__name__}") return cls