From 7db1572fd3ac348e5d88666c87fa99e792a9e04e Mon Sep 17 00:00:00 2001 From: Boger Date: Wed, 25 Mar 2020 15:49:43 +0300 Subject: [PATCH] Return DataMixin --- aiogram/utils/mixins.py | 28 ++++++++++++++++++++++++- tests/test_utils/test_mixins.py | 36 ++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py index 719cfaed..2972e1b7 100644 --- a/aiogram/utils/mixins.py +++ b/aiogram/utils/mixins.py @@ -8,13 +8,39 @@ from typing import ( TypeVar, cast, overload, + Dict, ) -__all__ = ("ContextInstanceMixin",) +__all__ = ("ContextInstanceMixin", "DataMixin") from typing_extensions import Literal +class DataMixin: + @property + def data(self) -> Dict[str, Any]: + data: Optional[Dict[str, Any]] = getattr(self, "_data", None) + if data is None: + data = {} + setattr(self, "_data", data) + return data + + def __getitem__(self, key: str) -> Any: + return self.data[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.data[key] = value + + def __delitem__(self, key: str) -> None: + del self.data[key] + + def __contains__(self, key: str) -> bool: + return key in self.data + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + return self.data.get(key, default) + + ContextInstance = TypeVar("ContextInstance") diff --git a/tests/test_utils/test_mixins.py b/tests/test_utils/test_mixins.py index e7b77727..fe96b3e1 100644 --- a/tests/test_utils/test_mixins.py +++ b/tests/test_utils/test_mixins.py @@ -1,11 +1,45 @@ import pytest -from aiogram.utils.mixins import ContextInstanceMixin +from aiogram.utils.mixins import ( + ContextInstanceMixin, + DataMixin, +) + class ContextObject(ContextInstanceMixin['ContextObject']): pass +class DataObject(DataMixin): + pass + + +class TestDataMixin: + def test_store_value(self): + obj = DataObject() + obj["foo"] = 42 + + assert "foo" in obj + assert obj["foo"] == 42 + assert len(obj.data) == 1 + + def test_remove_value(self): + obj = DataObject() + obj["foo"] = 42 + del obj["foo"] + + assert "key" not in obj + assert len(obj.data) == 0 + + def test_getter(self): + obj = DataObject() + obj["foo"] = 42 + + assert obj.get("foo") == 42 + assert obj.get("bar") is None + assert obj.get("baz", "test") == "test" + + class TestContextInstanceMixin: def test_empty(self): obj = ContextObject()