refactor StatesGroup to use State.set() at all StatesGroup methods

This commit is contained in:
PasNA6713 2022-01-23 15:00:05 +03:00
parent 4d2d811386
commit 66533a84e3
4 changed files with 117 additions and 21 deletions

View file

@ -1250,12 +1250,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
if chat is None:
chat_obj = types.Chat.get_current()
chat = chat_obj.id if chat_obj else None
chat_id = chat_obj.id if chat_obj else None
if user is None:
user_obj = types.User.get_current()
user = user_obj.id if user_obj else None
user_id = user_obj.id if user_obj else None
return FSMContext(storage=self.storage, chat=chat, user=user)
return FSMContext(storage=self.storage, chat=chat_id, user=user_id)
@renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=3)
@renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4)

View file

@ -9,10 +9,11 @@ class State:
State object
"""
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None):
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None, dispatcher = Dispatcher):
self._state = state
self._group_name = group_name
self._group = None
self._dispatcher = dispatcher
@property
def group(self):
@ -41,6 +42,7 @@ class State:
if not issubclass(group, StatesGroup):
raise ValueError('Group must be subclass of StatesGroup')
self._group = group
self._dispatcher = group._dispatcher
def __set_name__(self, owner, name):
if self._state is None:
@ -53,12 +55,12 @@ class State:
__repr__ = __str__
async def set(self):
state = Dispatcher.get_current().current_state()
state = self._dispatcher.get_current().current_state()
await state.set_state(self.state)
class StatesGroupMeta(type):
def __new__(mcs, name, bases, namespace, **kwargs):
def __new__(mcs, name, bases, namespace, dispatcher=Dispatcher, **kwargs):
cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace)
states = []
@ -78,6 +80,10 @@ class StatesGroupMeta(type):
cls._childs = tuple(childs)
cls._states = tuple(states)
cls._state_names = tuple(state.state for state in states)
cls._dispatcher = dispatcher
for state in states:
state.set_parent(cls)
return cls
@ -142,7 +148,7 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def next(cls) -> str:
state = Dispatcher.get_current().current_state()
state = cls._dispatcher.get_current().current_state()
state_name = await state.get_state()
try:
@ -151,16 +157,20 @@ class StatesGroup(metaclass=StatesGroupMeta):
next_step = 0
try:
next_state_name = cls.states[next_step].state
next_state = cls.states[next_step]
next_state_name = next_state.state
await next_state.set()
except IndexError:
next_state_name = None
await state.set_state(next_state_name)
await state.set_state(next_state_name)
return next_state_name
@classmethod
async def previous(cls) -> str:
state = Dispatcher.get_current().current_state()
state = cls._dispatcher.get_current().current_state()
state_name = await state.get_state()
try:
@ -170,27 +180,30 @@ class StatesGroup(metaclass=StatesGroupMeta):
if previous_step < 0:
previous_state_name = None
else:
previous_state_name = cls.states[previous_step].state
await state.set_state(previous_state_name)
else:
previous_state = cls.states[previous_step]
previous_state_name = previous_state.state
await previous_state.set()
await state.set_state(previous_state_name)
return previous_state_name
@classmethod
async def first(cls) -> str:
state = Dispatcher.get_current().current_state()
first_step_name = cls.states_names[0]
state = cls._dispatcher.get_current().current_state()
first_step = cls.states[0]
await state.set_state(first_step_name)
return first_step_name
await first_step.set()
return first_step.state
@classmethod
async def last(cls) -> str:
state = Dispatcher.get_current().current_state()
last_step_name = cls.states_names[-1]
state = cls._dispatcher.get_current().current_state()
last_step = cls.states[-1]
await state.set_state(last_step_name)
return last_step_name
await last_step.set()
return last_step.state
default_state = State()

View file

@ -0,0 +1,83 @@
import pytest
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import FSMContext
from aiogram.dispatcher.filters.state import StatesGroup, State
class FakeDispatcher:
context = FSMContext(MemoryStorage(), chat=1, user=1)
@classmethod
def get_current(cls, no_error=True):
return cls
@classmethod
def current_state(cls):
return cls.context
@classmethod
def reset(cls):
cls.context = FSMContext(MemoryStorage(), chat=1, user=1)
@pytest.fixture
async def reset_dispatcher():
FakeDispatcher.reset()
context = FakeDispatcher.get_current().current_state()
current_state = await context.get_state()
assert current_state == None
yield FakeDispatcher
class MyGroup(StatesGroup, dispatcher=FakeDispatcher()):
state_1 = State()
state_2 = State()
class TestNavigation:
@pytest.mark.asyncio
async def test_first(self, reset_dispatcher):
state = await MyGroup.first()
assert state == 'MyGroup:state_1'
@pytest.mark.asyncio
async def test_last(self, reset_dispatcher):
state = await MyGroup.last()
assert state == 'MyGroup:state_2'
class TestNext:
@pytest.mark.asyncio
async def test_next_from_none(self, reset_dispatcher):
state = await MyGroup.next()
assert state == 'MyGroup:state_1'
@pytest.mark.asyncio
async def test_next_from_the_first_state(self, reset_dispatcher):
await MyGroup.state_1.set()
state = await MyGroup.next()
assert state == 'MyGroup:state_2'
@pytest.mark.asyncio
async def test_next_from_the_last_state(self, reset_dispatcher):
await MyGroup.last()
state = await MyGroup.next()
assert state == None
class TestPrevious:
@pytest.mark.asyncio
async def test_previous_from_none(self, reset_dispatcher):
state = await MyGroup.previous()
assert state == 'MyGroup:state_1'
@pytest.mark.asyncio
async def test_previous_from_the_first_state(self, reset_dispatcher):
await MyGroup.first()
state = await MyGroup.previous()
assert state == None
@pytest.mark.asyncio
async def test_previous_from_the_last_state(self, reset_dispatcher):
await MyGroup.state_2.set()
state = await MyGroup.previous()
assert state == 'MyGroup:state_1'