enh: Text filter contains_any parameter added

This commit is contained in:
Ilya Smirnov 2021-07-08 20:04:24 +03:00
parent 2b4e3ad5c6
commit 72b416c8bf

View file

@ -247,6 +247,7 @@ class Text(Filter):
_default_params = (
('text', 'equals'),
('text_contains', 'contains'),
('text_contains_any', 'contains_any'),
('text_startswith', 'startswith'),
('text_endswith', 'endswith'),
)
@ -254,6 +255,7 @@ class Text(Filter):
def __init__(self,
equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
contains_any: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
ignore_case=False):
@ -268,10 +270,11 @@ class Text(Filter):
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it.
check = sum(map(lambda s: s is not None, (equals, contains, startswith, endswith)))
check = sum(map(lambda s: s is not None, (equals, contains, contains_any, startswith, endswith)))
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('contains_any', contains_any),
('startswith', startswith),
('endswith', endswith)
] if arg[1] is not None])
@ -279,21 +282,17 @@ class Text(Filter):
elif check == 0:
raise ValueError(f"No one mode is specified!")
equals, contains, endswith, startswith = map(lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy)
else e,
(equals, contains, endswith, startswith))
equals, contains, contains_any, endswith, startswith = map(
lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy)
else e,
(equals, contains, contains_any, endswith, startswith))
self.equals = equals
self.contains = contains
self.contains_any = contains_any
self.endswith = endswith
self.startswith = startswith
self.ignore_case = ignore_case
@classmethod
def validate(cls, full_config: Dict[str, Any]):
for param, key in cls._default_params:
if param in full_config:
return {key: full_config.pop(param)}
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery, Poll]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
@ -323,6 +322,10 @@ class Text(Filter):
contains = list(map(_pre_process_func, self.contains))
return all(map(text.__contains__, contains))
if self.contains_any is not None:
contains_any = list(map(_pre_process_func, self.contains_any))
return any(map(text.__contains__, contains_any))
if self.startswith is not None:
startswith = list(map(_pre_process_func, self.startswith))
return any(map(text.startswith, startswith))