diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 1e36f202..8913e726 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -11,7 +11,7 @@ from aiohttp.helpers import sentinel from aiogram.utils.deprecated import renamed_argument from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter, IsReplyFilter, ForwardedMessageFilter, \ - IsSenderContact, ChatTypeFilter, MediaGroupFilter, AbstractFilter + IsSenderContact, ChatTypeFilter, MediaGroupFilter, StorageDataFilter, AbstractFilter from .handler import Handler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -210,6 +210,11 @@ class Dispatcher(DataMixin, ContextInstanceMixin): self.channel_post_handlers, self.edited_channel_post_handlers ]) + filters_factory.bind(StorageDataFilter, event_handlers=[ + self.errors_handlers, + self.poll_handlers, + self.poll_answer_handlers, + ]) def __del__(self): self.stop_polling() diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index d07d953b..4048392c 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,7 +1,7 @@ from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \ ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, \ Text, IDFilter, AdminFilter, IsReplyFilter, IsSenderContact, ForwardedMessageFilter, \ - ChatTypeFilter, MediaGroupFilter + ChatTypeFilter, MediaGroupFilter, StorageDataFilter from .factory import FiltersFactory from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \ check_filters, get_filter_spec, get_filters_spec @@ -26,6 +26,7 @@ __all__ = ( 'ForwardedMessageFilter', 'ChatTypeFilter', 'MediaGroupFilter', + 'StorageDataFilter', 'FiltersFactory', 'AbstractFilter', 'BoundFilter', diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index ebd38f08..804ced7e 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -757,3 +757,72 @@ class MediaGroupFilter(BoundFilter): async def check(self, message: types.Message) -> bool: return bool(getattr(message, "media_group_id")) is self.is_media_group + + +class StorageDataFilter(BoundFilter): + """ + Check if all items matches the relevant items in the current storage data. + """ + + key = 'storage_data' + ctx_storage_data = ContextVar('user_storage_data') + + def __init__(self, dispatcher, storage_data: dict): + from aiogram import Dispatcher + + self.dispatcher: Dispatcher = dispatcher + self.storage_data = storage_data + + @staticmethod + def get_target(obj) -> typing.Tuple[Optional[int], Optional[int]]: + if isinstance(obj, CallbackQuery): + try: + chat_id = obj.message.chat.id + except AttributeError: + chat_id = None + else: + try: + chat_id = obj.chat.id + except AttributeError: + chat_id = None + + try: + user_id = obj.from_user.id + except AttributeError: + user_id = None + + return chat_id, user_id + + async def get_current_storage_data(self, obj) -> Optional[dict]: + try: + return self.ctx_storage_data.get() + except LookupError: + chat_id, user_id = self.get_target(obj) + + if chat_id or user_id: + storage_data = await self.dispatcher.storage.get_data(chat=chat_id, user=user_id) + self.ctx_storage_data.set(storage_data) + return storage_data + + async def check(self, obj) -> bool: + current_storage_data = await self.get_current_storage_data(obj) + + if current_storage_data is None: + return False + + for key, value in self.storage_data.items(): + if key not in current_storage_data: + return False + + if value == '*': + continue + + if isinstance(value, (list, tuple, set)): + if current_storage_data[key] in value: + continue + + if current_storage_data[key] == value: + continue + + return False + return True