Fixed current coverage

This commit is contained in:
Alex Root Junior 2021-09-20 23:38:58 +03:00
parent 9866e321a3
commit a5892f63f4
12 changed files with 407 additions and 2 deletions

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Any, Protocol
class WrapLocalFileCallbackCallbackProtocol(Protocol):
class WrapLocalFileCallbackCallbackProtocol(Protocol): # pragma: no cover
def __call__(self, value: str) -> str:
pass

View file

@ -331,7 +331,7 @@ class Dispatcher(Router):
try:
try:
await waiter
except CancelledError: # pragma: nocover
except CancelledError: # pragma: no cover
process_updates.remove_done_callback(release_waiter)
process_updates.cancel()
raise

View file

@ -0,0 +1,46 @@
from inspect import isclass
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
from pydantic import validator
from aiogram.dispatcher.filters import BaseFilter
from aiogram.dispatcher.fsm.state import State, StatesGroup
from aiogram.types import TelegramObject
StateType = Union[str, None, State, StatesGroup, Type[StatesGroup]]
class StateFilter(BaseFilter):
"""
State filter
"""
state: Union[StateType, Sequence[StateType]]
class Config:
arbitrary_types_allowed = True
@validator("state", always=True)
def _validate_state(cls, v: Union[StateType, Sequence[StateType]]) -> Sequence[StateType]:
if (
isinstance(v, (str, State, StatesGroup))
or (isclass(v) and issubclass(v, StatesGroup))
or v is None
):
return [v]
return v
async def __call__(
self, obj: Union[TelegramObject], raw_state: Optional[str] = None
) -> Union[bool, Dict[str, Any]]:
allowed_states = cast(Sequence[StateType], self.state)
for allowed_state in allowed_states:
if isinstance(allowed_state, str) or allowed_state is None:
if allowed_state == "*":
return True
return raw_state == allowed_state
elif isinstance(allowed_state, (State, StatesGroup)):
return allowed_state(event=obj, raw_state=raw_state)
elif isclass(allowed_state) and issubclass(allowed_state, StatesGroup):
return allowed_state()(event=obj, raw_state=raw_state)
return False

View file

@ -0,0 +1,21 @@
from .babel import I18n
from .context import get_i18n, gettext, lazy_gettext, lazy_ngettext, ngettext
from .middleware import (
ConstI18nMiddleware,
FSMI18nMiddleware,
I18nMiddleware,
SimpleI18nMiddleware,
)
__all__ = (
"I18n",
"I18nMiddleware",
"SimpleI18nMiddleware",
"ConstI18nMiddleware",
"FSMI18nMiddleware",
"gettext",
"lazy_gettext",
"ngettext",
"lazy_ngettext",
"get_i18n",
)

View file

@ -0,0 +1,98 @@
import gettext
import os
from contextvars import ContextVar
from os import PathLike
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from aiogram.utils.i18n.lazy_proxy import LazyProxy
class I18n:
def __init__(
self,
*,
path: Union[str, PathLike[str], Path],
locale: str = "en",
domain: str = "messages",
) -> None:
self.path = path
self.locale = locale
self.domain = domain
self.ctx_locale = ContextVar("aiogram_ctx_locale", default=locale)
self.locales = self.find_locales()
@property
def current_locale(self) -> str:
return self.ctx_locale.get()
@current_locale.setter
def current_locale(self, value: str) -> None:
self.ctx_locale.set(value)
def find_locales(self) -> Dict[str, gettext.GNUTranslations]:
"""
Load all compiled locales from path
:return: dict with locales
"""
translations: Dict[str, gettext.GNUTranslations] = {}
for name in os.listdir(self.path):
if not os.path.isdir(os.path.join(self.path, name)):
continue
mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo")
if os.path.exists(mo_path):
with open(mo_path, "r") as fp:
translations[name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + "po"):
raise RuntimeError(f"Found locale '{name}' but this language is not compiled!")
return translations
def reload(self) -> None:
"""
Hot reload locales
"""
self.locales = self.find_locales()
@property
def available_locales(self) -> Tuple[str, ...]:
"""
list of loaded locales
:return:
"""
return tuple(self.locales.keys())
def gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
) -> str:
"""
Get text
:param singular:
:param plural:
:param n:
:param locale:
:return:
"""
if locale is None:
locale = self.current_locale
if locale not in self.locales:
if n == 1:
return singular
return plural if plural else singular
translator = self.locales[locale]
if plural is None:
return translator.gettext(singular)
return translator.ngettext(singular, plural, n)
def lazy_gettext(
self, singular: str, plural: Optional[str] = None, n: int = 1, locale: Optional[str] = None
) -> LazyProxy:
return LazyProxy(self.gettext, singular=singular, plural=plural, n=n, locale=locale)

View file

@ -0,0 +1,30 @@
from contextvars import ContextVar
from typing import Any, Optional
from aiogram.utils.i18n.babel import I18n
from aiogram.utils.i18n.lazy_proxy import LazyProxy
ctx_i18n: ContextVar[Optional[I18n]] = ContextVar("aiogram_ctx_i18n", default=None)
def get_i18n() -> I18n:
i18n = ctx_i18n.get()
if i18n is None:
raise LookupError("I18n context is not set")
return i18n
def gettext(*args: Any, **kwargs: Any) -> str:
return get_i18n().gettext(*args, **kwargs)
def _lazy_lazy_gettext(*args: Any, **kwargs: Any) -> str:
return str(get_i18n().lazy_gettext(*args, **kwargs))
def lazy_gettext(*args: Any, **kwargs: Any) -> LazyProxy:
return LazyProxy(_lazy_lazy_gettext, *args, **kwargs)
ngettext = gettext
lazy_ngettext = lazy_gettext

View file

@ -0,0 +1,13 @@
from typing import Any
try:
from babel.support import LazyProxy
except ImportError: # pragma: no cover
class LazyProxy: # type: ignore
def __init__(self, func: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError(
"LazyProxy can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)

View file

@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Callable, Dict, Optional, Set, TypeVar, cast
try:
from babel import Locale
except ImportError: # pragma: no cover
Locale = None
from aiogram import BaseMiddleware, Router
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.types import TelegramObject, User
from aiogram.utils.i18n.babel import I18n
from aiogram.utils.i18n.context import ctx_i18n
T = TypeVar("T")
class I18nMiddleware(BaseMiddleware, ABC):
def __init__(
self,
i18n: I18n,
gettext_key: Optional[str] = "gettext",
middleware_key: str = "i18n_middleware",
) -> None:
self.i18n = i18n
self.gettext_key = gettext_key
self.middleware_key = middleware_key
def setup(
self: BaseMiddleware, router: Router, exclude: Optional[Set[str]] = None
) -> BaseMiddleware:
if exclude is None:
exclude = {"update"}
for event_name, observer in router.observers.items():
if event_name in exclude:
continue
observer.outer_middleware(self)
return self
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
self.i18n.current_locale = await self.get_locale(event=event, data=data)
if self.gettext_key:
data[self.gettext_key] = self.i18n
if self.middleware_key:
data[self.middleware_key] = self
token = ctx_i18n.set(self.i18n)
try:
return await handler(event, data)
finally:
ctx_i18n.reset(token)
@abstractmethod
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
pass
class SimpleI18nMiddleware(I18nMiddleware):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if Locale is None:
raise RuntimeError(
f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
if Locale is None:
raise RuntimeError(
f"{type(self).__name__} can be used only when Babel installed\n"
"Just install Babel (`pip install Babel`) "
"or aiogram with i18n support (`pip install aiogram[i18n]`)"
)
event_from_user: Optional[User] = data.get("event_from_user", None)
if event_from_user is None:
return self.i18n.locale
locale = Locale.parse(event_from_user.language_code, sep="-")
if locale.language not in self.i18n.available_locales:
return self.i18n.locale
return cast(str, locale.language)
class ConstI18nMiddleware(I18nMiddleware):
def __init__(self, locale: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.locale = locale
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
return self.locale
class FSMI18nMiddleware(SimpleI18nMiddleware):
def __init__(self, *args: Any, key: str = "locale", **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.key = key
async def get_locale(self, event: TelegramObject, data: Dict[str, Any]) -> str:
fsm_context: Optional[FSMContext] = data.get("state")
locale = None
if fsm_context:
fsm_data = await fsm_context.get_data()
locale = fsm_data.get(self.key, None)
if not locale:
locale = await super().get_locale(event=event, data=data)
if fsm_context:
await fsm_context.update_data(data={self.key: locale})
return locale
async def set_locale(self, state: FSMContext, locale: str) -> None:
await state.update_data(data={self.key: locale})
self.i18n.current_locale = locale

View file

@ -1,4 +1,6 @@
import io
import os
from tempfile import mkstemp
import aiofiles
import pytest
@ -6,6 +8,7 @@ from aresponses import ResponsesMockServer
from aiogram import Bot
from aiogram.client.session.aiohttp import AiohttpSession
from aiogram.client.telegram import TelegramAPIServer
from aiogram.methods import GetFile, GetMe
from aiogram.types import File, PhotoSize
from tests.mocked_bot import MockedBot
@ -128,3 +131,15 @@ class TestBot:
await bot.download(
[PhotoSize(file_id="file id", file_unique_id="file id", width=123, height=123)]
)
async def test_download_local_file(self, bot: MockedBot):
bot.session.api = TelegramAPIServer.from_base("http://localhost:8081", is_local=True)
fd, tmp = mkstemp(prefix="test-", suffix=".txt")
value = b"KABOOM"
try:
with open(fd, "wb") as f:
f.write(value)
content = await bot.download_file(tmp)
assert content.getvalue() == value
finally:
os.unlink(tmp)

View file

@ -0,0 +1,49 @@
from inspect import isclass
import pytest
from aiogram.dispatcher.filters import StateFilter
from aiogram.dispatcher.fsm.state import State, StatesGroup
from aiogram.types import Update
pytestmark = pytest.mark.asyncio
class MyGroup(StatesGroup):
state = State()
class TestStateFilter:
@pytest.mark.parametrize(
"state", [None, State("test"), MyGroup, MyGroup(), "state", ["state"]]
)
def test_validator(self, state):
f = StateFilter(state=state)
assert isinstance(f.state, list)
value = f.state[0]
assert (
isinstance(value, (State, str, MyGroup))
or (isclass(value) and issubclass(value, StatesGroup))
or value is None
)
@pytest.mark.parametrize(
"state,current_state,result",
[
[State("state"), "@:state", True],
[[State("state")], "@:state", True],
[MyGroup, "MyGroup:state", True],
[[MyGroup], "MyGroup:state", True],
[MyGroup(), "MyGroup:state", True],
[[MyGroup()], "MyGroup:state", True],
["*", "state", True],
[None, None, True],
[[None], None, True],
[None, "state", False],
[[], "state", False],
],
)
@pytestmark
async def test_filter(self, state, current_state, result):
f = StateFilter(state=state)
assert bool(await f(obj=Update(update_id=42), raw_state=current_state)) is result

View file

@ -0,0 +1,14 @@
import pytest
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
async def next_handler(*args, **kwargs):
pass
class TestUserContextMiddleware:
@pytest.mark.asyncio
async def test_unexpected_event_type(self):
with pytest.raises(RuntimeError):
await UserContextMiddleware()(next_handler, object(), {})