From 507344233750f9fc103f4803fabc3a98cd539b7f Mon Sep 17 00:00:00 2001 From: Fenicu Date: Sun, 12 Sep 2021 22:14:19 +0300 Subject: [PATCH] added new example of filter registration --- examples/custom_filter_example.py | 63 +++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/examples/custom_filter_example.py b/examples/custom_filter_example.py index b6e1f744..4099ac27 100644 --- a/examples/custom_filter_example.py +++ b/examples/custom_filter_example.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union from aiogram import Bot, Dispatcher, executor, types from aiogram.dispatcher.filters import BoundFilter from aiogram.dispatcher.handler import ctx_data @@ -22,18 +22,16 @@ dp = Dispatcher(bot) class UserMiddleware(LifetimeControllerMiddleware): skip_patterns = ["error", "update"] - def __init__(self): - super(UserMiddleware, self).__init__() - async def pre_process( self, - update: Union[types.Message, types.CallbackQuery], + obj: Union[types.Message, types.CallbackQuery], data: dict, ): - if update.from_user.id not in fake_db: - fake_db[update.from_user.id] = {"ban": False} + user = obj.from_user + if user.id not in fake_db: + fake_db[user.id] = {"ban": False} - data["user_in_db"] = fake_db[update.from_user.id] + data["user_in_db"] = fake_db[user.id] class BanUserFilter(BoundFilter): @@ -44,14 +42,14 @@ class BanUserFilter(BoundFilter): key = "banned" def __init__(self, banned: bool): - if banned is False: - raise ValueError("filter banned cannot be False") + self.banned = banned - async def check(self, obj: Union[types.Message, types.CallbackQuery]): - data = ctx_data.get() - user_in_db = data["user_in_db"] - - return not user_in_db["ban"] + async def check(self, obj: Any): + data: dict = ctx_data.get() + user_in_db = data.get("user_in_db", None) + if user_in_db: + return self.banned == user_in_db["ban"] + return False # If the user is not in the database, then he cannot be banned class IsEvenIdFilter(BoundFilter): @@ -70,18 +68,38 @@ class IsEvenIdFilter(BoundFilter): return {"is_even": bool(user_id % 2 == 0)} +class DatabaseIsNotEmptyFilter(BoundFilter): + """ + Check if the database contains the required number of records + """ + + key = "data_in_db" + + def __init__(self, data_in_db: int): + if isinstance(data_in_db, int): + self.data_in_db = data_in_db + else: + raise ValueError( + f"filter data_in_db must be a int, not {type(data_in_db).__name__}" + ) + + async def check(self, obj: Any): + return len(fake_db) >= self.data_in_db + + # Setup middleware dp.middleware.setup(UserMiddleware()) # Binding filters dp.filters_factory.bind( BanUserFilter, - event_handlers=[dp.message_handlers, dp.callback_query_handlers], + exclude_event_handlers=[dp.channel_post_handlers, dp.edited_channel_post_handlers], ) dp.filters_factory.bind( IsEvenIdFilter, event_handlers=[dp.message_handlers, dp.callback_query_handlers], ) +dp.filters_factory.bind(DatabaseIsNotEmptyFilter) @dp.message_handler(banned=True) @@ -89,12 +107,17 @@ async def handler_start1(message: types.Message): await message.answer("Congratulations, you are banned!") -@dp.message_handler(commands="start", is_even=True) -async def handler_start2(message: types.Message, is_even: bool): +@dp.message_handler(data_in_db=6) +async def handler_start2(message: types.Message): + await message.answer("There are too many records in the database!") + + +@dp.message_handler(is_even=True) +async def handler_start3(message: types.Message, is_even: bool): if is_even: - await message.answer("Congratulations, you id is even!") + await message.answer("Congratulations, your id is even!") else: - await message.answer("Congratulations, your id is not even") + await message.answer("Congratulations, your id is not even :(") if __name__ == "__main__":