diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 22f6700c..7be2bceb 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -1250,12 +1250,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin): """ if chat is None: chat_obj = types.Chat.get_current() - chat = chat_obj.id if chat_obj else None + chat_id = chat_obj.id if chat_obj else None if user is None: user_obj = types.User.get_current() - user = user_obj.id if user_obj else None + user_id = user_obj.id if user_obj else None - return FSMContext(storage=self.storage, chat=chat, user=user) + return FSMContext(storage=self.storage, chat=chat_id, user=user_id) @renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=3) @renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4) diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py index 16937e1c..2783e5d8 100644 --- a/aiogram/dispatcher/filters/state.py +++ b/aiogram/dispatcher/filters/state.py @@ -1,7 +1,8 @@ import inspect -from typing import Optional +from typing import Optional, Callable, List from ..dispatcher import Dispatcher +from ..storage import FSMContext class State: @@ -9,10 +10,13 @@ class State: State object """ - def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None): + def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None, dispatcher = Dispatcher): self._state = state self._group_name = group_name self._group = None + self._dispatcher = dispatcher + self._pre_set_handlers: List[Callable] = [] + self._post_set_handlers: List[Callable] = [] @property def group(self): @@ -41,6 +45,7 @@ class State: if not issubclass(group, StatesGroup): raise ValueError('Group must be subclass of StatesGroup') self._group = group + self._dispatcher = group._dispatcher def __set_name__(self, owner, name): if self._state is None: @@ -52,27 +57,46 @@ class State: __repr__ = __str__ - async def set(self): - state = Dispatcher.get_current().current_state() - await state.set_state(self.state) + def pre_set(self, func_to_call): + self._pre_set_handlers.append(func_to_call) + + def post_set(self, func_to_call): + self._post_set_handlers.append(func_to_call) + + async def set(self, *args, trigger=True, **kwargs): + context = self._dispatcher.get_current().current_state() + old_state = await context.get_state() + + if trigger is True: + for func in self._pre_set_handlers: + await func(context, old_state, self.state) + + await context.set_state(self.state) + + if trigger is True: + for func in self._post_set_handlers: + await func(context, old_state, self.state) class StatesGroupMeta(type): - def __new__(mcs, name, bases, namespace, **kwargs): + def __new__(mcs, name, bases, namespace, dispatcher=Dispatcher, **kwargs): cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace) states = [] childs = [] cls._group_name = name + cls._dispatcher = dispatcher for name, prop in namespace.items(): - if isinstance(prop, State): states.append(prop) + prop.set_parent(cls) + elif inspect.isclass(prop) and issubclass(prop, StatesGroup): childs.append(prop) prop._parent = cls + prop._dispatcher = dispatcher cls._parent = None cls._childs = tuple(childs) @@ -141,8 +165,8 @@ class StatesGroupMeta(type): class StatesGroup(metaclass=StatesGroupMeta): @classmethod - async def next(cls) -> str: - state = Dispatcher.get_current().current_state() + async def next(cls, *args, trigger=True, **kwargs) -> str: + state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -151,16 +175,20 @@ class StatesGroup(metaclass=StatesGroupMeta): next_step = 0 try: - next_state_name = cls.states[next_step].state + next_state = cls.states[next_step] + next_state_name = next_state.state + await next_state.set(trigger=trigger) + except IndexError: next_state_name = None + await state.set_state(next_state_name) - await state.set_state(next_state_name) return next_state_name + @classmethod - async def previous(cls) -> str: - state = Dispatcher.get_current().current_state() + async def previous(cls, *args, trigger=True, **kwargs) -> str: + state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -170,27 +198,38 @@ class StatesGroup(metaclass=StatesGroupMeta): if previous_step < 0: previous_state_name = None - else: - previous_state_name = cls.states[previous_step].state + await state.set_state(previous_state_name) + + else: + previous_state = cls.states[previous_step] + previous_state_name = previous_state.state + await previous_state.set(trigger=trigger) - await state.set_state(previous_state_name) return previous_state_name @classmethod - async def first(cls) -> str: - state = Dispatcher.get_current().current_state() - first_step_name = cls.states_names[0] + async def first(cls, *args, trigger=True, **kwargs) -> str: + state = cls._dispatcher.get_current().current_state() + first_step = cls.states[0] - await state.set_state(first_step_name) - return first_step_name + await first_step.set(trigger=trigger) + return first_step.state @classmethod - async def last(cls) -> str: - state = Dispatcher.get_current().current_state() - last_step_name = cls.states_names[-1] + async def last(cls, *args, trigger=True, **kwargs) -> str: + state = cls._dispatcher.get_current().current_state() + last_step = cls.states[-1] - await state.set_state(last_step_name) - return last_step_name + await last_step.set(trigger=trigger) + return last_step.state + + @classmethod + def pre_finish(cls, func_to_call) -> str: + FSMContext.set_pre_finish_hanlder(cls._group_name, func_to_call) + + @classmethod + def post_finish(cls, func_to_call) -> str: + FSMContext.set_post_finish_hanlder(cls._group_name, func_to_call) default_state = State() diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py index dcc42ed8..9fda94cc 100644 --- a/aiogram/dispatcher/storage.py +++ b/aiogram/dispatcher/storage.py @@ -290,12 +290,23 @@ class BaseStorage: class FSMContext: + _pre_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {} + _post_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {} + def __init__(self, storage, chat, user): self.storage: BaseStorage = storage self.chat, self.user = self.storage.check_address(chat=chat, user=user) def proxy(self): return FSMContextProxy(self) + + @classmethod + def set_pre_finish_hanlder(cls, group_name: str, handler: typing.Callable): + cls._pre_finish_handlers[group_name] = cls._pre_finish_handlers.get(group_name, []) + [handler] + + @classmethod + def set_post_finish_hanlder(cls, group_name: str, handler: typing.Callable): + cls._post_finish_handlers[group_name] = cls._post_finish_handlers.get(group_name, []) + [handler] async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]: return await self.storage.get_state(chat=self.chat, user=self.user, default=default) @@ -318,9 +329,23 @@ class FSMContext: async def reset_data(self): await self.storage.reset_data(chat=self.chat, user=self.user) - async def finish(self): + async def finish(self, *args, trigger=True, **kwargs): + state = await self.get_state() + + if trigger is True: + for group_name, handlers in self._pre_finish_handlers.items(): + if group_name in state: + for func in handlers: + await func(self, state) + await self.storage.finish(chat=self.chat, user=self.user) + if trigger is True: + for group_name, handlers in self._post_finish_handlers.items(): + if group_name in state: + for func in handlers: + await func(self, state) + class FSMContextProxy: def __init__(self, fsm_context: FSMContext): diff --git a/examples/finite_state_machine_signals.py b/examples/finite_state_machine_signals.py new file mode 100644 index 00000000..b05b38b0 --- /dev/null +++ b/examples/finite_state_machine_signals.py @@ -0,0 +1,167 @@ +import logging +from typing import Union + +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters import Text +from aiogram.dispatcher.filters.state import State, StatesGroup +from aiogram.types import ParseMode +from aiogram.utils import executor + +logging.basicConfig(level=logging.INFO) + +API_TOKEN = 'BOT_TOKEN_HERE' + + +bot = Bot(token=API_TOKEN) + +# For example use simple MemoryStorage for Dispatcher. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) + + +# States +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' + + +# You can use pre_set decorator to call decorated functions before state will be setted +@Form.name.pre_set +async def name_pre_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling before Form:name will be set') + await bot.send_message( + chat_id = context.chat, + text = "What's your name?" + ) + + +# You can use post_set decorator to call decorated functions after state will be setted +@Form.age.post_set +async def age_post_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling after Form:age will be set') + markup = types.InlineKeyboardMarkup(row_width=1) + markup.row(types.InlineKeyboardButton('Back', callback_data='back')) + + await bot.send_message( + chat_id = context.chat, + text = "How old are you?", + reply_markup = markup + ) + + +@Form.gender.pre_set +async def gender_pre_set(context: FSMContext, old_state: str, new_state: str): + logging.info('Calling before gender:age will be set') + + markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) + markup.add("Male", "Female") + markup.add("Other", "Back") + + await bot.send_message( + chat_id = context.chat, + text = "What is your gender?", + reply_markup = markup + ) + + +# You can use pre_finish decorator to call decorated functions before state group will be finished +@Form.pre_finish +async def form_pre_finish(context: FSMContext, old_state: str): + async with context.proxy() as data: + markup = types.ReplyKeyboardRemove() + + await bot.send_message( + context.chat, + md.text( + md.text('Hi! Nice to meet you,', md.bold(data['name'])), + md.text('Age:', md.code(data['age'])), + md.text('Gender:', data['gender']), + sep='\n', + ), + reply_markup=markup, + parse_mode=ParseMode.MARKDOWN, + ) + + +# You can use state Form if you need to handle all states from Form group +@dp.message_handler(commands='back', state=Form) +@dp.message_handler(Text(equals='back', ignore_case=True), state=Form) +@dp.callback_query_handler(text='back', state=Form) +async def process_gender_invalid(message: Union[types.Message, types.CallbackQuery]): + await Form.previous() + + +# You can use state '*' if you need to handle all states +@dp.message_handler(state='*', commands='cancel') +@dp.message_handler(Text(equals='cancel', ignore_case=True), state='*') +async def cancel_handler(message: types.Message, state: FSMContext): + """ + Allow user to cancel any action + """ + current_state = await state.get_state() + if current_state is None: + return + + logging.info('Cancelling state %r', current_state) + # Cancel state and inform user about it + # You can set trigger to False to prevent pre_set and post_set signals + await state.finish(trigger=False) + # And remove keyboard (just in case) + await message.reply('Cancelled.', reply_markup=types.ReplyKeyboardRemove()) + + +@dp.message_handler(commands='start') +async def cmd_start(message: types.Message): + """ + Conversation's entry point + """ + await message.reply("Hi there!") + await Form.name.set() + + +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state: FSMContext): + """ + Process user name and go to age + """ + await state.update_data(name=message.text) + await Form.next() + + +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) +async def process_age_invalid(message: types.Message): + """ + If age is not digit + """ + return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") + + +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state: FSMContext): + """ + Process user age and go to gender + """ + await state.update_data(age=int(message.text)) + await Form.next() + + +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) +async def process_gender_invalid(message: types.Message): + """ + In this example gender has to be one of: Male, Female, Other. + """ + return await message.reply("Bad gender name. Choose your gender from the keyboard.") + + +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state: FSMContext): + await state.update_data(gender=message.text) + await state.finish() + + +if __name__ == '__main__': + executor.start_polling(dp, skip_updates=True) diff --git a/tests/test_states_group.py b/tests/test_states_group_names.py similarity index 100% rename from tests/test_states_group.py rename to tests/test_states_group_names.py diff --git a/tests/test_states_group_navigation.py b/tests/test_states_group_navigation.py new file mode 100644 index 00000000..b4241ff1 --- /dev/null +++ b/tests/test_states_group_navigation.py @@ -0,0 +1,77 @@ +import pytest + +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters.state import StatesGroup, State + +pytestmark = pytest.mark.asyncio + + +class FakeDispatcher: + context = FSMContext(MemoryStorage(), chat=1, user=1) + + @classmethod + def get_current(cls, no_error=True): + return cls + + @classmethod + def current_state(cls): + return cls.context + + @classmethod + def reset(cls): + cls.context = FSMContext(MemoryStorage(), chat=1, user=1) + + +@pytest.fixture +async def reset_dispatcher(): + FakeDispatcher.reset() + context = FakeDispatcher.get_current().current_state() + current_state = await context.get_state() + assert current_state == None + yield FakeDispatcher + + +class MyGroup(StatesGroup, dispatcher=FakeDispatcher()): + state_1 = State() + state_2 = State() + + +class TestNavigation: + async def test_first(self, reset_dispatcher): + state = await MyGroup.first() + assert state == 'MyGroup:state_1' + + async def test_last(self, reset_dispatcher): + state = await MyGroup.last() + assert state == 'MyGroup:state_2' + + class TestNext: + async def test_next_from_none(self, reset_dispatcher): + state = await MyGroup.next() + assert state == 'MyGroup:state_1' + + async def test_next_from_the_first_state(self, reset_dispatcher): + await MyGroup.state_1.set() + state = await MyGroup.next() + assert state == 'MyGroup:state_2' + + async def test_next_from_the_last_state(self, reset_dispatcher): + await MyGroup.last() + state = await MyGroup.next() + assert state == None + + class TestPrevious: + async def test_previous_from_none(self, reset_dispatcher): + state = await MyGroup.previous() + assert state == 'MyGroup:state_1' + + async def test_previous_from_the_first_state(self, reset_dispatcher): + await MyGroup.first() + state = await MyGroup.previous() + assert state == None + + async def test_previous_from_the_last_state(self, reset_dispatcher): + await MyGroup.state_2.set() + state = await MyGroup.previous() + assert state == 'MyGroup:state_1'