From 8d5ae99c4ba3b0dbe5193781c7204b3cd29e457d Mon Sep 17 00:00:00 2001 From: darksidecat Date: Mon, 17 May 2021 14:30:46 +0300 Subject: [PATCH] add asyncio locks manager --- aiogram/dispatcher/fsm/middleware.py | 14 ++++- aiogram/dispatcher/fsm/storage/memory.py | 7 ++- aiogram/utils/lockmanager.py | 70 ++++++++++++++++++++++++ tests/test_utils/test_lockmanager.py | 65 ++++++++++++++++++++++ 4 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 aiogram/utils/lockmanager.py create mode 100644 tests/test_utils/test_lockmanager.py diff --git a/aiogram/dispatcher/fsm/middleware.py b/aiogram/dispatcher/fsm/middleware.py index d3d5d8c2..3158a715 100644 --- a/aiogram/dispatcher/fsm/middleware.py +++ b/aiogram/dispatcher/fsm/middleware.py @@ -29,7 +29,7 @@ class FSMContextMiddleware(BaseMiddleware[Update]): if context: data.update({"state": context, "raw_state": await context.get_state()}) if self.isolate_events: - async with self.storage.lock(): + async with self.storage.lock(self.lock_key(data)): return await handler(event, data) return await handler(event, data) @@ -55,3 +55,15 @@ class FSMContextMiddleware(BaseMiddleware[Update]): def get_context(self, chat_id: int, user_id: int) -> FSMContext: return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id) + + def lock_key(self, data: Dict[str, Any]) -> str: + user = data.get("event_from_user") + chat = data.get("event_chat") + user_id = user.id if user else None + chat_id = chat.id if chat else user_id + + if chat_id is not None and user_id is not None: + chat_id, user_id = apply_strategy( + chat_id=chat_id, user_id=user_id, strategy=self.strategy + ) + return f"FSM:{chat_id}:{user_id}" diff --git a/aiogram/dispatcher/fsm/storage/memory.py b/aiogram/dispatcher/fsm/storage/memory.py index 46b6d60b..064c8b10 100644 --- a/aiogram/dispatcher/fsm/storage/memory.py +++ b/aiogram/dispatcher/fsm/storage/memory.py @@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType +from aiogram.utils.lockmanager import LockManager @dataclass @@ -19,11 +20,11 @@ class MemoryStorage(BaseStorage): self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict( lambda: defaultdict(MemoryStorageRecord) ) - self._lock = Lock() + self._lock_storage: Dict[str, Lock] = {} @asynccontextmanager - async def lock(self) -> AsyncGenerator[None, None]: - async with self._lock: + async def lock(self, key: str) -> AsyncGenerator[None, None]: + async with LockManager(storage_data=self._lock_storage, key=key): yield None async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None: diff --git a/aiogram/utils/lockmanager.py b/aiogram/utils/lockmanager.py new file mode 100644 index 00000000..3b8975c1 --- /dev/null +++ b/aiogram/utils/lockmanager.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +from asyncio import Lock +from types import TracebackType +from typing import Dict, Optional, Type + +logger = logging.getLogger(__name__) + + +class CantDeleteWithWaiters(Exception): + """Sync primitive with waiters cant be deleted""" + + pass + + +class LockManager: + def __init__(self, storage_data: Dict[str, Lock], key: str = "global"): + self.storage_data = storage_data + self.key = key + self._current_lock: Optional[Lock] = None + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.release() + + async def acquire(self) -> None: + self._current_lock = self.get_lock(self.key) + await self._current_lock.acquire() + + def release(self) -> None: + if not self._current_lock: + self._current_lock = self.get_lock(self.key) + try: + self._current_lock.release() + finally: + if not self._current_lock.locked() and not self._current_lock._waiters: # type: ignore + self.del_lock(self.key) + + def get_lock(self, key: str) -> Lock: + """ + Return Lock from storage, + if key not exist in storage then create lock + :param key: key + :return: Lock + """ + return self.storage_data.setdefault(key, Lock()) + + def del_lock(self, key: str) -> None: + """ + Delete lock from storage, + raise CantDeleteWithWaiters exception + if key for deleting not found logging this + :param key: key + :return: + """ + lock = self.storage_data.get(key) + if not lock: + logger.warning("Can`t find Lock by key to delete %s" % key) + elif lock._waiters: # type: ignore + raise CantDeleteWithWaiters("Can`t delete Lock with waiters %s" % lock) + else: + del self.storage_data[key] diff --git a/tests/test_utils/test_lockmanager.py b/tests/test_utils/test_lockmanager.py new file mode 100644 index 00000000..01ef5926 --- /dev/null +++ b/tests/test_utils/test_lockmanager.py @@ -0,0 +1,65 @@ +import asyncio +import pytest + +from aiogram.utils.lockmanager import LockManager, CantDeleteWithWaiters + + +@pytest.fixture() +def storage_data(): + return {} + + +def test_lock_manager_without_context_add_lock(storage_data): + lock_manager = LockManager(storage_data) + lock_manager.get_lock(key="test") + assert lock_manager.storage_data + assert "test" in lock_manager.storage_data + + +@pytest.mark.asyncio +async def test_lock_manager_with_context(storage_data): + lock_manager = LockManager(storage_data, key="test") + locks_acquire_result = [] + expected = [False, True, False] + + locks_acquire_result.append(bool(lock_manager.storage_data)) + async with lock_manager: + assert "test" in lock_manager.storage_data + locks_acquire_result.append(bool(lock_manager.storage_data)) + locks_acquire_result.append(bool(lock_manager.storage_data)) + assert locks_acquire_result == expected + + +@pytest.mark.asyncio +async def test_lock_manager_raise_waiters_exc(storage_data): + async def task(lock_manager): + async with lock_manager: + await asyncio.sleep(0.1) + + async def task2(lock_manager): + async with lock_manager: + await asyncio.sleep(0.1) + + async def task_del(lock_manager): + await asyncio.sleep(0.05) + with pytest.raises(CantDeleteWithWaiters): + lock_manager.del_lock("test") + + lock_manager = LockManager(storage_data, key="test") + await asyncio.gather(task(lock_manager), task2(lock_manager), task_del(lock_manager)) + + +@pytest.mark.asyncio +async def test_lock_manager_log_miss_key(caplog, storage_data): + lock_manager = LockManager(storage_data) + lock_manager.del_lock("test") + assert "Can`t find Lock by key to delete" in caplog.text + + +@pytest.mark.asyncio +async def test_lock_manager_release_first(storage_data): + lock_manager = LockManager(storage_data, key="test") + with pytest.raises(RuntimeError): + lock_manager.release() + + assert not lock_manager.storage_data