add asyncio locks manager

This commit is contained in:
darksidecat 2021-05-17 14:30:46 +03:00
parent 9451a085d1
commit 8d5ae99c4b
4 changed files with 152 additions and 4 deletions

View file

@ -29,7 +29,7 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
async with self.storage.lock():
async with self.storage.lock(self.lock_key(data)):
return await handler(event, data)
return await handler(event, data)
@ -55,3 +55,15 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
def get_context(self, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id)
def lock_key(self, data: Dict[str, Any]) -> str:
user = data.get("event_from_user")
chat = data.get("event_chat")
user_id = user.id if user else None
chat_id = chat.id if chat else user_id
if chat_id is not None and user_id is not None:
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
return f"FSM:{chat_id}:{user_id}"

View file

@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
from aiogram.utils.lockmanager import LockManager
@dataclass
@ -19,11 +20,11 @@ class MemoryStorage(BaseStorage):
self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict(
lambda: defaultdict(MemoryStorageRecord)
)
self._lock = Lock()
self._lock_storage: Dict[str, Lock] = {}
@asynccontextmanager
async def lock(self) -> AsyncGenerator[None, None]:
async with self._lock:
async def lock(self, key: str) -> AsyncGenerator[None, None]:
async with LockManager(storage_data=self._lock_storage, key=key):
yield None
async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None:

View file

@ -0,0 +1,70 @@
from __future__ import annotations
import logging
from asyncio import Lock
from types import TracebackType
from typing import Dict, Optional, Type
logger = logging.getLogger(__name__)
class CantDeleteWithWaiters(Exception):
"""Sync primitive with waiters cant be deleted"""
pass
class LockManager:
def __init__(self, storage_data: Dict[str, Lock], key: str = "global"):
self.storage_data = storage_data
self.key = key
self._current_lock: Optional[Lock] = None
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.release()
async def acquire(self) -> None:
self._current_lock = self.get_lock(self.key)
await self._current_lock.acquire()
def release(self) -> None:
if not self._current_lock:
self._current_lock = self.get_lock(self.key)
try:
self._current_lock.release()
finally:
if not self._current_lock.locked() and not self._current_lock._waiters: # type: ignore
self.del_lock(self.key)
def get_lock(self, key: str) -> Lock:
"""
Return Lock from storage,
if key not exist in storage then create lock
:param key: key
:return: Lock
"""
return self.storage_data.setdefault(key, Lock())
def del_lock(self, key: str) -> None:
"""
Delete lock from storage,
raise CantDeleteWithWaiters exception
if key for deleting not found logging this
:param key: key
:return:
"""
lock = self.storage_data.get(key)
if not lock:
logger.warning("Can`t find Lock by key to delete %s" % key)
elif lock._waiters: # type: ignore
raise CantDeleteWithWaiters("Can`t delete Lock with waiters %s" % lock)
else:
del self.storage_data[key]

View file

@ -0,0 +1,65 @@
import asyncio
import pytest
from aiogram.utils.lockmanager import LockManager, CantDeleteWithWaiters
@pytest.fixture()
def storage_data():
return {}
def test_lock_manager_without_context_add_lock(storage_data):
lock_manager = LockManager(storage_data)
lock_manager.get_lock(key="test")
assert lock_manager.storage_data
assert "test" in lock_manager.storage_data
@pytest.mark.asyncio
async def test_lock_manager_with_context(storage_data):
lock_manager = LockManager(storage_data, key="test")
locks_acquire_result = []
expected = [False, True, False]
locks_acquire_result.append(bool(lock_manager.storage_data))
async with lock_manager:
assert "test" in lock_manager.storage_data
locks_acquire_result.append(bool(lock_manager.storage_data))
locks_acquire_result.append(bool(lock_manager.storage_data))
assert locks_acquire_result == expected
@pytest.mark.asyncio
async def test_lock_manager_raise_waiters_exc(storage_data):
async def task(lock_manager):
async with lock_manager:
await asyncio.sleep(0.1)
async def task2(lock_manager):
async with lock_manager:
await asyncio.sleep(0.1)
async def task_del(lock_manager):
await asyncio.sleep(0.05)
with pytest.raises(CantDeleteWithWaiters):
lock_manager.del_lock("test")
lock_manager = LockManager(storage_data, key="test")
await asyncio.gather(task(lock_manager), task2(lock_manager), task_del(lock_manager))
@pytest.mark.asyncio
async def test_lock_manager_log_miss_key(caplog, storage_data):
lock_manager = LockManager(storage_data)
lock_manager.del_lock("test")
assert "Can`t find Lock by key to delete" in caplog.text
@pytest.mark.asyncio
async def test_lock_manager_release_first(storage_data):
lock_manager = LockManager(storage_data, key="test")
with pytest.raises(RuntimeError):
lock_manager.release()
assert not lock_manager.storage_data