Add pre_set, post_set, pre_finish and post_finish state decorators

rm useless group_name

Add pre_set, post_set, pre_finish and post_finish state decorators
This commit is contained in:
PasNA6713 2022-01-23 17:39:55 +03:00
parent 66533a84e3
commit e4a0f78cc3
3 changed files with 236 additions and 18 deletions

View file

@ -1,7 +1,8 @@
import inspect
from typing import Optional
from typing import Optional, Callable, List
from ..dispatcher import Dispatcher
from ..storage import FSMContext
class State:
@ -14,6 +15,8 @@ class State:
self._group_name = group_name
self._group = None
self._dispatcher = dispatcher
self._pre_set_handlers: List[Callable] = []
self._post_set_handlers: List[Callable] = []
@property
def group(self):
@ -54,9 +57,25 @@ class State:
__repr__ = __str__
async def set(self):
state = self._dispatcher.get_current().current_state()
await state.set_state(self.state)
def pre_set(self, func_to_call):
self._pre_set_handlers.append(func_to_call)
def post_set(self, func_to_call):
self._post_set_handlers.append(func_to_call)
async def set(self, *args, trigger=True, **kwargs):
context = self._dispatcher.get_current().current_state()
old_state = await context.get_state()
if trigger is True:
for func in self._pre_set_handlers:
await func(context, old_state, self.state)
await context.set_state(self.state)
if trigger is True:
for func in self._post_set_handlers:
await func(context, old_state, self.state)
class StatesGroupMeta(type):
@ -67,23 +86,22 @@ class StatesGroupMeta(type):
childs = []
cls._group_name = name
cls._dispatcher = dispatcher
for name, prop in namespace.items():
if isinstance(prop, State):
states.append(prop)
prop.set_parent(cls)
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
childs.append(prop)
prop._parent = cls
prop._dispatcher = dispatcher
cls._parent = None
cls._childs = tuple(childs)
cls._states = tuple(states)
cls._state_names = tuple(state.state for state in states)
cls._dispatcher = dispatcher
for state in states:
state.set_parent(cls)
return cls
@ -147,7 +165,7 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def next(cls) -> str:
async def next(cls, *args, trigger=True, **kwargs) -> str:
state = cls._dispatcher.get_current().current_state()
state_name = await state.get_state()
@ -159,7 +177,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
try:
next_state = cls.states[next_step]
next_state_name = next_state.state
await next_state.set()
await next_state.set(trigger=trigger)
except IndexError:
next_state_name = None
@ -169,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def previous(cls) -> str:
async def previous(cls, *args, trigger=True, **kwargs) -> str:
state = cls._dispatcher.get_current().current_state()
state_name = await state.get_state()
@ -185,26 +203,34 @@ class StatesGroup(metaclass=StatesGroupMeta):
else:
previous_state = cls.states[previous_step]
previous_state_name = previous_state.state
await previous_state.set()
await previous_state.set(trigger=trigger)
return previous_state_name
@classmethod
async def first(cls) -> str:
async def first(cls, *args, trigger=True, **kwargs) -> str:
state = cls._dispatcher.get_current().current_state()
first_step = cls.states[0]
await first_step.set()
await first_step.set(trigger=trigger)
return first_step.state
@classmethod
async def last(cls) -> str:
async def last(cls, *args, trigger=True, **kwargs) -> str:
state = cls._dispatcher.get_current().current_state()
last_step = cls.states[-1]
await last_step.set()
await last_step.set(trigger=trigger)
return last_step.state
@classmethod
def pre_finish(cls, func_to_call) -> str:
FSMContext.set_pre_finish_hanlder(cls._group_name, func_to_call)
@classmethod
def post_finish(cls, func_to_call) -> str:
FSMContext.set_post_finish_hanlder(cls._group_name, func_to_call)
default_state = State()
any_state = State(state='*')

View file

@ -290,12 +290,23 @@ class BaseStorage:
class FSMContext:
_pre_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {}
_post_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {}
def __init__(self, storage, chat, user):
self.storage: BaseStorage = storage
self.chat, self.user = self.storage.check_address(chat=chat, user=user)
def proxy(self):
return FSMContextProxy(self)
@classmethod
def set_pre_finish_hanlder(cls, group_name: str, handler: typing.Callable):
cls._pre_finish_handlers[group_name] = cls._pre_finish_handlers.get(group_name, []) + [handler]
@classmethod
def set_post_finish_hanlder(cls, group_name: str, handler: typing.Callable):
cls._post_finish_handlers[group_name] = cls._post_finish_handlers.get(group_name, []) + [handler]
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
@ -318,9 +329,23 @@ class FSMContext:
async def reset_data(self):
await self.storage.reset_data(chat=self.chat, user=self.user)
async def finish(self):
async def finish(self, *args, trigger=True, **kwargs):
state = await self.get_state()
if trigger is True:
for group_name, handlers in self._pre_finish_handlers.items():
if group_name in state:
for func in handlers:
await func(self, state)
await self.storage.finish(chat=self.chat, user=self.user)
if trigger is True:
for group_name, handlers in self._post_finish_handlers.items():
if group_name in state:
for func in handlers:
await func(self, state)
class FSMContextProxy:
def __init__(self, fsm_context: FSMContext):

View file

@ -0,0 +1,167 @@
import logging
from typing import Union
import aiogram.utils.markdown as md
from aiogram import Bot, Dispatcher, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import FSMContext
from aiogram.dispatcher.filters import Text
from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.types import ParseMode
from aiogram.utils import executor
logging.basicConfig(level=logging.INFO)
API_TOKEN = 'BOT_TOKEN_HERE'
bot = Bot(token=API_TOKEN)
# For example use simple MemoryStorage for Dispatcher.
storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage)
# States
class Form(StatesGroup):
name = State() # Will be represented in storage as 'Form:name'
age = State() # Will be represented in storage as 'Form:age'
gender = State() # Will be represented in storage as 'Form:gender'
# You can use pre_set decorator to call decorated functions before state will be setted
@Form.name.pre_set
async def name_pre_set(context: FSMContext, old_state: str, new_state: str):
logging.info('Calling before Form:name will be set')
await bot.send_message(
chat_id = context.chat,
text = "What's your name?"
)
# You can use post_set decorator to call decorated functions after state will be setted
@Form.age.post_set
async def age_post_set(context: FSMContext, old_state: str, new_state: str):
logging.info('Calling after Form:age will be set')
markup = types.InlineKeyboardMarkup(row_width=1)
markup.row(types.InlineKeyboardButton('Back', callback_data='back'))
await bot.send_message(
chat_id = context.chat,
text = "How old are you?",
reply_markup = markup
)
@Form.gender.pre_set
async def gender_pre_set(context: FSMContext, old_state: str, new_state: str):
logging.info('Calling before gender:age will be set')
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
markup.add("Male", "Female")
markup.add("Other", "Back")
await bot.send_message(
chat_id = context.chat,
text = "What is your gender?",
reply_markup = markup
)
# You can use pre_finish decorator to call decorated functions before state group will be finished
@Form.pre_finish
async def form_post_finish(context: FSMContext, old_state: str):
async with context.proxy() as data:
markup = types.ReplyKeyboardRemove()
await bot.send_message(
context.chat,
md.text(
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
md.text('Age:', md.code(data['age'])),
md.text('Gender:', data['gender']),
sep='\n',
),
reply_markup=markup,
parse_mode=ParseMode.MARKDOWN,
)
# You can use state Form if you need to handle all states from Form group
@dp.message_handler(commands='back', state=Form)
@dp.message_handler(Text(equals='back', ignore_case=True), state=Form)
@dp.callback_query_handler(text='back', state=Form)
async def process_gender_invalid(message: Union[types.Message, types.CallbackQuery]):
await Form.previous()
# You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands='cancel')
@dp.message_handler(Text(equals='cancel', ignore_case=True), state='*')
async def cancel_handler(message: types.Message, state: FSMContext):
"""
Allow user to cancel any action
"""
current_state = await state.get_state()
if current_state is None:
return
logging.info('Cancelling state %r', current_state)
# Cancel state and inform user about it
# You can set trigger to False to prevent pre_set and post_set signals
await state.finish(trigger=False)
# And remove keyboard (just in case)
await message.reply('Cancelled.', reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(commands='start')
async def cmd_start(message: types.Message):
"""
Conversation's entry point
"""
await message.reply("Hi there!")
await Form.name.set()
@dp.message_handler(state=Form.name)
async def process_name(message: types.Message, state: FSMContext):
"""
Process user name and go to age
"""
await state.update_data(name=message.text)
await Form.next()
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
async def process_age_invalid(message: types.Message):
"""
If age is not digit
"""
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
async def process_age(message: types.Message, state: FSMContext):
"""
Process user age and go to gender
"""
await state.update_data(age=int(message.text))
await Form.next()
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
async def process_gender_invalid(message: types.Message):
"""
In this example gender has to be one of: Male, Female, Other.
"""
return await message.reply("Bad gender name. Choose your gender from the keyboard.")
@dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message, state: FSMContext):
await state.update_data(gender=message.text)
await state.finish()
if __name__ == '__main__':
executor.start_polling(dp, skip_updates=True)