mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
add builtin StorageDataFilter
This commit is contained in:
parent
e5cce6edf8
commit
a2869e0479
3 changed files with 77 additions and 2 deletions
|
|
@ -11,7 +11,7 @@ from aiohttp.helpers import sentinel
|
|||
from aiogram.utils.deprecated import renamed_argument
|
||||
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
|
||||
RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter, IsReplyFilter, ForwardedMessageFilter, \
|
||||
IsSenderContact, ChatTypeFilter, MediaGroupFilter, AbstractFilter
|
||||
IsSenderContact, ChatTypeFilter, MediaGroupFilter, StorageDataFilter, AbstractFilter
|
||||
from .handler import Handler
|
||||
from .middlewares import MiddlewareManager
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||
|
|
@ -210,6 +210,11 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
self.channel_post_handlers,
|
||||
self.edited_channel_post_handlers
|
||||
])
|
||||
filters_factory.bind(StorageDataFilter, event_handlers=[
|
||||
self.errors_handlers,
|
||||
self.poll_handlers,
|
||||
self.poll_answer_handlers,
|
||||
])
|
||||
|
||||
def __del__(self):
|
||||
self.stop_polling()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \
|
||||
ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, \
|
||||
Text, IDFilter, AdminFilter, IsReplyFilter, IsSenderContact, ForwardedMessageFilter, \
|
||||
ChatTypeFilter, MediaGroupFilter
|
||||
ChatTypeFilter, MediaGroupFilter, StorageDataFilter
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \
|
||||
check_filters, get_filter_spec, get_filters_spec
|
||||
|
|
@ -26,6 +26,7 @@ __all__ = (
|
|||
'ForwardedMessageFilter',
|
||||
'ChatTypeFilter',
|
||||
'MediaGroupFilter',
|
||||
'StorageDataFilter',
|
||||
'FiltersFactory',
|
||||
'AbstractFilter',
|
||||
'BoundFilter',
|
||||
|
|
|
|||
|
|
@ -757,3 +757,72 @@ class MediaGroupFilter(BoundFilter):
|
|||
|
||||
async def check(self, message: types.Message) -> bool:
|
||||
return bool(getattr(message, "media_group_id")) is self.is_media_group
|
||||
|
||||
|
||||
class StorageDataFilter(BoundFilter):
|
||||
"""
|
||||
Check if all items matches the relevant items in the current storage data.
|
||||
"""
|
||||
|
||||
key = 'storage_data'
|
||||
ctx_storage_data = ContextVar('user_storage_data')
|
||||
|
||||
def __init__(self, dispatcher, storage_data: dict):
|
||||
from aiogram import Dispatcher
|
||||
|
||||
self.dispatcher: Dispatcher = dispatcher
|
||||
self.storage_data = storage_data
|
||||
|
||||
@staticmethod
|
||||
def get_target(obj) -> typing.Tuple[Optional[int], Optional[int]]:
|
||||
if isinstance(obj, CallbackQuery):
|
||||
try:
|
||||
chat_id = obj.message.chat.id
|
||||
except AttributeError:
|
||||
chat_id = None
|
||||
else:
|
||||
try:
|
||||
chat_id = obj.chat.id
|
||||
except AttributeError:
|
||||
chat_id = None
|
||||
|
||||
try:
|
||||
user_id = obj.from_user.id
|
||||
except AttributeError:
|
||||
user_id = None
|
||||
|
||||
return chat_id, user_id
|
||||
|
||||
async def get_current_storage_data(self, obj) -> Optional[dict]:
|
||||
try:
|
||||
return self.ctx_storage_data.get()
|
||||
except LookupError:
|
||||
chat_id, user_id = self.get_target(obj)
|
||||
|
||||
if chat_id or user_id:
|
||||
storage_data = await self.dispatcher.storage.get_data(chat=chat_id, user=user_id)
|
||||
self.ctx_storage_data.set(storage_data)
|
||||
return storage_data
|
||||
|
||||
async def check(self, obj) -> bool:
|
||||
current_storage_data = await self.get_current_storage_data(obj)
|
||||
|
||||
if current_storage_data is None:
|
||||
return False
|
||||
|
||||
for key, value in self.storage_data.items():
|
||||
if key not in current_storage_data:
|
||||
return False
|
||||
|
||||
if value == '*':
|
||||
continue
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if current_storage_data[key] in value:
|
||||
continue
|
||||
|
||||
if current_storage_data[key] == value:
|
||||
continue
|
||||
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue