diff --git a/aiogram/dispatcher/_typedef.py b/aiogram/dispatcher/_typedef.py index 03928ce3..5cec07bc 100644 --- a/aiogram/dispatcher/_typedef.py +++ b/aiogram/dispatcher/_typedef.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Mapping, TypeVar -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover StorageDataT = TypeVar("StorageDataT", bound=Mapping[str, Any]) else: StorageDataT = TypeVar("StorageDataT", bound=Mapping) diff --git a/aiogram/dispatcher/state/context.py b/aiogram/dispatcher/state/context.py index 757f1d6f..acf7f41d 100644 --- a/aiogram/dispatcher/state/context.py +++ b/aiogram/dispatcher/state/context.py @@ -7,7 +7,7 @@ from .._typedef import StorageDataT def _default_key_maker(chat_id: Optional[int] = None, user_id: Optional[int] = None) -> str: if chat_id is None and user_id is None: - raise ValueError("`user` or `chat` parameter is required but no one is provided!") + raise ValueError("`user_id` or `chat_id` parameter is required but no one is provided!") if user_id is None and chat_id is not None: user_id = chat_id @@ -26,12 +26,8 @@ class CurrentUserContext(Generic[StorageDataT]): user_id: Optional[int], key_maker: Callable[[Optional[int], Optional[int]], str] = _default_key_maker, ): - assert ( - chat_id or user_id - ) is not None, "Either chat_id or user_id should be non-None value" - - self.storage = storage self.key = key_maker(chat_id, user_id) + self.storage = storage async def get_state(self) -> Optional[str]: return await self.storage.get_state(self.key) @@ -41,12 +37,14 @@ class CurrentUserContext(Generic[StorageDataT]): async def update_data(self, data: Optional[StorageDataT] = None, **kwargs: Any) -> None: if data is not None and not isinstance(data, Mapping): - raise ValueError("Data is expected to be a map") # todo + raise ValueError( + "type for `data` is expected to be a subtype of `collections.Mapping`" + ) temp_data: Dict[str, Any] = {} if isinstance(data, Mapping): - temp_data.update(**data) + temp_data.update(data) temp_data.update(**kwargs) diff --git a/aiogram/dispatcher/storage/base.py b/aiogram/dispatcher/storage/base.py index f76f9181..81215ab8 100644 --- a/aiogram/dispatcher/storage/base.py +++ b/aiogram/dispatcher/storage/base.py @@ -1,5 +1,4 @@ import abc - from typing import Generic, Optional, TypeVar _DataT = TypeVar("_DataT") diff --git a/aiogram/dispatcher/storage/dict.py b/aiogram/dispatcher/storage/dict.py index f6ceaf86..6966f7eb 100644 --- a/aiogram/dispatcher/storage/dict.py +++ b/aiogram/dispatcher/storage/dict.py @@ -44,7 +44,7 @@ class DictStorage(BaseStorage[Dict[str, Any]]): async def set_data(self, key: str, data: Optional[Dict[str, Any]] = None) -> None: self._make_spot_for_key(key=key) - self._data[key]["data"] = copy.deepcopy(data) # type: ignore + self._data[key]["data"] = copy.deepcopy(data) or {} # type: ignore async def wait_closed(self) -> None: pass diff --git a/tests/test_dispatcher/test_state/__init__.py b/tests/test_dispatcher/test_state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dispatcher/test_state/test_context.py b/tests/test_dispatcher/test_state/test_context.py new file mode 100644 index 00000000..6f0dc132 --- /dev/null +++ b/tests/test_dispatcher/test_state/test_context.py @@ -0,0 +1,118 @@ +import contextlib +from collections import Mapping + +import pytest + +from aiogram.dispatcher.state.context import CurrentUserContext, _default_key_maker +from aiogram.dispatcher.storage.dict import DictStorage + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +@pytest.fixture(scope="function") +def storage() -> DictStorage: + return DictStorage() + + +@contextlib.contextmanager +def patch_dict_storage_method(method: str): + with patch( + f"aiogram.dispatcher.storage.dict.DictStorage.{method}", new_callable=CoroutineMock, + ) as mocked: + yield mocked + + +def test_default_key_maker(): + chat_id, user_id = None, None + with pytest.raises(ValueError): + _default_key_maker(chat_id, user_id) + + chat_id, user_id = 1, None + assert _default_key_maker(chat_id, user_id) == f"{chat_id}:{chat_id}" + + chat_id, user_id = None, 1 + assert _default_key_maker(chat_id, user_id) == f"{user_id}:{user_id}" + + chat_id, user_id = 2 ** 8, 2 ** 10 + assert _default_key_maker(chat_id, user_id) == f"{chat_id}:{user_id}" + + +class TestCurrentUserContext: + def test_init(self, storage): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + assert not hasattr(ctx, "__dict__") + assert ctx.storage == storage + assert ctx.key == _default_key_maker(chat_id, user_id) + + def test_custom_key_maker(self, storage): + key_maker_const_result = "mpa" + + def my_key_maker(chat_id: int, user_id: int): + return key_maker_const_result + + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id, key_maker=my_key_maker) + assert ctx.key == my_key_maker(chat_id, user_id) == key_maker_const_result + + @pytest.mark.asyncio + @pytest.mark.parametrize("setter_method", ("set_state", "set_data")) + async def test_setters(self, storage, setter_method): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + + with patch_dict_storage_method(setter_method) as mocked: + await getattr(ctx, setter_method)("some") + mocked.assert_awaited() + + @pytest.mark.asyncio + @pytest.mark.parametrize("getter_method", ("set_state", "set_data")) + async def test_getters(self, storage, getter_method): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + + with patch_dict_storage_method(getter_method) as mocked: + await getattr(ctx, getter_method)() + mocked.assert_awaited() + + @pytest.mark.asyncio + @pytest.mark.parametrize("reseter_method", ("reset_data", "reset_state", "finish")) + async def test_setters(self, storage, reseter_method): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + + with patch_dict_storage_method(reseter_method) as mocked: + await getattr(ctx, reseter_method)() + mocked.assert_awaited() + + @pytest.mark.asyncio + async def test_update_data(self, storage): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + + with patch_dict_storage_method("update_data") as mocked: + await ctx.update_data() + mocked.assert_awaited() + + with pytest.raises( + ValueError, + match="type for `data` is expected to be a subtype of `collections.Mapping`", + ): + await ctx.update_data(data="definetely not mapping") + + class LegitMapping(Mapping): + def __getitem__(self, k): + return "value" + + def __len__(self): + return 1 + + def __iter__(self): + yield "key" + + new_data = LegitMapping() + assert await ctx.update_data(data=new_data) is None + assert await ctx.get_data() == new_data diff --git a/tests/test_dispatcher/test_storage/test_base.py b/tests/test_dispatcher/test_storage/test_base.py new file mode 100644 index 00000000..37fc6103 --- /dev/null +++ b/tests/test_dispatcher/test_storage/test_base.py @@ -0,0 +1,56 @@ +from typing import Any, Optional + +import pytest + +from aiogram.dispatcher.storage.base import BaseStorage + +try: + from asynctest import CoroutineMock, patch +except ImportError: + from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore + + +STORAGE_ABSTRACT_METHODS = { + "get_data", + "get_state", + "set_data", + "set_state", + "update_data", + "close", + "wait_closed", +} + + +class TestBaseStorage: + def test_do_not_impl_abstract_methods(self): + with pytest.raises(TypeError): + + class etcd(BaseStorage): # example of bad not implemented storage + nothing = lambda: None + + etcd() + + def test_do_impl_abstract_methods(self): + class good(BaseStorage): + async def get_state(self, key: str) -> Optional[str]: + pass + + async def set_state(self, key: str, state: Optional[str]) -> None: + pass + + async def get_data(self, key: str) -> Any: + pass + + async def set_data(self, key: str, data: Optional[Any]) -> None: + pass + + async def update_data(self, key: str, data: Any) -> None: + pass + + async def close(self) -> None: + pass + + async def wait_closed(self) -> None: + pass + + good() diff --git a/tests/test_dispatcher/test_storage/test_dict.py b/tests/test_dispatcher/test_storage/test_dict.py new file mode 100644 index 00000000..7961161f --- /dev/null +++ b/tests/test_dispatcher/test_storage/test_dict.py @@ -0,0 +1,76 @@ +from typing import Dict + +import pytest + +from aiogram.dispatcher.storage.dict import DictStorage + + +class TestDictStorage: + def test_init(self): + storage: DictStorage[Dict[str, str]] = DictStorage() + assert not storage._data + assert isinstance(storage._data, dict) + + def test_key_resolution(self): + storage: DictStorage[Dict[str, str]] = DictStorage() + key = "L:R" + assert key not in storage._data + storage._make_spot_for_key(key) + assert key in storage._data + assert storage._data.get(key) == {"state": None, "data": {}} + + @pytest.mark.asyncio + async def test_get_set_state(self): + storage: DictStorage[Dict[str, str]] = DictStorage() + + key = "L:R" + assert await storage.get_state(key=key) is None # initial state is None + + new_state = "sotm" + assert await storage.set_state(key=key, state=new_state) is None + assert await storage.get_state(key=key) == new_state + + assert await storage.reset_state(key=key) is None + assert await storage.get_state(key=key) is None + + @pytest.mark.asyncio + async def test_get_set_update_reset_data(self): + storage: DictStorage[Dict[str, str]] = DictStorage() + + key = "L:R" + assert await storage.get_data(key=key) == {} # initial data is empty dict + + new_data = {"sotm": "sotm"} + assert await storage.set_data(key=key, data=new_data) is None + assert await storage.get_data(key=key) == new_data + + updated_data = {"mpa": "mpa"} + assert await storage.update_data(key=key, data=updated_data) is None + assert await storage.get_data(key=key) == {**new_data, **updated_data} + + assert await storage.reset_data(key=key) is None + assert await storage.get_data(key=key) == {} # reset_data makes data empty dict + + @pytest.mark.asyncio + async def test_finish(self): + # finish turns data into initial one + storage: DictStorage[Dict[str, str]] = DictStorage() + + key = "L:R" + await storage.set_data(key=key, data={"mpa": "mpa"}) + await storage.set_state(key=key, state="mpa::mpa::mpa::mpa") + assert await storage.get_data(key=key) + assert await storage.get_state(key=key) + + assert await storage.finish(key=key) is None + assert await storage.get_data(key=key) == {} + assert await storage.get_state(key=key) is None + + @pytest.mark.asyncio + async def test_close_wait_closed(self): + storage: DictStorage[Dict[str, str]] = DictStorage() + + storage._data = {"corrupt": "True"} + assert await storage.close() is None + assert storage._data == {} + assert await storage.wait_closed() is None # noop