mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
refactor StatesGroup to use State.set() at all StatesGroup methods
This commit is contained in:
parent
4d2d811386
commit
66533a84e3
4 changed files with 117 additions and 21 deletions
|
|
@ -1250,12 +1250,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
|
|||
"""
|
||||
if chat is None:
|
||||
chat_obj = types.Chat.get_current()
|
||||
chat = chat_obj.id if chat_obj else None
|
||||
chat_id = chat_obj.id if chat_obj else None
|
||||
if user is None:
|
||||
user_obj = types.User.get_current()
|
||||
user = user_obj.id if user_obj else None
|
||||
user_id = user_obj.id if user_obj else None
|
||||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
return FSMContext(storage=self.storage, chat=chat_id, user=user_id)
|
||||
|
||||
@renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=3)
|
||||
@renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ class State:
|
|||
State object
|
||||
"""
|
||||
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
|
||||
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None, dispatcher = Dispatcher):
|
||||
self._state = state
|
||||
self._group_name = group_name
|
||||
self._group = None
|
||||
self._dispatcher = dispatcher
|
||||
|
||||
@property
|
||||
def group(self):
|
||||
|
|
@ -41,6 +42,7 @@ class State:
|
|||
if not issubclass(group, StatesGroup):
|
||||
raise ValueError('Group must be subclass of StatesGroup')
|
||||
self._group = group
|
||||
self._dispatcher = group._dispatcher
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
if self._state is None:
|
||||
|
|
@ -53,12 +55,12 @@ class State:
|
|||
__repr__ = __str__
|
||||
|
||||
async def set(self):
|
||||
state = Dispatcher.get_current().current_state()
|
||||
state = self._dispatcher.get_current().current_state()
|
||||
await state.set_state(self.state)
|
||||
|
||||
|
||||
class StatesGroupMeta(type):
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
def __new__(mcs, name, bases, namespace, dispatcher=Dispatcher, **kwargs):
|
||||
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
|
||||
|
||||
states = []
|
||||
|
|
@ -78,6 +80,10 @@ class StatesGroupMeta(type):
|
|||
cls._childs = tuple(childs)
|
||||
cls._states = tuple(states)
|
||||
cls._state_names = tuple(state.state for state in states)
|
||||
cls._dispatcher = dispatcher
|
||||
|
||||
for state in states:
|
||||
state.set_parent(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
|
@ -142,7 +148,7 @@ class StatesGroupMeta(type):
|
|||
class StatesGroup(metaclass=StatesGroupMeta):
|
||||
@classmethod
|
||||
async def next(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -151,16 +157,20 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
next_step = 0
|
||||
|
||||
try:
|
||||
next_state_name = cls.states[next_step].state
|
||||
next_state = cls.states[next_step]
|
||||
next_state_name = next_state.state
|
||||
await next_state.set()
|
||||
|
||||
except IndexError:
|
||||
next_state_name = None
|
||||
await state.set_state(next_state_name)
|
||||
|
||||
await state.set_state(next_state_name)
|
||||
return next_state_name
|
||||
|
||||
|
||||
@classmethod
|
||||
async def previous(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -170,27 +180,30 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
|
||||
if previous_step < 0:
|
||||
previous_state_name = None
|
||||
else:
|
||||
previous_state_name = cls.states[previous_step].state
|
||||
await state.set_state(previous_state_name)
|
||||
|
||||
else:
|
||||
previous_state = cls.states[previous_step]
|
||||
previous_state_name = previous_state.state
|
||||
await previous_state.set()
|
||||
|
||||
await state.set_state(previous_state_name)
|
||||
return previous_state_name
|
||||
|
||||
@classmethod
|
||||
async def first(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
first_step_name = cls.states_names[0]
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
first_step = cls.states[0]
|
||||
|
||||
await state.set_state(first_step_name)
|
||||
return first_step_name
|
||||
await first_step.set()
|
||||
return first_step.state
|
||||
|
||||
@classmethod
|
||||
async def last(cls) -> str:
|
||||
state = Dispatcher.get_current().current_state()
|
||||
last_step_name = cls.states_names[-1]
|
||||
state = cls._dispatcher.get_current().current_state()
|
||||
last_step = cls.states[-1]
|
||||
|
||||
await state.set_state(last_step_name)
|
||||
return last_step_name
|
||||
await last_step.set()
|
||||
return last_step.state
|
||||
|
||||
|
||||
default_state = State()
|
||||
|
|
|
|||
83
tests/test_states_group_navigation.py
Normal file
83
tests/test_states_group_navigation.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import FSMContext
|
||||
from aiogram.dispatcher.filters.state import StatesGroup, State
|
||||
|
||||
|
||||
class FakeDispatcher:
|
||||
context = FSMContext(MemoryStorage(), chat=1, user=1)
|
||||
|
||||
@classmethod
|
||||
def get_current(cls, no_error=True):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def current_state(cls):
|
||||
return cls.context
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
cls.context = FSMContext(MemoryStorage(), chat=1, user=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def reset_dispatcher():
|
||||
FakeDispatcher.reset()
|
||||
context = FakeDispatcher.get_current().current_state()
|
||||
current_state = await context.get_state()
|
||||
assert current_state == None
|
||||
yield FakeDispatcher
|
||||
|
||||
|
||||
class MyGroup(StatesGroup, dispatcher=FakeDispatcher()):
|
||||
state_1 = State()
|
||||
state_2 = State()
|
||||
|
||||
|
||||
class TestNavigation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first(self, reset_dispatcher):
|
||||
state = await MyGroup.first()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last(self, reset_dispatcher):
|
||||
state = await MyGroup.last()
|
||||
assert state == 'MyGroup:state_2'
|
||||
|
||||
class TestNext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_from_none(self, reset_dispatcher):
|
||||
state = await MyGroup.next()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_from_the_first_state(self, reset_dispatcher):
|
||||
await MyGroup.state_1.set()
|
||||
state = await MyGroup.next()
|
||||
assert state == 'MyGroup:state_2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_from_the_last_state(self, reset_dispatcher):
|
||||
await MyGroup.last()
|
||||
state = await MyGroup.next()
|
||||
assert state == None
|
||||
|
||||
class TestPrevious:
|
||||
@pytest.mark.asyncio
|
||||
async def test_previous_from_none(self, reset_dispatcher):
|
||||
state = await MyGroup.previous()
|
||||
assert state == 'MyGroup:state_1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_previous_from_the_first_state(self, reset_dispatcher):
|
||||
await MyGroup.first()
|
||||
state = await MyGroup.previous()
|
||||
assert state == None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_previous_from_the_last_state(self, reset_dispatcher):
|
||||
await MyGroup.state_2.set()
|
||||
state = await MyGroup.previous()
|
||||
assert state == 'MyGroup:state_1'
|
||||
Loading…
Add table
Add a link
Reference in a new issue