mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge pull request #276 from vanyakosmos/fix-redis-closing
fix redis pool connection closing
This commit is contained in:
commit
f8fa313403
4 changed files with 89 additions and 21 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -60,3 +60,6 @@ docs/html
|
|||
|
||||
# i18n/l10n
|
||||
*.mo
|
||||
|
||||
# pynev
|
||||
.python-version
|
||||
|
|
|
|||
|
|
@ -44,19 +44,19 @@ class RedisStorage(BaseStorage):
|
|||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._kwargs = kwargs
|
||||
|
||||
self._redis: aioredis.RedisConnection = None
|
||||
self._redis: typing.Optional[aioredis.RedisConnection] = None
|
||||
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||
|
||||
async def close(self):
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
del self._redis
|
||||
self._redis = None
|
||||
async with self._connection_lock:
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
|
||||
async def wait_closed(self):
|
||||
if self._redis:
|
||||
return await self._redis.wait_closed()
|
||||
return True
|
||||
async with self._connection_lock:
|
||||
if self._redis:
|
||||
return await self._redis.wait_closed()
|
||||
return True
|
||||
|
||||
async def redis(self) -> aioredis.RedisConnection:
|
||||
"""
|
||||
|
|
@ -64,7 +64,7 @@ class RedisStorage(BaseStorage):
|
|||
"""
|
||||
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||
async with self._connection_lock:
|
||||
if self._redis is None:
|
||||
if self._redis is None or self._redis.closed:
|
||||
self._redis = await aioredis.create_connection((self._host, self._port),
|
||||
db=self._db, password=self._password, ssl=self._ssl,
|
||||
loop=self._loop,
|
||||
|
|
@ -144,7 +144,7 @@ class RedisStorage(BaseStorage):
|
|||
record_data.update(data, **kwargs)
|
||||
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data)
|
||||
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
"""
|
||||
Get list of all stored chat's and user's
|
||||
|
||||
|
|
@ -220,11 +220,11 @@ class RedisStorage2(BaseStorage):
|
|||
|
||||
"""
|
||||
def __init__(self, host: str = 'localhost', port=6379, db=None, password=None,
|
||||
ssl=None, pool_size=10, loop=None, prefix='fsm',
|
||||
state_ttl: int = 0,
|
||||
data_ttl: int = 0,
|
||||
bucket_ttl: int = 0,
|
||||
**kwargs):
|
||||
ssl=None, pool_size=10, loop=None, prefix='fsm',
|
||||
state_ttl: int = 0,
|
||||
data_ttl: int = 0,
|
||||
bucket_ttl: int = 0,
|
||||
**kwargs):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
|
|
@ -239,7 +239,7 @@ class RedisStorage2(BaseStorage):
|
|||
self._data_ttl = data_ttl
|
||||
self._bucket_ttl = bucket_ttl
|
||||
|
||||
self._redis: aioredis.RedisConnection = None
|
||||
self._redis: typing.Optional[aioredis.RedisConnection] = None
|
||||
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||
|
||||
async def redis(self) -> aioredis.Redis:
|
||||
|
|
@ -248,7 +248,7 @@ class RedisStorage2(BaseStorage):
|
|||
"""
|
||||
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||
async with self._connection_lock:
|
||||
if self._redis is None:
|
||||
if self._redis is None or self._redis.closed:
|
||||
self._redis = await aioredis.create_redis_pool((self._host, self._port),
|
||||
db=self._db, password=self._password, ssl=self._ssl,
|
||||
minsize=1, maxsize=self._pool_size,
|
||||
|
|
@ -262,8 +262,6 @@ class RedisStorage2(BaseStorage):
|
|||
async with self._connection_lock:
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
del self._redis
|
||||
self._redis = None
|
||||
|
||||
async def wait_closed(self):
|
||||
async with self._connection_lock:
|
||||
|
|
@ -357,7 +355,7 @@ class RedisStorage2(BaseStorage):
|
|||
keys = await conn.keys(self.generate_key('*'))
|
||||
await conn.delete(*keys)
|
||||
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
"""
|
||||
Get list of all stored chat's and user's
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,35 @@
|
|||
# pytest_plugins = "pytest_asyncio.plugin"
|
||||
import pytest
|
||||
from _pytest.config import UsageError
|
||||
import aioredis.util
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--redis", default=None,
|
||||
help="run tests which require redis connection")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
redis_uri = config.getoption("--redis")
|
||||
if redis_uri is None:
|
||||
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
|
||||
for item in items:
|
||||
if "redis" in item.keywords:
|
||||
item.add_marker(skip_redis)
|
||||
return
|
||||
try:
|
||||
address, options = aioredis.util.parse_url(redis_uri)
|
||||
assert isinstance(address, tuple), "Only redis and rediss schemas are supported, eg redis://foo."
|
||||
except AssertionError as e:
|
||||
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def redis_options(request):
|
||||
redis_uri = request.config.getoption("--redis")
|
||||
(host, port), options = aioredis.util.parse_url(redis_uri)
|
||||
options.update({'host': host, 'port': port})
|
||||
return options
|
||||
|
|
|
|||
33
tests/contrib/fsm_storage/test_redis.py
Normal file
33
tests/contrib/fsm_storage/test_redis.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def store(redis_options):
|
||||
s = RedisStorage2(**redis_options)
|
||||
try:
|
||||
yield s
|
||||
finally:
|
||||
conn = await s.redis()
|
||||
await conn.flushdb()
|
||||
await s.close()
|
||||
await s.wait_closed()
|
||||
|
||||
|
||||
@pytest.mark.redis
|
||||
class TestRedisStorage2:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_get(self, store):
|
||||
assert await store.get_data(chat='1234') == {}
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_and_open_connection(self, store):
|
||||
await store.set_data(chat='1234', data={'foo': 'bar'})
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'}
|
||||
pool_id = id(store._redis)
|
||||
await store.close()
|
||||
assert await store.get_data(chat='1234') == {'foo': 'bar'} # new pool was opened at this point
|
||||
assert id(store._redis) != pool_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue