mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
added new example of filter registration
This commit is contained in:
parent
0ee9cc02a6
commit
5073442337
1 changed files with 43 additions and 20 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from aiogram import Bot, Dispatcher, executor, types
|
||||
from aiogram.dispatcher.filters import BoundFilter
|
||||
from aiogram.dispatcher.handler import ctx_data
|
||||
|
|
@ -22,18 +22,16 @@ dp = Dispatcher(bot)
|
|||
class UserMiddleware(LifetimeControllerMiddleware):
|
||||
skip_patterns = ["error", "update"]
|
||||
|
||||
def __init__(self):
|
||||
super(UserMiddleware, self).__init__()
|
||||
|
||||
async def pre_process(
|
||||
self,
|
||||
update: Union[types.Message, types.CallbackQuery],
|
||||
obj: Union[types.Message, types.CallbackQuery],
|
||||
data: dict,
|
||||
):
|
||||
if update.from_user.id not in fake_db:
|
||||
fake_db[update.from_user.id] = {"ban": False}
|
||||
user = obj.from_user
|
||||
if user.id not in fake_db:
|
||||
fake_db[user.id] = {"ban": False}
|
||||
|
||||
data["user_in_db"] = fake_db[update.from_user.id]
|
||||
data["user_in_db"] = fake_db[user.id]
|
||||
|
||||
|
||||
class BanUserFilter(BoundFilter):
|
||||
|
|
@ -44,14 +42,14 @@ class BanUserFilter(BoundFilter):
|
|||
key = "banned"
|
||||
|
||||
def __init__(self, banned: bool):
|
||||
if banned is False:
|
||||
raise ValueError("filter banned cannot be False")
|
||||
self.banned = banned
|
||||
|
||||
async def check(self, obj: Union[types.Message, types.CallbackQuery]):
|
||||
data = ctx_data.get()
|
||||
user_in_db = data["user_in_db"]
|
||||
|
||||
return not user_in_db["ban"]
|
||||
async def check(self, obj: Any):
|
||||
data: dict = ctx_data.get()
|
||||
user_in_db = data.get("user_in_db", None)
|
||||
if user_in_db:
|
||||
return self.banned == user_in_db["ban"]
|
||||
return False # If the user is not in the database, then he cannot be banned
|
||||
|
||||
|
||||
class IsEvenIdFilter(BoundFilter):
|
||||
|
|
@ -70,18 +68,38 @@ class IsEvenIdFilter(BoundFilter):
|
|||
return {"is_even": bool(user_id % 2 == 0)}
|
||||
|
||||
|
||||
class DatabaseIsNotEmptyFilter(BoundFilter):
|
||||
"""
|
||||
Check if the database contains the required number of records
|
||||
"""
|
||||
|
||||
key = "data_in_db"
|
||||
|
||||
def __init__(self, data_in_db: int):
|
||||
if isinstance(data_in_db, int):
|
||||
self.data_in_db = data_in_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"filter data_in_db must be a int, not {type(data_in_db).__name__}"
|
||||
)
|
||||
|
||||
async def check(self, obj: Any):
|
||||
return len(fake_db) >= self.data_in_db
|
||||
|
||||
|
||||
# Setup middleware
|
||||
dp.middleware.setup(UserMiddleware())
|
||||
|
||||
# Binding filters
|
||||
dp.filters_factory.bind(
|
||||
BanUserFilter,
|
||||
event_handlers=[dp.message_handlers, dp.callback_query_handlers],
|
||||
exclude_event_handlers=[dp.channel_post_handlers, dp.edited_channel_post_handlers],
|
||||
)
|
||||
dp.filters_factory.bind(
|
||||
IsEvenIdFilter,
|
||||
event_handlers=[dp.message_handlers, dp.callback_query_handlers],
|
||||
)
|
||||
dp.filters_factory.bind(DatabaseIsNotEmptyFilter)
|
||||
|
||||
|
||||
@dp.message_handler(banned=True)
|
||||
|
|
@ -89,12 +107,17 @@ async def handler_start1(message: types.Message):
|
|||
await message.answer("Congratulations, you are banned!")
|
||||
|
||||
|
||||
@dp.message_handler(commands="start", is_even=True)
|
||||
async def handler_start2(message: types.Message, is_even: bool):
|
||||
@dp.message_handler(data_in_db=6)
|
||||
async def handler_start2(message: types.Message):
|
||||
await message.answer("There are too many records in the database!")
|
||||
|
||||
|
||||
@dp.message_handler(is_even=True)
|
||||
async def handler_start3(message: types.Message, is_even: bool):
|
||||
if is_even:
|
||||
await message.answer("Congratulations, you id is even!")
|
||||
await message.answer("Congratulations, your id is even!")
|
||||
else:
|
||||
await message.answer("Congratulations, your id is not even")
|
||||
await message.answer("Congratulations, your id is not even :(")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue