Merge branch 'dev-2.x' into refact-user-chat

This commit is contained in:
Anthony Byuraev 2020-07-05 14:54:00 +03:00
commit 352588cb06
7 changed files with 250 additions and 473 deletions

View file

@ -1 +1,217 @@
from .mongo_aiomongo import MongoStorage
"""
This module has mongo storage for finite-state machine
based on `motor <https://github.com/mongodb/motor>`_ driver
"""
from typing import Union, Dict, Optional, List, Tuple, AnyStr
import pymongo
try:
import motor
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
except ModuleNotFoundError as e:
import warnings
warnings.warn("Install motor with `pip install motor`")
raise e
from ...dispatcher.storage import BaseStorage
STATE = 'aiogram_state'
DATA = 'aiogram_data'
BUCKET = 'aiogram_bucket'
COLLECTIONS = (STATE, DATA, BUCKET)
class MongoStorage(BaseStorage):
"""
Mongo-based storage for FSM.
Usage:
.. code-block:: python3
storage = MongoStorage(host='localhost', port=27017, db_name='aiogram_fsm')
dp = Dispatcher(bot, storage=storage)
And need to close Mongo client connections when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None,
username=None, password=None, index=True, **kwargs):
self._host = host
self._port = port
self._db_name: str = db_name
self._uri = uri
self._username = username
self._password = password
self._kwargs = kwargs
self._mongo: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
self._index = index
async def get_client(self) -> AsyncIOMotorClient:
if isinstance(self._mongo, AsyncIOMotorClient):
return self._mongo
if self._uri:
try:
self._mongo = AsyncIOMotorClient(self._uri)
except pymongo.errors.ConfigurationError as e:
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
import logging
logger = logging.getLogger("aiogram")
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
raise e
return self._mongo
uri = 'mongodb://'
# set username + password
if self._username and self._password:
uri += f'{self._username}:{self._password}@'
# set host and port (optional)
uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}'
# define and return client
self._mongo = AsyncIOMotorClient(uri)
return self._mongo
async def get_db(self) -> AsyncIOMotorDatabase:
"""
Get Mongo db
This property is awaitable.
"""
if isinstance(self._db, AsyncIOMotorDatabase):
return self._db
mongo = await self.get_client()
self._db = mongo.get_database(self._db_name)
if self._index:
await self.apply_index(self._db)
return self._db
@staticmethod
async def apply_index(db):
for collection in COLLECTIONS:
await db[collection].create_index(keys=[('chat', 1), ('user', 1)],
name="chat_user_idx", unique=True, background=True)
async def close(self):
if self._mongo:
self._mongo.close()
async def wait_closed(self):
return True
async def set_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
state: Optional[AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
if state is None:
await db[STATE].delete_one(filter={'chat': chat, 'user': user})
else:
await db[STATE].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'state': state}}, upsert=True)
async def get_state(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[str] = None) -> Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[STATE].find_one(filter={'chat': chat, 'user': user})
return result.get('state') if result else default
async def set_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
data: Dict = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
await db[DATA].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'data': data}}, upsert=True)
async def get_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[DATA].find_one(filter={'chat': chat, 'user': user})
return result.get('data') if result else default or {}
async def update_data(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
data: Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
result = await db[BUCKET].find_one(filter={'chat': chat, 'user': user})
return result.get('bucket') if result else default or {}
async def set_bucket(self, *, chat: Union[str, int, None] = None, user: Union[str, int, None] = None,
bucket: Dict = None):
chat, user = self.check_address(chat=chat, user=user)
db = await self.get_db()
await db[BUCKET].update_one(filter={'chat': chat, 'user': user},
update={'$set': {'bucket': bucket}}, upsert=True)
async def update_bucket(self, *, chat: Union[str, int, None] = None,
user: Union[str, int, None] = None,
bucket: Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat=chat, user=user, bucket=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
db = await self.get_db()
await db[STATE].drop()
if full:
await db[DATA].drop()
await db[BUCKET].drop()
async def get_states_list(self) -> List[Tuple[int, int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
db = await self.get_db()
result = []
items = await db[STATE].find().to_list()
for item in items:
result.append(
(int(item['chat']), int(item['user']))
)
return result

View file

@ -1,222 +0,0 @@
"""
This module has mongo storage for finite-state machine
based on `aiomongo <https://github.com/ZeoAlliance/aiomongo`_ driver
"""
from typing import Union, Dict, Optional, List, Tuple, AnyStr
try:
import aiomongo
from aiomongo import AioMongoClient, Database
except ModuleNotFoundError as e:
import warnings
warnings.warn("Install aiomongo with `pip install aiomongo`")
raise e
from ...dispatcher.storage import BaseStorage
from ...utils.deprecated import deprecated
STATE = 'aiogram_state'
DATA = 'aiogram_data'
BUCKET = 'aiogram_bucket'
COLLECTIONS = (STATE, DATA, BUCKET)
@deprecated("mongo_aiomongo.MongoStorage is deprecated in favour to mongo_motor.MongoStorage, and will be removed in aiogram v2.11.",
stacklevel=3)
class MongoStorage(BaseStorage):
"""
Mongo-based storage for FSM.
Usage:
.. code-block:: python3
storage = MongoStorage(host='localhost', port=27017, db_name='aiogram_fsm')
dp = Dispatcher(bot, storage=storage)
And need to close Mongo client connections when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm',
username=None, password=None, index=True, **kwargs):
self._host = host
self._port = port
self._db_name: str = db_name
self._username = username
self._password = password
self._kwargs = kwargs
self._mongo: Union[AioMongoClient, None] = None
self._db: Union[Database, None] = None
self._index = index
async def get_client(self) -> AioMongoClient:
if isinstance(self._mongo, AioMongoClient):
return self._mongo
uri = 'mongodb://'
# set username + password
if self._username and self._password:
uri += f'{self._username}:{self._password}@'
# set host and port (optional)
uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}'
# define and return client
self._mongo = await aiomongo.create_client(uri)
return self._mongo
async def get_db(self) -> Database:
"""
Get Mongo db
This property is awaitable.
"""
if isinstance(self._db, Database):
return self._db
mongo = await self.get_client()
self._db = mongo.get_database(self._db_name)
if self._index:
await self.apply_index(self._db)
return self._db
@staticmethod
async def apply_index(db):
for collection in COLLECTIONS:
await db[collection].create_index(keys=[('chat_id', 1), ('user_id', 1)],
name="chat_user_idx", unique=True, background=True)
async def close(self):
if self._mongo:
self._mongo.close()
async def wait_closed(self):
if self._mongo:
return await self._mongo.wait_closed()
return True
async def set_state(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
state: Optional[AnyStr] = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
if state is None:
await db[STATE].delete_one(filter={'chat_id': chat_id, 'user_id': user_id})
else:
await db[STATE].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'state': state}}, upsert=True)
async def get_state(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[str] = None) -> Optional[str]:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[STATE].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('state') if result else default
async def set_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
data: Dict = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
await db[DATA].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'data': data}}, upsert=True)
async def get_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[DATA].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('data') if result else default or {}
async def update_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
data: Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat_id=chat_id, user_id=user_id, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat_id=chat_id, user_id=user_id, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[BUCKET].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('bucket') if result else default or {}
async def set_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
bucket: Dict = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
await db[BUCKET].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'bucket': bucket}}, upsert=True)
async def update_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
bucket: Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat_id=chat_id, user_id=user_id)
temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat_id=chat_id, user_id=user_id, bucket=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
db = await self.get_db()
await db[STATE].drop()
if full:
await db[DATA].drop()
await db[BUCKET].drop()
async def get_states_list(self) -> List[Tuple[int, int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
db = await self.get_db()
result = []
items = await db[STATE].find().to_list()
for item in items:
result.append(
(int(item['chat_id']), int(item['user_id']))
)
return result

View file

@ -1,232 +0,0 @@
"""
This module has mongo storage for finite-state machine
based on `motor <https://github.com/mongodb/motor>`_ driver
"""
from typing import Union, Dict, Optional, List, Tuple, AnyStr
import pymongo
try:
import motor
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
except ModuleNotFoundError as e:
import warnings
warnings.warn("Install motor with `pip install motor`")
raise e
from ...dispatcher.storage import BaseStorage
STATE = 'aiogram_state'
DATA = 'aiogram_data'
BUCKET = 'aiogram_bucket'
COLLECTIONS = (STATE, DATA, BUCKET)
class MongoStorage(BaseStorage):
"""
Mongo-based storage for FSM.
Usage:
.. code-block:: python3
storage = MongoStorage(host='localhost', port=27017, db_name='aiogram_fsm')
dp = Dispatcher(bot, storage=storage)
And need to close Mongo client connections when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=27017, db_name='aiogram_fsm', uri=None,
username=None, password=None, index=True, **kwargs):
self._host = host
self._port = port
self._db_name: str = db_name
self._uri = uri
self._username = username
self._password = password
self._kwargs = kwargs
self._mongo: Optional[AsyncIOMotorClient] = None
self._db: Optional[AsyncIOMotorDatabase] = None
self._index = index
async def get_client(self) -> AsyncIOMotorClient:
if isinstance(self._mongo, AsyncIOMotorClient):
return self._mongo
if self._uri:
try:
self._mongo = AsyncIOMotorClient(self._uri)
except pymongo.errors.ConfigurationError as e:
if "query() got an unexpected keyword argument 'lifetime'" in e.args[0]:
import logging
logger = logging.getLogger("aiogram")
logger.warning("Run `pip install dnspython==1.16.0` in order to fix ConfigurationError. More information: https://github.com/mongodb/mongo-python-driver/pull/423#issuecomment-528998245")
raise e
return self._mongo
uri = 'mongodb://'
# set username + password
if self._username and self._password:
uri += f'{self._username}:{self._password}@'
# set host and port (optional)
uri += f'{self._host}:{self._port}' if self._host else f'localhost:{self._port}'
# define and return client
self._mongo = AsyncIOMotorClient(uri)
return self._mongo
async def get_db(self) -> AsyncIOMotorDatabase:
"""
Get Mongo db
This property is awaitable.
"""
if isinstance(self._db, AsyncIOMotorDatabase):
return self._db
mongo = await self.get_client()
self._db = mongo.get_database(self._db_name)
if self._index:
await self.apply_index(self._db)
return self._db
@staticmethod
async def apply_index(db):
for collection in COLLECTIONS:
await db[collection].create_index(keys=[('chat_id', 1), ('user_id', 1)],
name="chat_user_idx", unique=True, background=True)
async def close(self):
if self._mongo:
self._mongo.close()
async def wait_closed(self):
return True
async def set_state(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
state: Optional[AnyStr] = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
if state is None:
await db[STATE].delete_one(filter={'chat_id': chat_id, 'user_id': user_id})
else:
await db[STATE].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'state': state}}, upsert=True)
async def get_state(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[str] = None) -> Optional[str]:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[STATE].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('state') if result else default
async def set_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
data: Dict = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
await db[DATA].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'data': data}}, upsert=True)
async def get_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[DATA].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('data') if result else default or {}
async def update_data(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
data: Dict = None, **kwargs):
if data is None:
data = {}
temp_data = await self.get_data(chat_id=chat_id, user_id=user_id, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat_id=chat_id, user_id=user_id, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
default: Optional[dict] = None) -> Dict:
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
result = await db[BUCKET].find_one(filter={'chat_id': chat_id, 'user_id': user_id})
return result.get('bucket') if result else default or {}
async def set_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
bucket: Dict = None):
chat_id, user_id = self.check_address(chat_id=chat_id, user_id=user_id)
db = await self.get_db()
await db[BUCKET].update_one(filter={'chat_id': chat_id, 'user_id': user_id},
update={'$set': {'bucket': bucket}}, upsert=True)
async def update_bucket(self, *,
chat_id: Union[str, int, None] = None,
user_id: Union[str, int, None] = None,
bucket: Dict = None, **kwargs):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat_id=chat_id, user_id=user_id)
temp_bucket.update(bucket, **kwargs)
await self.set_bucket(chat_id=chat_id, user_id=user_id, bucket=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
db = await self.get_db()
await db[STATE].drop()
if full:
await db[DATA].drop()
await db[BUCKET].drop()
async def get_states_list(self) -> List[Tuple[int, int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
db = await self.get_db()
result = []
items = await db[STATE].find().to_list()
for item in items:
result.append(
(int(item['chat_id']), int(item['user_id']))
)
return result

View file

@ -591,7 +591,9 @@ class IDFilter(Filter):
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]):
if isinstance(obj, Message):
user_id = obj.from_user.id
user_id = None
if obj.from_user is not None:
user_id = obj.from_user.id
chat_id = obj.chat.id
elif isinstance(obj, CallbackQuery):
user_id = obj.from_user.id

View file

@ -1,15 +0,0 @@
import importlib
import aiogram
def test_file_deleted():
try:
major, minor, _ = aiogram.__version__.split(".")
except ValueError: # raised if version is major.minor
major, minor = aiogram.__version__.split(".")
if major == "2" and int(minor) >= 11:
mongo_aiomongo = importlib.util.find_spec("aiogram.contrib.fsm_storage.mongo_aiomongo")
assert mongo_aiomongo is False, "Remove aiogram.contrib.fsm_storage.mongo_aiomongo file, and replace storage " \
"in aiogram.contrib.fsm_storage.mongo with storage " \
"from aiogram.contrib.fsm_storage.mongo_motor"

View file

@ -6,10 +6,10 @@ import pytest
from aiogram.dispatcher.filters.builtin import (
Text,
extract_chat_ids,
ChatIDArgumentType, ForwardedMessageFilter,
ChatIDArgumentType, ForwardedMessageFilter, IDFilter,
)
from aiogram.types import Message
from tests.types.dataset import MESSAGE
from tests.types.dataset import MESSAGE, MESSAGE_FROM_CHANNEL
class TestText:
@ -75,6 +75,8 @@ def test_extract_chat_ids(chat_id: ChatIDArgumentType, expected: Set[int]):
class TestForwardedMessageFilter:
@pytest.mark.asyncio
async def test_filter_forwarded_messages(self):
filter = ForwardedMessageFilter(is_forwarded=True)
@ -85,6 +87,7 @@ class TestForwardedMessageFilter:
assert await filter.check(forwarded_message)
assert not await filter.check(not_forwarded_message)
@pytest.mark.asyncio
async def test_filter_not_forwarded_messages(self):
filter = ForwardedMessageFilter(is_forwarded=False)
@ -94,3 +97,14 @@ class TestForwardedMessageFilter:
assert await filter.check(not_forwarded_message)
assert not await filter.check(forwarded_message)
class TestIDFilter:
@pytest.mark.asyncio
async def test_chat_id_for_channels(self):
message_from_channel = Message(**MESSAGE_FROM_CHANNEL)
filter = IDFilter(chat_id=message_from_channel.chat.id)
assert await filter.check(message_from_channel)

View file

@ -409,6 +409,20 @@ MESSAGE_WITH_VOICE = {
"voice": VOICE,
}
CHANNEL = {
"type": "channel",
"username": "best_channel_ever",
"id": -1001065170817,
}
MESSAGE_FROM_CHANNEL = {
"message_id": 123432,
"from": None,
"chat": CHANNEL,
"date": 1508768405,
"text": "Hi, world!",
}
PRE_CHECKOUT_QUERY = {
"id": "262181558630368727",
"from": USER,