diff --git a/aiogram/contrib/fsm_storage/rethinkdb.py b/aiogram/contrib/fsm_storage/rethinkdb.py index fd8f1e13..38d24efa 100644 --- a/aiogram/contrib/fsm_storage/rethinkdb.py +++ b/aiogram/contrib/fsm_storage/rethinkdb.py @@ -1,23 +1,18 @@ import asyncio import contextlib import typing -import weakref -import rethinkdb as r +import rethinkdb +from rethinkdb.asyncio_net.net_asyncio import Connection from ...dispatcher.storage import BaseStorage -__all__ = ['RethinkDBStorage', 'ConnectionNotClosed'] +__all__ = ['RethinkDBStorage'] +r = rethinkdb.RethinkDB() r.set_loop_type('asyncio') -class ConnectionNotClosed(Exception): - """ - Indicates that DB connection wasn't closed. - """ - - class RethinkDBStorage(BaseStorage): """ RethinkDB-based storage for FSM. @@ -37,8 +32,17 @@ class RethinkDBStorage(BaseStorage): """ - def __init__(self, host='localhost', port=28015, db='aiogram', table='aiogram', auth_key=None, - user=None, password=None, timeout=20, ssl=None, max_conn=10, loop=None): + def __init__(self, + host: str = 'localhost', + port: int = 28015, + db: str = 'aiogram', + table: str = 'aiogram', + auth_key: typing.Optional[str] = None, + user: typing.Optional[str] = None, + password: typing.Optional[str] = None, + timeout: int = 20, + ssl: typing.Optional[dict] = None, + loop: typing.Optional[asyncio.AbstractEventLoop] = None): self._host = host self._port = port self._db = db @@ -48,65 +52,37 @@ class RethinkDBStorage(BaseStorage): self._password = password self._timeout = timeout self._ssl = ssl or {} + self._loop = loop - self._queue = asyncio.Queue(max_conn) - self._outstanding_connections = weakref.WeakSet() - self._loop = loop or asyncio.get_event_loop() + self._conn: typing.Union[Connection, None] = None - async def get_connection(self): + async def connect(self) -> Connection: """ - Get or create connection. + Get or create a connection. """ - try: - while True: - conn: r.Connection = self._queue.get_nowait() - if conn.is_open(): - break - try: - await conn.close() - except r.ReqlError: - raise ConnectionNotClosed('Exception was caught while closing connection') - except asyncio.QueueEmpty: - if len(self._outstanding_connections) < self._queue.maxsize: - conn = await r.connect(host=self._host, port=self._port, db=self._db, - auth_key=self._auth_key, user=self._user, password=self._password, - timeout=self._timeout, ssl=self._ssl) - else: - conn = await self._queue.get() - - self._outstanding_connections.add(conn) - return conn - - async def put_connection(self, conn): - """ - Return connection to pool. - """ - self._queue.put_nowait(conn) - self._outstanding_connections.remove(conn) + if self._conn is None: + self._conn = await r.connect(host=self._host, + port=self._port, + db=self._db, + auth_key=self._auth_key, + user=self._user, + password=self._password, + timeout=self._timeout, + ssl=self._ssl, + io_loop=self._loop) + return self._conn @contextlib.asynccontextmanager async def connection(self): - conn = await self.get_connection() + conn = await self.connect() yield conn - await self.put_connection(conn) async def close(self): """ - Close all connections. + Close a connection. """ - while True: - try: - conn: r.Connection = self._queue.get_nowait() - except asyncio.QueueEmpty: - break - - self._outstanding_connections.add(conn) - - for conn in self._outstanding_connections: - try: - await conn.close() - except r.ReqlError: - raise ConnectionNotClosed('Exception was caught while closing connection') + self._conn.close() + self._conn = None async def wait_closed(self): """ @@ -118,24 +94,19 @@ class RethinkDBStorage(BaseStorage): default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - result = await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn) - return result + return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - result = await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) - return result + return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn) async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'state': state}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'state': state}}).run(conn) + await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn) async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): @@ -151,10 +122,7 @@ class RethinkDBStorage(BaseStorage): **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'data': data}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn) + await r.table(self._table).insert({'id': chat, user: {'data': data}}, conflict="update").run(conn) def has_bucket(self): return True @@ -163,8 +131,7 @@ class RethinkDBStorage(BaseStorage): default: typing.Optional[dict] = None) -> typing.Dict: chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - result = await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) - return result + return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn) async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Dict = None): @@ -180,10 +147,7 @@ class RethinkDBStorage(BaseStorage): **kwargs): chat, user = map(str, self.check_address(chat=chat, user=user)) async with self.connection() as conn: - if await r.table(self._table).get(chat).run(conn): - await r.table(self._table).get(chat).update({user: {'bucket': bucket}}).run(conn) - else: - await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn) + await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}, conflict="update").run(conn) async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]: """ diff --git a/dev_requirements.txt b/dev_requirements.txt index 72d6a989..ee1fc5e5 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,3 +16,4 @@ sphinx-rtd-theme>=0.3.0 sphinxcontrib-programoutput>=0.11 aresponses>=1.0.0 aiohttp-socks>=0.1.5 +rethinkdb>=2.4.1 \ No newline at end of file