aiogram/aiogram/fsm/storage/redis.py

232 lines
7 KiB
Python
Raw Normal View History

2021-10-11 01:30:19 +03:00
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
2021-10-11 01:30:19 +03:00
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock
from redis.typing import ExpiryT
from aiogram import Bot
2022-07-10 00:21:34 +03:00
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
DEFAULT_DESTINY,
BaseEventIsolation,
BaseStorage,
StateType,
StorageKey,
)
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
2021-10-11 01:30:19 +03:00
class KeyBuilder(ABC):
"""
Base class for Redis key builder
"""
@abstractmethod
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
2021-10-12 01:11:53 +03:00
"""
This method should be implemented in subclasses
:param key: contextual key
:param part: part of the record
:return: key to be used in Redis queries
"""
2021-10-11 01:30:19 +03:00
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple Redis key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id and optional destiny.
"""
def __init__(
2021-10-12 01:11:53 +03:00
self,
*,
prefix: str = "fsm",
separator: str = ":",
with_bot_id: bool = False,
with_destiny: bool = False,
2021-10-11 01:30:19 +03:00
) -> None:
2021-10-12 01:11:53 +03:00
"""
:param prefix: prefix for all records
:param separator: separator
:param with_bot_id: include Bot id in the key
:param with_destiny: include destiny key
"""
2021-10-11 01:30:19 +03:00
self.prefix = prefix
2021-10-12 01:11:53 +03:00
self.separator = separator
2021-10-11 01:30:19 +03:00
self.with_bot_id = with_bot_id
self.with_destiny = with_destiny
def build(self, key: StorageKey, part: Literal["data", "state", "lock"]) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:
raise ValueError(
"Redis key builder is not configured to use key destiny other the default.\n"
"\n"
"Probably, you should set `with_destiny=True` in for DefaultKeyBuilder.\n"
"E.g: `RedisStorage(redis, key_builder=DefaultKeyBuilder(with_destiny=True))`"
)
2021-10-11 01:30:19 +03:00
parts.append(part)
2021-10-12 01:11:53 +03:00
return self.separator.join(parts)
2021-10-11 01:30:19 +03:00
class RedisStorage(BaseStorage):
2021-10-12 01:11:53 +03:00
"""
Redis storage required :code:`aioredis` package installed (:code:`pip install aioredis`)
"""
def __init__(
self,
redis: Redis,
2021-10-11 01:30:19 +03:00
key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[ExpiryT] = None,
data_ttl: Optional[ExpiryT] = None,
) -> None:
2021-10-12 01:11:53 +03:00
"""
:param redis: Instance of Redis connection
:param key_builder: builder that helps to convert contextual key to string
:param state_ttl: TTL for state records
:param data_ttl: TTL for data records
:param lock_kwargs: Custom arguments for Redis lock
"""
2021-10-11 01:30:19 +03:00
if key_builder is None:
key_builder = DefaultKeyBuilder()
self.redis = redis
2021-10-11 01:30:19 +03:00
self.key_builder = key_builder
self.state_ttl = state_ttl
self.data_ttl = data_ttl
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "RedisStorage":
2021-10-12 01:11:53 +03:00
"""
Create an instance of :class:`RedisStorage` with specifying the connection string
:param url: for example :code:`redis://user:password@host:port/db`
:param connection_kwargs: see :code:`aioredis` docs
:param kwargs: arguments to be passed to :class:`RedisStorage`
:return: an instance of :class:`RedisStorage`
"""
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
def create_isolation(self, **kwargs: Any) -> "RedisEventIsolation":
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
async def close(self) -> None:
await self.redis.close()
async def set_state(
self,
bot: Bot,
2021-10-11 01:30:19 +03:00
key: StorageKey,
state: StateType = None,
) -> None:
2021-10-11 01:30:19 +03:00
redis_key = self.key_builder.build(key, "state")
if state is None:
2021-10-11 01:30:19 +03:00
await self.redis.delete(redis_key)
else:
await self.redis.set(
2021-10-11 01:30:19 +03:00
redis_key,
cast(str, state.state if isinstance(state, State) else state),
ex=self.state_ttl,
)
async def get_state(
self,
bot: Bot,
2021-10-11 01:30:19 +03:00
key: StorageKey,
) -> Optional[str]:
2021-10-11 01:30:19 +03:00
redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
async def set_data(
self,
bot: Bot,
2021-10-11 01:30:19 +03:00
key: StorageKey,
data: Dict[str, Any],
) -> None:
2021-10-11 01:30:19 +03:00
redis_key = self.key_builder.build(key, "data")
if not data:
2021-10-11 01:30:19 +03:00
await self.redis.delete(redis_key)
return
2021-10-11 01:30:19 +03:00
await self.redis.set(
redis_key,
bot.session.json_dumps(data),
ex=self.data_ttl,
2021-10-11 01:30:19 +03:00
)
async def get_data(
self,
bot: Bot,
2021-10-11 01:30:19 +03:00
key: StorageKey,
) -> Dict[str, Any]:
2021-10-11 01:30:19 +03:00
redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], bot.session.json_loads(value))
class RedisEventIsolation(BaseEventIsolation):
def __init__(
self,
redis: Redis,
key_builder: Optional[KeyBuilder] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.key_builder = key_builder
self.lock_kwargs = lock_kwargs
@classmethod
def from_url(
cls,
url: str,
connection_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "RedisEventIsolation":
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
@asynccontextmanager
async def lock(
self,
bot: Bot,
key: StorageKey,
) -> AsyncGenerator[None, None]:
redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock):
yield None
async def close(self) -> None:
pass