From 3c85ae761f11d6da1dea800fbc6b7caa120db1fe Mon Sep 17 00:00:00 2001 From: Andriy Date: Wed, 16 Nov 2022 20:04:34 +0200 Subject: [PATCH] change sqlite3 to aiosqlite --- aiogram/contrib/fsm_storage/sqlite.py | 81 ++++++++++++++++----------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/aiogram/contrib/fsm_storage/sqlite.py b/aiogram/contrib/fsm_storage/sqlite.py index 34885487..29803777 100644 --- a/aiogram/contrib/fsm_storage/sqlite.py +++ b/aiogram/contrib/fsm_storage/sqlite.py @@ -1,5 +1,5 @@ import typing -import sqlite3 +import aiosqlite from ...utils import json from ...dispatcher.storage import BaseStorage @@ -8,59 +8,71 @@ class SqliteStorage(BaseStorage): def __init__(self, db_name: str = 'aiogram_fsm_storage.db', tbl_name: str = 'aiogram_fsm'): self._db_name = db_name self._tbl_name = tbl_name - self._conn = sqlite3.connect(db_name) - self._cur = self._conn.cursor() + self._conn = None + + async def _get_connect(self) -> aiosqlite.Connection: + if self._conn: + return self._conn + else: + self._conn = await aiosqlite.connect(self._db_name) + await self._conn.execute(f'CREATE TABLE IF NOT EXISTS {self._tbl_name}(chat INTEGER, user INTEGER, state VARCHAR(255), data TEXT, bucket TEXT)') + await self._conn.commit() + return self._conn - self._cur.execute(f'CREATE TABLE IF NOT EXISTS {tbl_name}(chat INTEGER, user INTEGER, state VARCHAR(255), data TEXT, bucket TEXT)') - self._conn.commit() - - def _has_in_db(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None) -> bool: - self._cur.execute(f'SELECT EXISTS(SELECT * FROM {self._tbl_name} WHERE chat=? AND user=?)', (chat, user)) - return self._cur.fetchone()[0] + async def _has_in_db(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None) -> bool: + conn = await self._get_connect() + cursor = await conn.execute(f'SELECT EXISTS(SELECT * FROM {self._tbl_name} WHERE chat=? AND user=?)', (chat, user)) + return (await cursor.fetchone())[0] async def close(self): - self._conn.close() + if self._conn: + await self._conn.close() async def wait_closed(self): pass async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[str] = None) -> typing.Optional[str]: chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'SELECT state FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) - return self._cur.fetchone()[0] + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + cursor = await conn.execute(f'SELECT state FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) + return (await cursor.fetchone())[0] else: return self.resolve_state(default) async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[typing.Dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'SELECT data FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) - data = self._cur.fetchone()[0] + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + cursor = await conn.execute(f'SELECT data FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) + data = (await cursor.fetchone())[0] return json.loads(data) if data else default or {} else: return default or {} async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, state: typing.Optional[typing.AnyStr] = None): chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'UPDATE {self._tbl_name} SET state=? WHERE chat=? AND user=?', (self.resolve_state(state), chat, user)) + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + await conn.execute(f'UPDATE {self._tbl_name} SET state=? WHERE chat=? AND user=?', (self.resolve_state(state), chat, user)) else: - self._cur.execute(f'INSERT INTO {self._tbl_name}(chat, user, state) VALUES(?, ?, ?)', (chat, user, self.resolve_state(state))) - self._conn.commit() + await conn.execute(f'INSERT INTO {self._tbl_name}(chat, user, state) VALUES(?, ?, ?)', (chat, user, self.resolve_state(state))) + await conn.commit() async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict = None): chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'UPDATE {self._tbl_name} SET data=? WHERE chat=? AND user=?', (json.dumps(data) if data else None, chat, user)) + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + await conn.execute(f'UPDATE {self._tbl_name} SET data=? WHERE chat=? AND user=?', (json.dumps(data) if data else None, chat, user)) else: - self._cur.execute(f'INSERT INTO {self._tbl_name}(chat, user, data) VALUES(?, ?, ?)', (chat, user, json.dumps(data) if data else None)) - self._conn.commit() + await conn.execute(f'INSERT INTO {self._tbl_name}(chat, user, data) VALUES(?, ?, ?)', (chat, user, json.dumps(data) if data else None)) + await conn.commit() async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, data: typing.Dict, **kwargs): chat, user = self.check_address(chat=chat, user=user) _data = await self.get_data(chat=chat, user=user) - _data.update(data, **kwargs) + if isinstance(data, dict): + _data.update(data, **kwargs) await self.set_data(chat=chat, user=user, data=_data) def has_bucket(self): @@ -68,23 +80,26 @@ class SqliteStorage(BaseStorage): async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, default: typing.Optional[dict] = None) -> typing.Dict: chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'SELECT bucket FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) - bucket = self._cur.fetchone()[0] + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + cursor = await conn.execute(f'SELECT bucket FROM {self._tbl_name} WHERE chat=? AND user=?', (chat, user)) + bucket = (await cursor.fetchone())[0] return json.loads(bucket) if bucket else default or {} else: return default or {} async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Optional[dict] = None): chat, user = self.check_address(chat=chat, user=user) - if self._has_in_db(chat=chat, user=user): - self._cur.execute(f'UPDATE {self._tbl_name} SET bucket=? WHERE chat=? AND user=?', (json.dumps(bucket) if bucket else None, chat, user)) + conn = await self._get_connect() + if await self._has_in_db(chat=chat, user=user): + await conn.execute(f'UPDATE {self._tbl_name} SET bucket=? WHERE chat=? AND user=?', (json.dumps(bucket) if bucket else None, chat, user)) else: - self._cur.execute(f'INSERT INTO {self._tbl_name}(chat, user, bucket) VALUES(?, ?, ?)', (chat, user, json.dumps(bucket) if bucket else None)) - self._conn.commit() + await conn.execute(f'INSERT INTO {self._tbl_name}(chat, user, bucket) VALUES(?, ?, ?)', (chat, user, json.dumps(bucket) if bucket else None)) + await conn.commit() async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, bucket: typing.Optional[dict] = None, **kwargs): chat, user = self.check_address(chat=chat, user=user) _bucket = await self.get_bucket(chat=chat, user=user) - _bucket.update(bucket, **kwargs) + if isinstance(bucket, dict): + _bucket.update(bucket, **kwargs) await self.set_bucket(chat=chat, user=user, bucket=_bucket)