Updated filters for new event types

This commit is contained in:
Alex Root Junior 2021-03-14 17:54:14 +02:00
parent a439f698f9
commit e666922811
2 changed files with 20 additions and 8 deletions

View file

@ -165,6 +165,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
self.edited_channel_post_handlers,
self.callback_query_handlers,
self.inline_query_handlers,
self.chat_member_handlers,
])
filters_factory.bind(IDFilter, event_handlers=[
self.message_handlers,
@ -173,6 +174,8 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
self.edited_channel_post_handlers,
self.callback_query_handlers,
self.inline_query_handlers,
self.chat_member_handlers,
self.my_chat_member_handlers,
])
filters_factory.bind(IsReplyFilter, event_handlers=[
self.message_handlers,
@ -198,6 +201,8 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
self.channel_post_handlers,
self.edited_channel_post_handlers,
self.callback_query_handlers,
self.my_chat_member_handlers,
self.chat_member_handlers
])
def __del__(self):

View file

@ -10,7 +10,7 @@ from babel.support import LazyProxy
from aiogram import types
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
from aiogram.types import CallbackQuery, ChatType, InlineQuery, Message, Poll
from aiogram.types import CallbackQuery, ChatType, InlineQuery, Message, Poll, ChatMemberUpdated
ChatIDArgumentType = typing.Union[typing.Iterable[typing.Union[int, str]], str, int]
@ -604,7 +604,7 @@ class IDFilter(Filter):
return result
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]):
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, ChatMemberUpdated]):
if isinstance(obj, Message):
user_id = None
if obj.from_user is not None:
@ -619,6 +619,9 @@ class IDFilter(Filter):
elif isinstance(obj, InlineQuery):
user_id = obj.from_user.id
chat_id = None
elif isinstance(obj, ChatMemberUpdated):
user_id = obj.from_user.id
chat_id = obj.chat.id
else:
return False
@ -663,19 +666,21 @@ class AdminFilter(Filter):
return result
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]) -> bool:
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, ChatMemberUpdated]) -> bool:
user_id = obj.from_user.id
if self._check_current:
if isinstance(obj, Message):
message = obj
chat = obj.chat
elif isinstance(obj, CallbackQuery) and obj.message:
message = obj.message
chat = obj.message.chat
elif isinstance(obj, ChatMemberUpdated):
chat = obj.chat
else:
return False
if message.chat.type == ChatType.PRIVATE: # there is no admin in private chats
if chat.type == ChatType.PRIVATE: # there is no admin in private chats
return False
chat_ids = [message.chat.id]
chat_ids = [chat.id]
else:
chat_ids = self._chat_ids
@ -719,11 +724,13 @@ class ChatTypeFilter(BoundFilter):
self.chat_type: typing.Set[str] = set(chat_type)
async def check(self, obj: Union[Message, CallbackQuery]):
async def check(self, obj: Union[Message, CallbackQuery, ChatMemberUpdated]):
if isinstance(obj, Message):
obj = obj.chat
elif isinstance(obj, CallbackQuery):
obj = obj.message.chat
elif isinstance(obj, ChatMemberUpdated):
obj = obj.chat
else:
warnings.warn("ChatTypeFilter doesn't support %s as input", type(obj))
return False