diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index b485fa49..164d6aad 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -10,7 +10,7 @@ from aiohttp.helpers import sentinel from aiogram.utils.deprecated import renamed_argument from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ - RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter, IsReplyFilter + RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter, IsReplyFilter, ForwardedMessageFilter from .filters.builtin import IsSenderContact from .handler import Handler from .middlewares import MiddlewareManager @@ -160,6 +160,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin): self.channel_post_handlers, self.edited_channel_post_handlers, ]) + filters_factory.bind(ForwardedMessageFilter, event_handlers=[ + self.message_handlers, + self.edited_channel_post_handlers, + self.channel_post_handlers, + self.edited_channel_post_handlers + ]) def __del__(self): self.stop_polling() diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 6de3cc7a..edd1959a 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,6 +1,6 @@ from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \ ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, \ - Text, IDFilter, AdminFilter, IsReplyFilter, IsSenderContact + Text, IDFilter, AdminFilter, IsReplyFilter, IsSenderContact, ForwardedMessageFilter from .factory import FiltersFactory from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \ check_filters, get_filter_spec, get_filters_spec @@ -32,4 +32,5 @@ __all__ = [ 'get_filters_spec', 'execute_filter', 'check_filters', + 'ForwardedMessageFilter', ] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 5fe01dde..c59d9b0d 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -681,3 +681,13 @@ class IsReplyFilter(BoundFilter): return {'reply': msg.reply_to_message} elif not msg.reply_to_message and not self.is_reply: return True + + +class ForwardedMessageFilter(BoundFilter): + key = 'is_forwarded' + + def __init__(self, is_forwarded: bool): + self.is_forwarded = is_forwarded + + async def check(self, message: Message): + return bool(getattr(message, "forward_date")) is self.is_forwarded diff --git a/docs/source/dispatcher/filters.rst b/docs/source/dispatcher/filters.rst index af06b73e..3681dfcb 100644 --- a/docs/source/dispatcher/filters.rst +++ b/docs/source/dispatcher/filters.rst @@ -141,6 +141,14 @@ IsReplyFilter :show-inheritance: +ForwardedMessageFilter +------------- + +.. autoclass:: aiogram.dispatcher.filters.filters.ForwardedMessageFilter + :members: + :show-inheritance: + + Making own filters (Custom filters) =================================== diff --git a/tests/test_dispatcher/test_filters/test_builtin.py b/tests/test_dispatcher/test_filters/test_builtin.py index a26fc139..4cfce465 100644 --- a/tests/test_dispatcher/test_filters/test_builtin.py +++ b/tests/test_dispatcher/test_filters/test_builtin.py @@ -1,12 +1,15 @@ from typing import Set +from datetime import datetime import pytest from aiogram.dispatcher.filters.builtin import ( Text, extract_chat_ids, - ChatIDArgumentType, + ChatIDArgumentType, ForwardedMessageFilter, ) +from aiogram.types import Message +from tests.types.dataset import MESSAGE class TestText: @@ -69,3 +72,25 @@ class TestText: ) def test_extract_chat_ids(chat_id: ChatIDArgumentType, expected: Set[int]): assert extract_chat_ids(chat_id) == expected + + +class TestForwardedMessageFilter: + async def test_filter_forwarded_messages(self): + filter = ForwardedMessageFilter(is_forwarded=True) + + forwarded_message = Message(forward_date=round(datetime(2020, 5, 21, 5, 1).timestamp()), **MESSAGE) + + not_forwarded_message = Message(**MESSAGE) + + assert await filter.check(forwarded_message) + assert not await filter.check(not_forwarded_message) + + async def test_filter_not_forwarded_messages(self): + filter = ForwardedMessageFilter(is_forwarded=False) + + forwarded_message = Message(forward_date=round(datetime(2020, 5, 21, 5, 1).timestamp()), **MESSAGE) + + not_forwarded_message = Message(**MESSAGE) + + assert await filter.check(not_forwarded_message) + assert not await filter.check(forwarded_message)