feat: aioredis v1-v2 adapters #648

This commit is contained in:
Oleg A 2021-08-01 12:00:07 +03:00
parent 580fa2e499
commit 8136ce8ff8

View file

@ -5,6 +5,8 @@ This module has redis storage for finite-state machine based on `aioredis <https
import asyncio
import logging
import typing
from abc import ABC, abstractmethod
from importlib.metadata import version
import aioredis
@ -204,6 +206,138 @@ class RedisStorage(BaseStorage):
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
class AioRedisAdapterBase(ABC):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: typing.Optional[int] = None,
password: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
**kwargs,
):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs
self._prefix = (prefix,)
self._state_ttl = state_ttl
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl
self._redis: typing.Optional["aioredis.Redis"] = None
self._connection_lock = asyncio.Lock(loop=self._loop)
@abstractmethod
async def redis(self) -> aioredis.Redis:
"""Get Redis connection."""
pass
def close(self):
"""Grace shutdown."""
pass
async def wait_closed(self):
pass
@abstractmethod
async def get(self, name):
pass
@abstractmethod
async def set(self, name, value, expire=None):
pass
@abstractmethod
async def delete(self, name):
pass
class AioRedisAdapterV1(AioRedisAdapterBase):
"""Redis adapter for aioredis v1."""
async def redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_redis_pool(
(self._host, self._port),
db=self._db,
password=self._password,
ssl=self._ssl,
minsize=1,
maxsize=self._pool_size,
loop=self._loop,
**self._kwargs,
)
return self._redis
def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
async def get(self, name):
redis = await self.redis()
return await redis.get(name, encoding="utf8")
async def set(self, name, value, expire=None):
redis = await self.redis()
return await redis.set(name, value, expire=expire)
async def delete(self, name):
redis = await self.redis()
return await redis.delete(name)
class AioRedisAdapterV2(AioRedisAdapterBase):
"""Redis adapter for aioredis v2."""
async def redis(self) -> aioredis.Redis:
"""Get Redis connection."""
async with self._connection_lock: # to prevent race
if self._redis is None:
self._redis = aioredis.Redis(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
max_connections=self._pool_size,
**self._kwargs,
)
return self._redis
async def get(self, name):
redis = await self.redis()
return await redis.get(name)
async def set(self, name, value, expire=None):
redis = await self.redis()
return await redis.set(name, value, ex=expire)
async def delete(self, name):
redis = await self.redis()
return await redis.delete(name)
class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
@ -224,12 +358,22 @@ class RedisStorage2(BaseStorage):
await dp.storage.wait_closed()
"""
def __init__(self, host: str = 'localhost', port=6379, db=None, password=None,
ssl=None, pool_size=10, loop=None, prefix='fsm',
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
**kwargs):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: typing.Optional[int] = None,
password: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
pool_size: int = 10,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
prefix: str = "fsm",
state_ttl: int = 0,
data_ttl: int = 0,
bucket_ttl: int = 0,
**kwargs,
):
self._host = host
self._port = port
self._db = db
@ -244,49 +388,53 @@ class RedisStorage2(BaseStorage):
self._data_ttl = data_ttl
self._bucket_ttl = bucket_ttl
self._redis: typing.Optional[aioredis.RedisConnection] = None
self._redis: typing.Optional[AioRedisAdapterBase] = None
self._connection_lock = asyncio.Lock(loop=self._loop)
async def redis(self) -> aioredis.Redis:
"""
Get Redis connection
"""
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None or self._redis.closed:
self._redis = await aioredis.create_redis_pool((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
minsize=1, maxsize=self._pool_size,
loop=self._loop, **self._kwargs)
async def redis(self) -> AioRedisAdapterBase:
"""Get adapter based on aioredis version."""
if self._redis is None:
redis_version = version("aioredis").split(".")[0]
connection_data = dict(
host=self._host,
port=self._port,
db=self._db,
password=self._password,
ssl=self._ssl,
pool_size=self._pool_size,
loop=self._loop,
**self._kwargs,
)
if redis_version == 1:
self._redis = AioRedisAdapterV1(**connection_data)
elif redis_version == 2:
self._redis = AioRedisAdapterV2(**connection_data)
return self._redis
def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))
async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
if self._redis:
return self._redis.close()
async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
if self._redis:
return await self._redis.wait_closed()
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis()
return await redis.get(key, encoding='utf8') or self.resolve_state(default)
return await redis.get(key) or self.resolve_state(default)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8')
raw_result = await redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}
@ -327,7 +475,7 @@ class RedisStorage2(BaseStorage):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8')
raw_result = await redis.get(key)
if raw_result:
return json.loads(raw_result)
return default or {}