Merge pull request #172 from Birdi7/add-id-filter

Add id filter
This commit is contained in:
Alex Root Junior 2019-07-22 19:30:51 +03:00 committed by GitHub
commit 01f075d905
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 2 deletions

View file

@ -9,7 +9,7 @@ import aiohttp
from aiohttp.helpers import sentinel
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
RegexpCommandsFilter, StateFilter, Text
RegexpCommandsFilter, StateFilter, Text, IdFilter
from .handler import Handler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
@ -114,6 +114,11 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
filters_factory.bind(ExceptionsFilter, event_handlers=[
self.errors_handlers
])
filters_factory.bind(IdFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers, self.inline_query_handlers
])
def __del__(self):
self.stop_polling()

View file

@ -1,5 +1,5 @@
from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \
ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text
ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text, IdFilter
from .factory import FiltersFactory
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \
check_filters, get_filter_spec, get_filters_spec
@ -23,6 +23,7 @@ __all__ = [
'Regexp',
'StateFilter',
'Text',
'IdFilter',
'get_filter_spec',
'get_filters_spec',
'execute_filter',

View file

@ -503,3 +503,66 @@ class ExceptionsFilter(BoundFilter):
return True
except:
return False
class IdFilter(Filter):
def __init__(self,
user_id: Optional[Union[Iterable[Union[int, str]], str, int]] = None,
chat_id: Optional[Union[Iterable[Union[int, str]], str, int]] = None,
):
"""
:param user_id:
:param chat_id:
"""
if user_id is None and chat_id is None:
raise ValueError("Both user_id and chat_id can't be None")
self.user_id = None
self.chat_id = None
if user_id:
if isinstance(user_id, Iterable):
self.user_id = list(map(int, user_id))
else:
self.user_id = [int(user_id), ]
if chat_id:
if isinstance(chat_id, Iterable):
self.chat_id = list(map(int, chat_id))
else:
self.chat_id = [int(chat_id), ]
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
result = {}
if 'user_id' in full_config:
result['user_id'] = full_config.pop('user_id')
if 'chat_id' in full_config:
result['chat_id'] = full_config.pop('chat_id')
return result
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]):
if isinstance(obj, Message):
user_id = obj.from_user.id
chat_id = obj.chat.id
elif isinstance(obj, CallbackQuery):
user_id = obj.from_user.id
chat_id = None
if obj.message is not None:
# if the button was sent with message
chat_id = obj.message.chat.id
elif isinstance(obj, InlineQuery):
user_id = obj.from_user.id
chat_id = None
else:
return False
if self.user_id and self.chat_id:
return user_id in self.user_id and chat_id in self.chat_id
elif self.user_id:
return user_id in self.user_id
elif self.chat_id:
return chat_id in self.chat_id
return False