From 5b479b38450b9d29dad71fc9d6f71d35a79bc50e Mon Sep 17 00:00:00 2001 From: mpa Date: Sat, 22 Aug 2020 21:22:36 +0400 Subject: [PATCH] fix(types): use StorageDataT in storage.dict fix(tests): name setter properly --- aiogram/dispatcher/_typedef.py | 2 +- aiogram/dispatcher/storage/dict.py | 17 ++++++++--------- .../test_dispatcher/test_state/test_context.py | 17 ++++++++++++++++- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/aiogram/dispatcher/_typedef.py b/aiogram/dispatcher/_typedef.py index 5cec07bc..9332df35 100644 --- a/aiogram/dispatcher/_typedef.py +++ b/aiogram/dispatcher/_typedef.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Mapping, TypeVar if TYPE_CHECKING: # pragma: no cover - StorageDataT = TypeVar("StorageDataT", bound=Mapping[str, Any]) + StorageDataT = TypeVar("StorageDataT", bound=Mapping[Any, Any]) else: StorageDataT = TypeVar("StorageDataT", bound=Mapping) diff --git a/aiogram/dispatcher/storage/dict.py b/aiogram/dispatcher/storage/dict.py index 6966f7eb..70c82a8f 100644 --- a/aiogram/dispatcher/storage/dict.py +++ b/aiogram/dispatcher/storage/dict.py @@ -4,14 +4,15 @@ from typing import Any, Dict, Optional from typing_extensions import TypedDict from .base import BaseStorage +from .._typedef import StorageDataT class _UserStorageMetaData(TypedDict): state: Optional[str] - data: Dict[str, Any] + data: Dict[Any, Any] -class DictStorage(BaseStorage[Dict[str, Any]]): +class DictStorage(BaseStorage[StorageDataT]): """ Python dictionary data structure based state storage. Not the most persistent storage, not recommended for in-production environments. @@ -28,21 +29,19 @@ class DictStorage(BaseStorage[Dict[str, Any]]): self._make_spot_for_key(key) return self._data[key]["state"] - async def get_data(self, key: str) -> Dict[str, Any]: + async def get_data(self, key: str) -> StorageDataT: self._make_spot_for_key(key=key) - return copy.deepcopy(self._data[key]["data"]) + return copy.deepcopy(self._data[key]["data"]) # type: ignore - async def update_data(self, key: str, data: Optional[Dict[str, Any]] = None) -> None: - if data is None: - data = {} + async def update_data(self, key: str, data: Optional[StorageDataT] = None) -> None: self._make_spot_for_key(key=key) - self._data[key]["data"].update(data) + self._data[key]["data"].update({} if data is None else data) # type: ignore async def set_state(self, key: str, state: Optional[str] = None) -> None: self._make_spot_for_key(key=key) self._data[key]["state"] = state - async def set_data(self, key: str, data: Optional[Dict[str, Any]] = None) -> None: + async def set_data(self, key: str, data: Optional[StorageDataT] = None) -> None: self._make_spot_for_key(key=key) self._data[key]["data"] = copy.deepcopy(data) or {} # type: ignore diff --git a/tests/test_dispatcher/test_state/test_context.py b/tests/test_dispatcher/test_state/test_context.py index 412bf73f..dd3774bc 100644 --- a/tests/test_dispatcher/test_state/test_context.py +++ b/tests/test_dispatcher/test_state/test_context.py @@ -58,6 +58,21 @@ class TestCurrentUserContext: ctx = CurrentUserContext(storage, chat_id, user_id, key_maker=my_key_maker) assert ctx.key == my_key_maker(chat_id, user_id) == key_maker_const_result + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("setter_method", "getter_method", "value"), + ( + ("set_state", "get_state", "some state"), + ("set_data", "get_data", {"some": "data"}), + ) + ) + async def test_setters_getters(self, storage, setter_method, getter_method, value): + chat_id, user_id = 1, 2 + ctx = CurrentUserContext(storage, chat_id, user_id) + + assert await getattr(ctx, setter_method)(value) is None + assert await getattr(ctx, getter_method)() == value + @pytest.mark.asyncio @pytest.mark.parametrize("setter_method", ("set_state", "set_data")) async def test_setters(self, storage, setter_method): @@ -84,7 +99,7 @@ class TestCurrentUserContext: @pytest.mark.asyncio @pytest.mark.parametrize("reseter_method", ("reset_data", "reset_state", "finish")) - async def test_setters(self, storage, reseter_method): + async def test_resetters(self, storage, reseter_method): chat_id, user_id = 1, 2 ctx = CurrentUserContext(storage, chat_id, user_id)