diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py index 1309a39a..2783e5d8 100644 --- a/aiogram/dispatcher/filters/state.py +++ b/aiogram/dispatcher/filters/state.py @@ -1,7 +1,8 @@ import inspect -from typing import Optional +from typing import Optional, Callable, List from ..dispatcher import Dispatcher +from ..storage import FSMContext class State: @@ -14,6 +15,8 @@ class State: self._group_name = group_name self._group = None self._dispatcher = dispatcher + self._pre_set_handlers: List[Callable] = [] + self._post_set_handlers: List[Callable] = [] @property def group(self): @@ -54,9 +57,25 @@ class State: __repr__ = __str__ - async def set(self): - state = self._dispatcher.get_current().current_state() - await state.set_state(self.state) + def pre_set(self, func_to_call): + self._pre_set_handlers.append(func_to_call) + + def post_set(self, func_to_call): + self._post_set_handlers.append(func_to_call) + + async def set(self, *args, trigger=True, **kwargs): + context = self._dispatcher.get_current().current_state() + old_state = await context.get_state() + + if trigger is True: + for func in self._pre_set_handlers: + await func(context, old_state, self.state) + + await context.set_state(self.state) + + if trigger is True: + for func in self._post_set_handlers: + await func(context, old_state, self.state) class StatesGroupMeta(type): @@ -67,23 +86,22 @@ class StatesGroupMeta(type): childs = [] cls._group_name = name + cls._dispatcher = dispatcher for name, prop in namespace.items(): - if isinstance(prop, State): states.append(prop) + prop.set_parent(cls) + elif inspect.isclass(prop) and issubclass(prop, StatesGroup): childs.append(prop) prop._parent = cls + prop._dispatcher = dispatcher cls._parent = None cls._childs = tuple(childs) cls._states = tuple(states) cls._state_names = tuple(state.state for state in states) - cls._dispatcher = dispatcher - - for state in states: - state.set_parent(cls) return cls @@ -147,7 +165,7 @@ class StatesGroupMeta(type): class StatesGroup(metaclass=StatesGroupMeta): @classmethod - async def next(cls) -> str: + async def next(cls, *args, trigger=True, **kwargs) -> str: state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() @@ -159,7 +177,7 @@ class StatesGroup(metaclass=StatesGroupMeta): try: next_state = cls.states[next_step] next_state_name = next_state.state - await next_state.set() + await next_state.set(trigger=trigger) except IndexError: next_state_name = None @@ -169,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta): @classmethod - async def previous(cls) -> str: + async def previous(cls, *args, trigger=True, **kwargs) -> str: state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() @@ -185,26 +203,34 @@ class StatesGroup(metaclass=StatesGroupMeta): else: previous_state = cls.states[previous_step] previous_state_name = previous_state.state - await previous_state.set() + await previous_state.set(trigger=trigger) return previous_state_name @classmethod - async def first(cls) -> str: + async def first(cls, *args, trigger=True, **kwargs) -> str: state = cls._dispatcher.get_current().current_state() first_step = cls.states[0] - await first_step.set() + await first_step.set(trigger=trigger) return first_step.state @classmethod - async def last(cls) -> str: + async def last(cls, *args, trigger=True, **kwargs) -> str: state = cls._dispatcher.get_current().current_state() last_step = cls.states[-1] - await last_step.set() + await last_step.set(trigger=trigger) return last_step.state + @classmethod + def pre_finish(cls, func_to_call) -> str: + FSMContext.set_pre_finish_hanlder(cls._group_name, func_to_call) + + @classmethod + def post_finish(cls, func_to_call) -> str: + FSMContext.set_post_finish_hanlder(cls._group_name, func_to_call) + default_state = State() any_state = State(state='*') diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py index 63ce25b2..acb1459a 100644 --- a/aiogram/dispatcher/storage.py +++ b/aiogram/dispatcher/storage.py @@ -290,12 +290,23 @@ class BaseStorage: class FSMContext: + _pre_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {} + _post_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {} + def __init__(self, storage, chat, user): self.storage: BaseStorage = storage self.chat, self.user = self.storage.check_address(chat=chat, user=user) def proxy(self): return FSMContextProxy(self) + + @classmethod + def set_pre_finish_hanlder(cls, group_name: str, handler: typing.Callable): + cls._pre_finish_handlers[group_name] = cls._pre_finish_handlers.get(group_name, []) + [handler] + + @classmethod + def set_post_finish_hanlder(cls, group_name: str, handler: typing.Callable): + cls._post_finish_handlers[group_name] = cls._post_finish_handlers.get(group_name, []) + [handler] async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]: return await self.storage.get_state(chat=self.chat, user=self.user, default=default) @@ -318,9 +329,23 @@ class FSMContext: async def reset_data(self): await self.storage.reset_data(chat=self.chat, user=self.user) - async def finish(self): + async def finish(self, *args, trigger=True, **kwargs): + state = await self.get_state() + + if trigger is True: + for group_name, handlers in self._pre_finish_handlers.items(): + if group_name in state: + for func in handlers: + await func(self, state) + await self.storage.finish(chat=self.chat, user=self.user) + if trigger is True: + for group_name, handlers in self._post_finish_handlers.items(): + if group_name in state: + for func in handlers: + await func(self, state) + class FSMContextProxy: def __init__(self, fsm_context: FSMContext): diff --git a/examples/finite_state_machine_signals.py b/examples/finite_state_machine_signals.py new file mode 100644 index 00000000..d8442855 --- /dev/null +++ b/examples/finite_state_machine_signals.py @@ -0,0 +1,167 @@ +import logging +from typing import Union + +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters import Text +from aiogram.dispatcher.filters.state import State, StatesGroup +from aiogram.types import ParseMode +from aiogram.utils import executor + +logging.basicConfig(level=logging.INFO) + +API_TOKEN = 'BOT_TOKEN_HERE' + + +bot = Bot(token=API_TOKEN) + +# For example use simple MemoryStorage for Dispatcher. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) + + +# States +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' + + +# You can use pre_set decorator to call decorated functions before state will be setted +@Form.name.pre_set +async def name_pre_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling before Form:name will be set') + await bot.send_message( + chat_id = context.chat, + text = "What's your name?" + ) + + +# You can use post_set decorator to call decorated functions after state will be setted +@Form.age.post_set +async def age_post_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling after Form:age will be set') + markup = types.InlineKeyboardMarkup(row_width=1) + markup.row(types.InlineKeyboardButton('Back', callback_data='back')) + + await bot.send_message( + chat_id = context.chat, + text = "How old are you?", + reply_markup = markup + ) + + +@Form.gender.pre_set +async def gender_pre_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling before gender:age will be set') + + markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) + markup.add("Male", "Female") + markup.add("Other", "Back") + + await bot.send_message( + chat_id = context.chat, + text = "What is your gender?", + reply_markup = markup + ) + + +# You can use pre_finish decorator to call decorated functions before state group will be finished +@Form.pre_finish +async def form_post_finish(context: FSMContext, old_state: str): + async with context.proxy() as data: + markup = types.ReplyKeyboardRemove() + + await bot.send_message( + context.chat, + md.text( + md.text('Hi! Nice to meet you,', md.bold(data['name'])), + md.text('Age:', md.code(data['age'])), + md.text('Gender:', data['gender']), + sep='\n', + ), + reply_markup=markup, + parse_mode=ParseMode.MARKDOWN, + ) + + +# You can use state Form if you need to handle all states from Form group +@dp.message_handler(commands='back', state=Form) +@dp.message_handler(Text(equals='back', ignore_case=True), state=Form) +@dp.callback_query_handler(text='back', state=Form) +async def process_gender_invalid(message: Union[types.Message, types.CallbackQuery]): + await Form.previous() + + +# You can use state '*' if you need to handle all states +@dp.message_handler(state='*', commands='cancel') +@dp.message_handler(Text(equals='cancel', ignore_case=True), state='*') +async def cancel_handler(message: types.Message, state: FSMContext): + """ + Allow user to cancel any action + """ + current_state = await state.get_state() + if current_state is None: + return + + logging.info('Cancelling state %r', current_state) + # Cancel state and inform user about it + # You can set trigger to False to prevent pre_set and post_set signals + await state.finish(trigger=False) + # And remove keyboard (just in case) + await message.reply('Cancelled.', reply_markup=types.ReplyKeyboardRemove()) + + +@dp.message_handler(commands='start') +async def cmd_start(message: types.Message): + """ + Conversation's entry point + """ + await message.reply("Hi there!") + await Form.name.set() + + +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state: FSMContext): + """ + Process user name and go to age + """ + await state.update_data(name=message.text) + await Form.next() + + +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) +async def process_age_invalid(message: types.Message): + """ + If age is not digit + """ + return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") + + +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state: FSMContext): + """ + Process user age and go to gender + """ + await state.update_data(age=int(message.text)) + await Form.next() + + +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) +async def process_gender_invalid(message: types.Message): + """ + In this example gender has to be one of: Male, Female, Other. + """ + return await message.reply("Bad gender name. Choose your gender from the keyboard.") + + +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state: FSMContext): + await state.update_data(gender=message.text) + await state.finish() + + +if __name__ == '__main__': + executor.start_polling(dp, skip_updates=True)