mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
Merge 4484323c4d into 69c9ecb7e0
This commit is contained in:
commit
115609cabf
6 changed files with 338 additions and 30 deletions
|
|
@ -1250,12 +1250,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
"""
|
||||
if chat is None:
|
||||
chat_obj = types.Chat.get_current()
|
||||
chat = chat_obj.id if chat_obj else None
|
||||
chat_id = chat_obj.id if chat_obj else None
|
||||
if user is None:
|
||||
user_obj = types.User.get_current()
|
||||
user = user_obj.id if user_obj else None
|
||||
user_id = user_obj.id if user_obj else None
|
||||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
return FSMContext(storage=self.storage, chat=chat_id, user=user_id)
|
||||
|
||||
@renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=3)
|
||||
@renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import inspect
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, List
|
||||
|
||||
from ..dispatcher import Dispatcher
|
||||
from ..storage import FSMContext
|
||||
|
||||
|
||||
class State:
|
||||
|
|
@ -9,10 +10,13 @@ class State:
|
|||
State object
|
||||
"""
|
||||
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None, dispatcher = Dispatcher):
|
||||
self._state = state
|
||||
self._group_name = group_name
|
||||
self._group = None
|
||||
self._dispatcher = dispatcher
|
||||
self._pre_set_handlers: List[Callable] = []
|
||||
self._post_set_handlers: List[Callable] = []
|
||||
|
||||
@property
|
||||
def group(self):
|
||||
|
|
@ -41,6 +45,7 @@ class State:
|
|||
if not issubclass(group, StatesGroup):
|
||||
raise ValueError('Group must be subclass of StatesGroup')
|
||||
self._group = group
|
||||
self._dispatcher = group._dispatcher
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
if self._state is None:
|
||||
|
|
@ -52,27 +57,46 @@ class State:
|
|||
|
||||
__repr__ = __str__
|
||||
|
||||
async def set(self):
|
||||
state = Dispatcher.get_current().current_state()
|
||||
await state.set_state(self.state)
|
||||
def pre_set(self, func_to_call):
|
||||
self._pre_set_handlers.append(func_to_call)
|
||||
|
||||
def post_set(self, func_to_call):
|
||||
self._post_set_handlers.append(func_to_call)
|
||||
|
||||
async def set(self, *args, trigger=True, **kwargs):
|
||||
context = self._dispatcher.get_current().current_state()
|
||||
old_state = await context.get_state()
|
||||
|
||||
if trigger is True:
|
||||
for func in self._pre_set_handlers:
|
||||
await func(context, old_state, self.state)
|
||||
|
||||
await context.set_state(self.state)
|
||||
|
||||
if trigger is True:
|
||||
for func in self._post_set_handlers:
|
||||
await func(context, old_state, self.state)
|
||||
|
||||
|
||||
class StatesGroupMeta(type):
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
def __new__(mcs, name, bases, namespace, dispatcher=Dispatcher, **kwargs):
|
||||
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
|
||||
|
||||
states = []
|
||||
childs = []
|
||||
|
||||
cls._group_name = name
|
||||
cls._dispatcher = dispatcher
|
||||
|
||||
for name, prop in namespace.items():
|
||||
|
||||
if isinstance(prop, State):
|
||||
states.append(prop)
|
||||
prop.set_parent(cls)
|
||||
|
||||
elif inspect.isclass(prop) and issubclass(prop, StatesGroup):
|
||||
childs.append(prop)
|
||||
prop._parent = cls
|
||||
prop._dispatcher = dispatcher
|
||||
|
||||
cls._parent = None
|
||||
cls._childs = tuple(childs)
|
||||
|
|
@ -141,8 +165,8 @@ class StatesGroupMeta(type):
|
|||
|
||||
class StatesGroup(metaclass=StatesGroupMeta):
|
||||
@classmethod
|
||||
async def next(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
async def next(cls, *args, trigger=True, **kwargs) -> str:
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -151,16 +175,20 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
next_step = 0
|
||||
|
||||
try:
|
||||
next_state_name = cls.states[next_step].state
|
||||
next_state = cls.states[next_step]
|
||||
next_state_name = next_state.state
|
||||
await next_state.set(trigger=trigger)
|
||||
|
||||
except IndexError:
|
||||
next_state_name = None
|
||||
await state.set_state(next_state_name)
|
||||
|
||||
await state.set_state(next_state_name)
|
||||
return next_state_name
|
||||
|
||||
|
||||
@classmethod
|
||||
async def previous(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
async def previous(cls, *args, trigger=True, **kwargs) -> str:
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -170,27 +198,38 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
|
||||
if previous_step < 0:
|
||||
previous_state_name = None
|
||||
else:
|
||||
previous_state_name = cls.states[previous_step].state
|
||||
await state.set_state(previous_state_name)
|
||||
|
||||
else:
|
||||
previous_state = cls.states[previous_step]
|
||||
previous_state_name = previous_state.state
|
||||
await previous_state.set(trigger=trigger)
|
||||
|
||||
await state.set_state(previous_state_name)
|
||||
return previous_state_name
|
||||
|
||||
@classmethod
|
||||
async def first(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
first_step_name = cls.states_names[0]
|
||||
async def first(cls, *args, trigger=True, **kwargs) -> str:
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
first_step = cls.states[0]
|
||||
|
||||
await state.set_state(first_step_name)
|
||||
return first_step_name
|
||||
await first_step.set(trigger=trigger)
|
||||
return first_step.state
|
||||
|
||||
@classmethod
|
||||
async def last(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
last_step_name = cls.states_names[-1]
|
||||
async def last(cls, *args, trigger=True, **kwargs) -> str:
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
last_step = cls.states[-1]
|
||||
|
||||
await state.set_state(last_step_name)
|
||||
return last_step_name
|
||||
await last_step.set(trigger=trigger)
|
||||
return last_step.state
|
||||
|
||||
@classmethod
|
||||
def pre_finish(cls, func_to_call) -> str:
|
||||
FSMContext.set_pre_finish_hanlder(cls._group_name, func_to_call)
|
||||
|
||||
@classmethod
|
||||
def post_finish(cls, func_to_call) -> str:
|
||||
FSMContext.set_post_finish_hanlder(cls._group_name, func_to_call)
|
||||
|
||||
|
||||
default_state = State()
|
||||
|
|
|
|||
|
|
@ -290,12 +290,23 @@ class BaseStorage:
|
|||
|
||||
|
||||
class FSMContext:
|
||||
_pre_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {}
|
||||
_post_finish_handlers: typing.Dict[str, typing.List[typing.Callable]] = {}
|
||||
|
||||
def __init__(self, storage, chat, user):
|
||||
self.storage: BaseStorage = storage
|
||||
self.chat, self.user = self.storage.check_address(chat=chat, user=user)
|
||||
|
||||
def proxy(self):
|
||||
return FSMContextProxy(self)
|
||||
|
||||
@classmethod
|
||||
def set_pre_finish_hanlder(cls, group_name: str, handler: typing.Callable):
|
||||
cls._pre_finish_handlers[group_name] = cls._pre_finish_handlers.get(group_name, []) + [handler]
|
||||
|
||||
@classmethod
|
||||
def set_post_finish_hanlder(cls, group_name: str, handler: typing.Callable):
|
||||
cls._post_finish_handlers[group_name] = cls._post_finish_handlers.get(group_name, []) + [handler]
|
||||
|
||||
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
return await self.storage.get_state(chat=self.chat, user=self.user, default=default)
|
||||
|
|
@ -318,9 +329,23 @@ class FSMContext:
|
|||
async def reset_data(self):
|
||||
await self.storage.reset_data(chat=self.chat, user=self.user)
|
||||
|
||||
async def finish(self):
|
||||
async def finish(self, *args, trigger=True, **kwargs):
|
||||
state = await self.get_state()
|
||||
|
||||
if trigger is True:
|
||||
for group_name, handlers in self._pre_finish_handlers.items():
|
||||
if group_name in state:
|
||||
for func in handlers:
|
||||
await func(self, state)
|
||||
|
||||
await self.storage.finish(chat=self.chat, user=self.user)
|
||||
|
||||
if trigger is True:
|
||||
for group_name, handlers in self._post_finish_handlers.items():
|
||||
if group_name in state:
|
||||
for func in handlers:
|
||||
await func(self, state)
|
||||
|
||||
|
||||
class FSMContextProxy:
|
||||
def __init__(self, fsm_context: FSMContext):
|
||||
|
|
|
|||
167
examples/finite_state_machine_signals.py
Normal file
167
examples/finite_state_machine_signals.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
import logging
|
||||
from typing import Union
|
||||
|
||||
import aiogram.utils.markdown as md
|
||||
from aiogram import Bot, Dispatcher, types
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import FSMContext
|
||||
from aiogram.dispatcher.filters import Text
|
||||
from aiogram.dispatcher.filters.state import State, StatesGroup
|
||||
from aiogram.types import ParseMode
|
||||
from aiogram.utils import executor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
API_TOKEN = 'BOT_TOKEN_HERE'
|
||||
|
||||
|
||||
bot = Bot(token=API_TOKEN)
|
||||
|
||||
# For example use simple MemoryStorage for Dispatcher.
|
||||
storage = MemoryStorage()
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
|
||||
|
||||
# States
|
||||
class Form(StatesGroup):
|
||||
name = State() # Will be represented in storage as 'Form:name'
|
||||
age = State() # Will be represented in storage as 'Form:age'
|
||||
gender = State() # Will be represented in storage as 'Form:gender'
|
||||
|
||||
|
||||
# You can use pre_set decorator to call decorated functions before state will be setted
|
||||
@Form.name.pre_set
|
||||
async def name_pre_set(context: FSMContext, old_state: str, new_state: str):
|
||||
logging.info('Calling before Form:name will be set')
|
||||
await bot.send_message(
|
||||
chat_id = context.chat,
|
||||
text = "What's your name?"
|
||||
)
|
||||
|
||||
|
||||
# You can use post_set decorator to call decorated functions after state will be setted
|
||||
@Form.age.post_set
|
||||
async def age_post_set(context: FSMContext, old_state: str, new_state: str):
|
||||
logging.info('Calling after Form:age will be set')
|
||||
markup = types.InlineKeyboardMarkup(row_width=1)
|
||||
markup.row(types.InlineKeyboardButton('Back', callback_data='back'))
|
||||
|
||||
await bot.send_message(
|
||||
chat_id = context.chat,
|
||||
text = "How old are you?",
|
||||
reply_markup = markup
|
||||
)
|
||||
|
||||
|
||||
@Form.gender.pre_set
|
||||
async def gender_pre_set(context: FSMContext, old_state: str, new_state: str):
|
||||
logging.info('Calling before gender:age will be set')
|
||||
|
||||
markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True)
|
||||
markup.add("Male", "Female")
|
||||
markup.add("Other", "Back")
|
||||
|
||||
await bot.send_message(
|
||||
chat_id = context.chat,
|
||||
text = "What is your gender?",
|
||||
reply_markup = markup
|
||||
)
|
||||
|
||||
|
||||
# You can use pre_finish decorator to call decorated functions before state group will be finished
|
||||
@Form.pre_finish
|
||||
async def form_pre_finish(context: FSMContext, old_state: str):
|
||||
async with context.proxy() as data:
|
||||
markup = types.ReplyKeyboardRemove()
|
||||
|
||||
await bot.send_message(
|
||||
context.chat,
|
||||
md.text(
|
||||
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
|
||||
md.text('Age:', md.code(data['age'])),
|
||||
md.text('Gender:', data['gender']),
|
||||
sep='\n',
|
||||
),
|
||||
reply_markup=markup,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
)
|
||||
|
||||
|
||||
# You can use state Form if you need to handle all states from Form group
|
||||
@dp.message_handler(commands='back', state=Form)
|
||||
@dp.message_handler(Text(equals='back', ignore_case=True), state=Form)
|
||||
@dp.callback_query_handler(text='back', state=Form)
|
||||
async def process_gender_invalid(message: Union[types.Message, types.CallbackQuery]):
|
||||
await Form.previous()
|
||||
|
||||
|
||||
# You can use state '*' if you need to handle all states
|
||||
@dp.message_handler(state='*', commands='cancel')
|
||||
@dp.message_handler(Text(equals='cancel', ignore_case=True), state='*')
|
||||
async def cancel_handler(message: types.Message, state: FSMContext):
|
||||
"""
|
||||
Allow user to cancel any action
|
||||
"""
|
||||
current_state = await state.get_state()
|
||||
if current_state is None:
|
||||
return
|
||||
|
||||
logging.info('Cancelling state %r', current_state)
|
||||
# Cancel state and inform user about it
|
||||
# You can set trigger to False to prevent pre_set and post_set signals
|
||||
await state.finish(trigger=False)
|
||||
# And remove keyboard (just in case)
|
||||
await message.reply('Cancelled.', reply_markup=types.ReplyKeyboardRemove())
|
||||
|
||||
|
||||
@dp.message_handler(commands='start')
|
||||
async def cmd_start(message: types.Message):
|
||||
"""
|
||||
Conversation's entry point
|
||||
"""
|
||||
await message.reply("Hi there!")
|
||||
await Form.name.set()
|
||||
|
||||
|
||||
@dp.message_handler(state=Form.name)
|
||||
async def process_name(message: types.Message, state: FSMContext):
|
||||
"""
|
||||
Process user name and go to age
|
||||
"""
|
||||
await state.update_data(name=message.text)
|
||||
await Form.next()
|
||||
|
||||
|
||||
@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age)
|
||||
async def process_age_invalid(message: types.Message):
|
||||
"""
|
||||
If age is not digit
|
||||
"""
|
||||
return await message.reply("Age gotta be a number.\nHow old are you? (digits only)")
|
||||
|
||||
|
||||
@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age)
|
||||
async def process_age(message: types.Message, state: FSMContext):
|
||||
"""
|
||||
Process user age and go to gender
|
||||
"""
|
||||
await state.update_data(age=int(message.text))
|
||||
await Form.next()
|
||||
|
||||
|
||||
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
|
||||
async def process_gender_invalid(message: types.Message):
|
||||
"""
|
||||
In this example gender has to be one of: Male, Female, Other.
|
||||
"""
|
||||
return await message.reply("Bad gender name. Choose your gender from the keyboard.")
|
||||
|
||||
|
||||
@dp.message_handler(state=Form.gender)
|
||||
async def process_gender(message: types.Message, state: FSMContext):
|
||||
await state.update_data(gender=message.text)
|
||||
await state.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
executor.start_polling(dp, skip_updates=True)
|
||||
77
tests/test_states_group_navigation.py
Normal file
77
tests/test_states_group_navigation.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import FSMContext
|
||||
from aiogram.dispatcher.filters.state import StatesGroup, State
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class FakeDispatcher:
|
||||
context = FSMContext(MemoryStorage(), chat=1, user=1)
|
||||
|
||||
@classmethod
|
||||
def get_current(cls, no_error=True):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def current_state(cls):
|
||||
return cls.context
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
cls.context = FSMContext(MemoryStorage(), chat=1, user=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def reset_dispatcher():
|
||||
FakeDispatcher.reset()
|
||||
context = FakeDispatcher.get_current().current_state()
|
||||
current_state = await context.get_state()
|
||||
assert current_state == None
|
||||
yield FakeDispatcher
|
||||
|
||||
|
||||
class MyGroup(StatesGroup, dispatcher=FakeDispatcher()):
|
||||
state_1 = State()
|
||||
state_2 = State()
|
||||
|
||||
|
||||
class TestNavigation:
|
||||
async def test_first(self, reset_dispatcher):
|
||||
state = await MyGroup.first()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
async def test_last(self, reset_dispatcher):
|
||||
state = await MyGroup.last()
|
||||
assert state == 'MyGroup:state_2'
|
||||
|
||||
class TestNext:
|
||||
async def test_next_from_none(self, reset_dispatcher):
|
||||
state = await MyGroup.next()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
async def test_next_from_the_first_state(self, reset_dispatcher):
|
||||
await MyGroup.state_1.set()
|
||||
state = await MyGroup.next()
|
||||
assert state == 'MyGroup:state_2'
|
||||
|
||||
async def test_next_from_the_last_state(self, reset_dispatcher):
|
||||
await MyGroup.last()
|
||||
state = await MyGroup.next()
|
||||
assert state == None
|
||||
|
||||
class TestPrevious:
|
||||
async def test_previous_from_none(self, reset_dispatcher):
|
||||
state = await MyGroup.previous()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
async def test_previous_from_the_first_state(self, reset_dispatcher):
|
||||
await MyGroup.first()
|
||||
state = await MyGroup.previous()
|
||||
assert state == None
|
||||
|
||||
async def test_previous_from_the_last_state(self, reset_dispatcher):
|
||||
await MyGroup.state_2.set()
|
||||
state = await MyGroup.previous()
|
||||
assert state == 'MyGroup:state_1'
|
||||
Loading…
Add table
Add a link
Reference in a new issue