From 66533a84e3a2a2dc88edb956ccfecccc1a659147 Mon Sep 17 00:00:00 2001 From: PasNA6713 Date: Sun, 23 Jan 2022 15:00:05 +0300 Subject: [PATCH] refactor StatesGroup to use State.set() at all StatesGroup methods --- aiogram/dispatcher/dispatcher.py | 6 +- aiogram/dispatcher/filters/state.py | 49 +++++++---- ...es_group.py => test_states_group_names.py} | 0 tests/test_states_group_navigation.py | 83 +++++++++++++++++++ 4 files changed, 117 insertions(+), 21 deletions(-) rename tests/{test_states_group.py => test_states_group_names.py} (100%) create mode 100644 tests/test_states_group_navigation.py diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index e6160b3e..d549c333 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -1250,12 +1250,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin): """ if chat is None: chat_obj = types.Chat.get_current() - chat = chat_obj.id if chat_obj else None + chat_id = chat_obj.id if chat_obj else None if user is None: user_obj = types.User.get_current() - user = user_obj.id if user_obj else None + user_id = user_obj.id if user_obj else None - return FSMContext(storage=self.storage, chat=chat, user=user) + return FSMContext(storage=self.storage, chat=chat_id, user=user_id) @renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=3) @renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4) diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py index 16937e1c..1309a39a 100644 --- a/aiogram/dispatcher/filters/state.py +++ b/aiogram/dispatcher/filters/state.py @@ -9,10 +9,11 @@ class State: State object """ - def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None): + def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None, dispatcher = Dispatcher): self._state = state self._group_name = group_name self._group = None + self._dispatcher = dispatcher @property def group(self): @@ -41,6 +42,7 @@ class State: if not issubclass(group, StatesGroup): raise ValueError('Group must be subclass of StatesGroup') self._group = group + self._dispatcher = group._dispatcher def __set_name__(self, owner, name): if self._state is None: @@ -53,12 +55,12 @@ class State: __repr__ = __str__ async def set(self): - state = Dispatcher.get_current().current_state() + state = self._dispatcher.get_current().current_state() await state.set_state(self.state) class StatesGroupMeta(type): - def __new__(mcs, name, bases, namespace, **kwargs): + def __new__(mcs, name, bases, namespace, dispatcher=Dispatcher, **kwargs): cls = super(StatesGroupMeta, mcs).__new__(mcs, name, bases, namespace) states = [] @@ -78,6 +80,10 @@ class StatesGroupMeta(type): cls._childs = tuple(childs) cls._states = tuple(states) cls._state_names = tuple(state.state for state in states) + cls._dispatcher = dispatcher + + for state in states: + state.set_parent(cls) return cls @@ -142,7 +148,7 @@ class StatesGroupMeta(type): class StatesGroup(metaclass=StatesGroupMeta): @classmethod async def next(cls) -> str: - state = Dispatcher.get_current().current_state() + state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -151,16 +157,20 @@ class StatesGroup(metaclass=StatesGroupMeta): next_step = 0 try: - next_state_name = cls.states[next_step].state + next_state = cls.states[next_step] + next_state_name = next_state.state + await next_state.set() + except IndexError: next_state_name = None + await state.set_state(next_state_name) - await state.set_state(next_state_name) return next_state_name + @classmethod async def previous(cls) -> str: - state = Dispatcher.get_current().current_state() + state = cls._dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -170,27 +180,30 @@ class StatesGroup(metaclass=StatesGroupMeta): if previous_step < 0: previous_state_name = None - else: - previous_state_name = cls.states[previous_step].state + await state.set_state(previous_state_name) + + else: + previous_state = cls.states[previous_step] + previous_state_name = previous_state.state + await previous_state.set() - await state.set_state(previous_state_name) return previous_state_name @classmethod async def first(cls) -> str: - state = Dispatcher.get_current().current_state() - first_step_name = cls.states_names[0] + state = cls._dispatcher.get_current().current_state() + first_step = cls.states[0] - await state.set_state(first_step_name) - return first_step_name + await first_step.set() + return first_step.state @classmethod async def last(cls) -> str: - state = Dispatcher.get_current().current_state() - last_step_name = cls.states_names[-1] + state = cls._dispatcher.get_current().current_state() + last_step = cls.states[-1] - await state.set_state(last_step_name) - return last_step_name + await last_step.set() + return last_step.state default_state = State() diff --git a/tests/test_states_group.py b/tests/test_states_group_names.py similarity index 100% rename from tests/test_states_group.py rename to tests/test_states_group_names.py diff --git a/tests/test_states_group_navigation.py b/tests/test_states_group_navigation.py new file mode 100644 index 00000000..92339b46 --- /dev/null +++ b/tests/test_states_group_navigation.py @@ -0,0 +1,83 @@ +import pytest + +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import FSMContext +from aiogram.dispatcher.filters.state import StatesGroup, State + + +class FakeDispatcher: + context = FSMContext(MemoryStorage(), chat=1, user=1) + + @classmethod + def get_current(cls, no_error=True): + return cls + + @classmethod + def current_state(cls): + return cls.context + + @classmethod + def reset(cls): + cls.context = FSMContext(MemoryStorage(), chat=1, user=1) + + +@pytest.fixture +async def reset_dispatcher(): + FakeDispatcher.reset() + context = FakeDispatcher.get_current().current_state() + current_state = await context.get_state() + assert current_state == None + yield FakeDispatcher + + +class MyGroup(StatesGroup, dispatcher=FakeDispatcher()): + state_1 = State() + state_2 = State() + + +class TestNavigation: + @pytest.mark.asyncio + async def test_first(self, reset_dispatcher): + state = await MyGroup.first() + assert state == 'MyGroup:state_1' + + @pytest.mark.asyncio + async def test_last(self, reset_dispatcher): + state = await MyGroup.last() + assert state == 'MyGroup:state_2' + + class TestNext: + @pytest.mark.asyncio + async def test_next_from_none(self, reset_dispatcher): + state = await MyGroup.next() + assert state == 'MyGroup:state_1' + + @pytest.mark.asyncio + async def test_next_from_the_first_state(self, reset_dispatcher): + await MyGroup.state_1.set() + state = await MyGroup.next() + assert state == 'MyGroup:state_2' + + @pytest.mark.asyncio + async def test_next_from_the_last_state(self, reset_dispatcher): + await MyGroup.last() + state = await MyGroup.next() + assert state == None + + class TestPrevious: + @pytest.mark.asyncio + async def test_previous_from_none(self, reset_dispatcher): + state = await MyGroup.previous() + assert state == 'MyGroup:state_1' + + @pytest.mark.asyncio + async def test_previous_from_the_first_state(self, reset_dispatcher): + await MyGroup.first() + state = await MyGroup.previous() + assert state == None + + @pytest.mark.asyncio + async def test_previous_from_the_last_state(self, reset_dispatcher): + await MyGroup.state_2.set() + state = await MyGroup.previous() + assert state == 'MyGroup:state_1'