mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
add asyncio locks manager
This commit is contained in:
parent
9451a085d1
commit
8d5ae99c4b
4 changed files with 152 additions and 4 deletions
|
|
@ -29,7 +29,7 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
|
|||
if context:
|
||||
data.update({"state": context, "raw_state": await context.get_state()})
|
||||
if self.isolate_events:
|
||||
async with self.storage.lock():
|
||||
async with self.storage.lock(self.lock_key(data)):
|
||||
return await handler(event, data)
|
||||
return await handler(event, data)
|
||||
|
||||
|
|
@ -55,3 +55,15 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
|
|||
|
||||
def get_context(self, chat_id: int, user_id: int) -> FSMContext:
|
||||
return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id)
|
||||
|
||||
def lock_key(self, data: Dict[str, Any]) -> str:
|
||||
user = data.get("event_from_user")
|
||||
chat = data.get("event_chat")
|
||||
user_id = user.id if user else None
|
||||
chat_id = chat.id if chat else user_id
|
||||
|
||||
if chat_id is not None and user_id is not None:
|
||||
chat_id, user_id = apply_strategy(
|
||||
chat_id=chat_id, user_id=user_id, strategy=self.strategy
|
||||
)
|
||||
return f"FSM:{chat_id}:{user_id}"
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
|
|||
|
||||
from aiogram.dispatcher.fsm.state import State
|
||||
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
|
||||
from aiogram.utils.lockmanager import LockManager
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -19,11 +20,11 @@ class MemoryStorage(BaseStorage):
|
|||
self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict(
|
||||
lambda: defaultdict(MemoryStorageRecord)
|
||||
)
|
||||
self._lock = Lock()
|
||||
self._lock_storage: Dict[str, Lock] = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self) -> AsyncGenerator[None, None]:
|
||||
async with self._lock:
|
||||
async def lock(self, key: str) -> AsyncGenerator[None, None]:
|
||||
async with LockManager(storage_data=self._lock_storage, key=key):
|
||||
yield None
|
||||
|
||||
async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None:
|
||||
|
|
|
|||
70
aiogram/utils/lockmanager.py
Normal file
70
aiogram/utils/lockmanager.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from asyncio import Lock
|
||||
from types import TracebackType
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CantDeleteWithWaiters(Exception):
|
||||
"""Sync primitive with waiters cant be deleted"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LockManager:
|
||||
def __init__(self, storage_data: Dict[str, Lock], key: str = "global"):
|
||||
self.storage_data = storage_data
|
||||
self.key = key
|
||||
self._current_lock: Optional[Lock] = None
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.release()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
self._current_lock = self.get_lock(self.key)
|
||||
await self._current_lock.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._current_lock:
|
||||
self._current_lock = self.get_lock(self.key)
|
||||
try:
|
||||
self._current_lock.release()
|
||||
finally:
|
||||
if not self._current_lock.locked() and not self._current_lock._waiters: # type: ignore
|
||||
self.del_lock(self.key)
|
||||
|
||||
def get_lock(self, key: str) -> Lock:
|
||||
"""
|
||||
Return Lock from storage,
|
||||
if key not exist in storage then create lock
|
||||
:param key: key
|
||||
:return: Lock
|
||||
"""
|
||||
return self.storage_data.setdefault(key, Lock())
|
||||
|
||||
def del_lock(self, key: str) -> None:
|
||||
"""
|
||||
Delete lock from storage,
|
||||
raise CantDeleteWithWaiters exception
|
||||
if key for deleting not found logging this
|
||||
:param key: key
|
||||
:return:
|
||||
"""
|
||||
lock = self.storage_data.get(key)
|
||||
if not lock:
|
||||
logger.warning("Can`t find Lock by key to delete %s" % key)
|
||||
elif lock._waiters: # type: ignore
|
||||
raise CantDeleteWithWaiters("Can`t delete Lock with waiters %s" % lock)
|
||||
else:
|
||||
del self.storage_data[key]
|
||||
65
tests/test_utils/test_lockmanager.py
Normal file
65
tests/test_utils/test_lockmanager.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
|
||||
from aiogram.utils.lockmanager import LockManager, CantDeleteWithWaiters
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_data():
|
||||
return {}
|
||||
|
||||
|
||||
def test_lock_manager_without_context_add_lock(storage_data):
|
||||
lock_manager = LockManager(storage_data)
|
||||
lock_manager.get_lock(key="test")
|
||||
assert lock_manager.storage_data
|
||||
assert "test" in lock_manager.storage_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_manager_with_context(storage_data):
|
||||
lock_manager = LockManager(storage_data, key="test")
|
||||
locks_acquire_result = []
|
||||
expected = [False, True, False]
|
||||
|
||||
locks_acquire_result.append(bool(lock_manager.storage_data))
|
||||
async with lock_manager:
|
||||
assert "test" in lock_manager.storage_data
|
||||
locks_acquire_result.append(bool(lock_manager.storage_data))
|
||||
locks_acquire_result.append(bool(lock_manager.storage_data))
|
||||
assert locks_acquire_result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_manager_raise_waiters_exc(storage_data):
|
||||
async def task(lock_manager):
|
||||
async with lock_manager:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def task2(lock_manager):
|
||||
async with lock_manager:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def task_del(lock_manager):
|
||||
await asyncio.sleep(0.05)
|
||||
with pytest.raises(CantDeleteWithWaiters):
|
||||
lock_manager.del_lock("test")
|
||||
|
||||
lock_manager = LockManager(storage_data, key="test")
|
||||
await asyncio.gather(task(lock_manager), task2(lock_manager), task_del(lock_manager))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_manager_log_miss_key(caplog, storage_data):
|
||||
lock_manager = LockManager(storage_data)
|
||||
lock_manager.del_lock("test")
|
||||
assert "Can`t find Lock by key to delete" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_manager_release_first(storage_data):
|
||||
lock_manager = LockManager(storage_data, key="test")
|
||||
with pytest.raises(RuntimeError):
|
||||
lock_manager.release()
|
||||
|
||||
assert not lock_manager.storage_data
|
||||
Loading…
Add table
Add a link
Reference in a new issue