diff --git a/aiogram/client/telegram.py b/aiogram/client/telegram.py index bf7e7067..2363e24e 100644 --- a/aiogram/client/telegram.py +++ b/aiogram/client/telegram.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Any, Protocol -class WrapLocalFileCallbackCallbackProtocol(Protocol): +class WrapLocalFileCallbackCallbackProtocol(Protocol): # pragma: no cover def __call__(self, value: str) -> str: pass diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index fffe0262..9269429e 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -331,7 +331,7 @@ class Dispatcher(Router): try: try: await waiter - except CancelledError: # pragma: nocover + except CancelledError: # pragma: no cover process_updates.remove_done_callback(release_waiter) process_updates.cancel() raise diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py new file mode 100644 index 00000000..e0a0386f --- /dev/null +++ b/aiogram/dispatcher/filters/state.py @@ -0,0 +1,46 @@ +from inspect import isclass +from typing import Any, Dict, Optional, Sequence, Type, Union, cast + +from pydantic import validator + +from aiogram.dispatcher.filters import BaseFilter +from aiogram.dispatcher.fsm.state import State, StatesGroup +from aiogram.types import TelegramObject + +StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]] + + +class StateFilter(BaseFilter): + """ + State filter + """ + + state: Union[StateType, Sequence[StateType]] + + class Config: + arbitrary_types_allowed = True + + @validator("state", always=True) + def _validate_state(cls, v: Union[StateType, Sequence[StateType]]) -> Sequence[StateType]: + if ( + isinstance(v, (str, State, StatesGroup)) + or (isclass(v) and issubclass(v, StatesGroup)) + or v is None + ): + return [v] + return v + + async def __call__( + self, obj: Union[TelegramObject], raw_state: Optional[str] = None + ) -> Union[bool, Dict[str, Any]]: + allowed_states = cast(Sequence[StateType], self.state) + for allowed_state in allowed_states: + if isinstance(allowed_state, str) or allowed_state is None: + if allowed_state == "*": + return True + return raw_state == allowed_state + elif isinstance(allowed_state, (State, StatesGroup)): + return allowed_state(event=obj, raw_state=raw_state) + elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup): + return allowed_state()(event=obj, raw_state=raw_state) + return False diff --git a/aiogram/utils/i18n/__init__.py b/aiogram/utils/i18n/__init__.py new file mode 100644 index 00000000..9cb3a353 --- /dev/null +++ b/aiogram/utils/i18n/__init__.py @@ -0,0 +1,21 @@ +from .babel import I18n +from .context import get_i18n, gettext, lazy_gettext, lazy_ngettext, ngettext +from .middleware import ( + ConstI18nMiddleware, + FSMI18nMiddleware, + I18nMiddleware, + SimpleI18nMiddleware, +) + +__all__ = ( + "I18n", + "I18nMiddleware", + "SimpleI18nMiddleware", + "ConstI18nMiddleware", + "FSMI18nMiddleware", + "gettext", + "lazy_gettext", + "ngettext", + "lazy_ngettext", + "get_i18n", +) diff --git a/aiogram/utils/i18n/babel.py b/aiogram/utils/i18n/babel.py new file mode 100644 index 00000000..a5c3d34e --- /dev/null +++ b/aiogram/utils/i18n/babel.py @@ -0,0 +1,98 @@ +import gettext +import os +from contextvars import ContextVar +from os import PathLike +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +from aiogram.utils.i18n.lazy_proxy import LazyProxy + + +class I18n: + def __init__( + self, + *, + path: Union[str, PathLike[str], Path], + locale: str = "en", + domain: str = "messages", + ) -> None: + self.path = path + self.locale = locale + self.domain = domain + self.ctx_locale = ContextVar("aiogram_ctx_locale", default=locale) + self.locales = self.find_locales() + + @property + def current_locale(self) -> str: + return self.ctx_locale.get() + + @current_locale.setter + def current_locale(self, value: str) -> None: + self.ctx_locale.set(value) + + def find_locales(self) -> Dict[str, gettext.GNUTranslations]: + """ + Load all compiled locales from path + + :return: dict with locales + """ + translations: Dict[str, gettext.GNUTranslations] = {} + + for name in os.listdir(self.path): + if not os.path.isdir(os.path.join(self.path, name)): + continue + mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo") + + if os.path.exists(mo_path): + with open(mo_path, "r") as fp: + translations[name] = gettext.GNUTranslations(fp) + elif os.path.exists(mo_path[:-2] + "po"): + raise RuntimeError(f"Found locale '{name}' but this language is not compiled!") + + return translations + + def reload(self) -> None: + """ + Hot reload locales + """ + self.locales = self.find_locales() + + @property + def available_locales(self) -> Tuple[str, ...]: + """ + list of loaded locales + + :return: + """ + return tuple(self.locales.keys()) + + def gettext( + self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None + ) -> str: + """ + Get text + + :param singular: + :param plural: + :param n: + :param locale: + :return: + """ + if locale is None: + locale = self.current_locale + + if locale not in self.locales: + if n == 1: + return singular + return plural if plural else singular + + translator = self.locales[locale] + + if plural is None: + return translator.gettext(singular) + return translator.ngettext(singular, plural, n) + + def lazy_gettext( + self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None + ) -> LazyProxy: + return LazyProxy(self.gettext, singular=singular, plural=plural, n=n, locale=locale) diff --git a/aiogram/utils/i18n/context.py b/aiogram/utils/i18n/context.py new file mode 100644 index 00000000..1c72c555 --- /dev/null +++ b/aiogram/utils/i18n/context.py @@ -0,0 +1,30 @@ +from contextvars import ContextVar +from typing import Any, Optional + +from aiogram.utils.i18n.babel import I18n +from aiogram.utils.i18n.lazy_proxy import LazyProxy + +ctx_i18n: ContextVar[Optional[I18n]] = ContextVar("aiogram_ctx_i18n", default=None) + + +def get_i18n() -> I18n: + i18n = ctx_i18n.get() + if i18n is None: + raise LookupError("I18n context is not set") + return i18n + + +def gettext(*args: Any, **kwargs: Any) -> str: + return get_i18n().gettext(*args, **kwargs) + + +def _lazy_lazy_gettext(*args: Any, **kwargs: Any) -> str: + return str(get_i18n().lazy_gettext(*args, **kwargs)) + + +def lazy_gettext(*args: Any, **kwargs: Any) -> LazyProxy: + return LazyProxy(_lazy_lazy_gettext, *args, **kwargs) + + +ngettext = gettext +lazy_ngettext = lazy_gettext diff --git a/aiogram/utils/i18n/lazy_proxy.py b/aiogram/utils/i18n/lazy_proxy.py new file mode 100644 index 00000000..6852540d --- /dev/null +++ b/aiogram/utils/i18n/lazy_proxy.py @@ -0,0 +1,13 @@ +from typing import Any + +try: + from babel.support import LazyProxy +except ImportError: # pragma: no cover + + class LazyProxy: # type: ignore + def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None: + raise RuntimeError( + "LazyProxy can be used only when Babel installed\n" + "Just install Babel (`pip install Babel`) " + "or aiogram with i18n support (`pip install aiogram[i18n]`)" + ) diff --git a/aiogram/utils/i18n/middleware.py b/aiogram/utils/i18n/middleware.py new file mode 100644 index 00000000..27fac10b --- /dev/null +++ b/aiogram/utils/i18n/middleware.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Dict, Optional, Set, TypeVar, cast + +try: + from babel import Locale +except ImportError: # pragma: no cover + Locale = None + +from aiogram import BaseMiddleware, Router +from aiogram.dispatcher.fsm.context import FSMContext +from aiogram.types import TelegramObject, User +from aiogram.utils.i18n.babel import I18n +from aiogram.utils.i18n.context import ctx_i18n + +T = TypeVar("T") + + +class I18nMiddleware(BaseMiddleware, ABC): + def __init__( + self, + i18n: I18n, + gettext_key: Optional[str] = "gettext", + middleware_key: str = "i18n_middleware", + ) -> None: + self.i18n = i18n + self.gettext_key = gettext_key + self.middleware_key = middleware_key + + def setup( + self: BaseMiddleware, router: Router, exclude: Optional[Set[str]] = None + ) -> BaseMiddleware: + if exclude is None: + exclude = {"update"} + for event_name, observer in router.observers.items(): + if event_name in exclude: + continue + observer.outer_middleware(self) + return self + + async def __call__( + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: TelegramObject, + data: Dict[str, Any], + ) -> Any: + self.i18n.current_locale = await self.get_locale(event=event, data=data) + + if self.gettext_key: + data[self.gettext_key] = self.i18n + if self.middleware_key: + data[self.middleware_key] = self + token = ctx_i18n.set(self.i18n) + try: + return await handler(event, data) + finally: + ctx_i18n.reset(token) + + @abstractmethod + async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: + pass + + +class SimpleI18nMiddleware(I18nMiddleware): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + if Locale is None: + raise RuntimeError( + f"{type(self).__name__} can be used only when Babel installed\n" + "Just install Babel (`pip install Babel`) " + "or aiogram with i18n support (`pip install aiogram[i18n]`)" + ) + + async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: + if Locale is None: + raise RuntimeError( + f"{type(self).__name__} can be used only when Babel installed\n" + "Just install Babel (`pip install Babel`) " + "or aiogram with i18n support (`pip install aiogram[i18n]`)" + ) + + event_from_user: Optional[User] = data.get("event_from_user", None) + if event_from_user is None: + return self.i18n.locale + locale = Locale.parse(event_from_user.language_code, sep="-") + if locale.language not in self.i18n.available_locales: + return self.i18n.locale + return cast(str, locale.language) + + +class ConstI18nMiddleware(I18nMiddleware): + def __init__(self, locale: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.locale = locale + + async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: + return self.locale + + +class FSMI18nMiddleware(SimpleI18nMiddleware): + def __init__(self, *args: Any, key: str = "locale", **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.key = key + + async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str: + fsm_context: Optional[FSMContext] = data.get("state") + locale = None + if fsm_context: + fsm_data = await fsm_context.get_data() + locale = fsm_data.get(self.key, None) + if not locale: + locale = await super().get_locale(event=event, data=data) + if fsm_context: + await fsm_context.update_data(data={self.key: locale}) + return locale + + async def set_locale(self, state: FSMContext, locale: str) -> None: + await state.update_data(data={self.key: locale}) + self.i18n.current_locale = locale diff --git a/tests/test_api/test_client/test_bot.py b/tests/test_api/test_client/test_bot.py index b36006cc..d2c9d56a 100644 --- a/tests/test_api/test_client/test_bot.py +++ b/tests/test_api/test_client/test_bot.py @@ -1,4 +1,6 @@ import io +import os +from tempfile import mkstemp import aiofiles import pytest @@ -6,6 +8,7 @@ from aresponses import ResponsesMockServer from aiogram import Bot from aiogram.client.session.aiohttp import AiohttpSession +from aiogram.client.telegram import TelegramAPIServer from aiogram.methods import GetFile, GetMe from aiogram.types import File, PhotoSize from tests.mocked_bot import MockedBot @@ -128,3 +131,15 @@ class TestBot: await bot.download( [PhotoSize(file_id="file id", file_unique_id="file id", width=123, height=123)] ) + + async def test_download_local_file(self, bot: MockedBot): + bot.session.api = TelegramAPIServer.from_base("http://localhost:8081", is_local=True) + fd, tmp = mkstemp(prefix="test-", suffix=".txt") + value = b"KABOOM" + try: + with open(fd, "wb") as f: + f.write(value) + content = await bot.download_file(tmp) + assert content.getvalue() == value + finally: + os.unlink(tmp) diff --git a/tests/test_dispatcher/test_filters/test_state.py b/tests/test_dispatcher/test_filters/test_state.py new file mode 100644 index 00000000..d551f748 --- /dev/null +++ b/tests/test_dispatcher/test_filters/test_state.py @@ -0,0 +1,49 @@ +from inspect import isclass + +import pytest + +from aiogram.dispatcher.filters import StateFilter +from aiogram.dispatcher.fsm.state import State, StatesGroup +from aiogram.types import Update + +pytestmark = pytest.mark.asyncio + + +class MyGroup(StatesGroup): + state = State() + + +class TestStateFilter: + @pytest.mark.parametrize( + "state", [None, State("test"), MyGroup, MyGroup(), "state", ["state"]] + ) + def test_validator(self, state): + f = StateFilter(state=state) + assert isinstance(f.state, list) + value = f.state[0] + assert ( + isinstance(value, (State, str, MyGroup)) + or (isclass(value) and issubclass(value, StatesGroup)) + or value is None + ) + + @pytest.mark.parametrize( + "state,current_state,result", + [ + [State("state"), "@:state", True], + [[State("state")], "@:state", True], + [MyGroup, "MyGroup:state", True], + [[MyGroup], "MyGroup:state", True], + [MyGroup(), "MyGroup:state", True], + [[MyGroup()], "MyGroup:state", True], + ["*", "state", True], + [None, None, True], + [[None], None, True], + [None, "state", False], + [[], "state", False], + ], + ) + @pytestmark + async def test_filter(self, state, current_state, result): + f = StateFilter(state=state) + assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result diff --git a/tests/test_dispatcher/test_middlewares/__init__.py b/tests/test_dispatcher/test_middlewares/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dispatcher/test_middlewares/test_user_context.py b/tests/test_dispatcher/test_middlewares/test_user_context.py new file mode 100644 index 00000000..8d289c2b --- /dev/null +++ b/tests/test_dispatcher/test_middlewares/test_user_context.py @@ -0,0 +1,14 @@ +import pytest + +from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware + + +async def next_handler(*args, **kwargs): + pass + + +class TestUserContextMiddleware: + @pytest.mark.asyncio + async def test_unexpected_event_type(self): + with pytest.raises(RuntimeError): + await UserContextMiddleware()(next_handler, object(), {})