diff --git a/aiogram/dispatcher/state.py b/aiogram/dispatcher/state.py index 28793634..14d73653 100644 --- a/aiogram/dispatcher/state.py +++ b/aiogram/dispatcher/state.py @@ -22,6 +22,7 @@ class BaseStorage: def set_state(self, chat, user, state): """ Set state + :param chat: chat_id :param user: user_id :param state: value @@ -31,6 +32,7 @@ class BaseStorage: def get_state(self, chat, user): """ Get user state from + :param chat: :param user: :return: @@ -49,8 +51,77 @@ class BaseStorage: def all_states(self, chat=None, user=None, state=None): """ Yield all states (Can use filters) + :param chat: :param user: + :param state: + :return: + """ + raise NotImplementedError + + def set_value(self, chat, user, key, value): + """ + Set value for user in storage + + :param chat: + :param user: + :param key: + :param value: + :return: + """ + raise NotImplementedError + + def get_value(self, chat, user, key, default=None): + """ + Get value from storage + + By default, this method calls `self.get_data(chat, user).get(key, default)` + :param chat: + :param user: + :param key: + :param default: + :return: + """ + return self.get_data(chat, user).get(key, default) + + def del_value(self, chat, user, key): + """ + Delete value from storage + + :param chat: + :param user: + :param key: + """ + raise NotImplementedError + + def get_data(self, chat, user): + """ + Get all stored data for user + + :param chat: + :param user: + :return: dict + """ + raise NotImplementedError + + def update_data(self, chat, user, data): + """ + Update data in storage + + :param chat: + :param user: + :param data: + :return: + """ + raise NotImplementedError + + def clear_data(self, chat, user, key): + """ + Clear data in storage + + :param chat: + :param user: + :param key: :return: """ raise NotImplementedError @@ -85,6 +156,11 @@ class BaseStorage: return self.get_state(key, key) def __delitem__(self, key): + """ + Reset state for user + :param key: + :return: + """ if isinstance(key, slice): self.del_state(key.start, key.stop) else: @@ -117,23 +193,23 @@ class StateStorage(BaseStorage): result = True if user not in self.storage[chat]: - self.storage[chat][user] = None + self.storage[chat][user] = {'state': None, 'data': {}} result = True return result def set_state(self, chat, user, state): self._prepare(chat, user) - self.storage[chat][user] = self._prepare_state_name(state) + self.storage[chat][user]['state'] = self._prepare_state_name(state) def get_state(self, chat, user): self._prepare(chat, user) - return self.storage[chat][user] + return self.storage[chat][user]['state'] def del_state(self, chat, user): self._prepare(chat, user) if self[chat:user] is not None: - self[chat:user] = None + self.storage[chat][user]['state'] = {'state': None, 'data': {}} def all_states(self, chat=None, user=None, state=None): for chat_id, chat in self.storage.items(): @@ -146,6 +222,26 @@ class StateStorage(BaseStorage): continue yield chat_id, user_id, user_state + def set_value(self, chat, user, key, value): + self._prepare(chat, user) + self.storage[chat][user]['data'][key] = value + + def del_value(self, chat, user, key): + self._prepare(chat, user) + del self.storage[chat][user]['data'][key] + + def get_data(self, chat, user): + self._prepare(chat, user) + return self.storage[chat][user]['data'] + + def update_data(self, chat, user, data): + self._prepare(chat, user) + self.storage[chat][user]['data'].update(data) + + def clear_data(self, chat, user, key): + self._prepare(chat, user) + self.storage[chat][user]['data'].clear() + class Controller: """ @@ -160,17 +256,19 @@ class Controller: self._user = user self._state = state - def set(self, value): + def set_state(self, value): """ Set state + :param value: :return: """ self._state_machine[self._chat:self._user] = value - def get(self): + def get_state(self): """ Get current state + :return: """ return self._state_machine[self._chat:self._user] @@ -178,10 +276,78 @@ class Controller: def clear(self): """ Reset state + :return: """ del self._state_machine[self._chat:self._user] + def get(self, key, default=None): + """ + Get value from storage + + :param key: + :param default: + :return: + """ + return self._state_machine.storage.get_value(self._chat, self._user, key, default) + + def pop(self, key, default=None): + """ + Pop item from storage + + :param key: + :param default: + :return: + """ + result = self.get(key, default) + self.delete(key) + return result + + def set(self, key, value): + """ + Set new value in user storage + + :param key: + :param value: + :return: + """ + self._state_machine.storage.set_value(self._chat, self._user, key, value) + + def delete(self, key): + """ + Delete key from user storage + + :param key: + :return: + """ + self._state_machine.storage.del_value(self._chat, self._user, key) + + def update(self, data): + """ + Update user storage + + :param data: + :return: + """ + self._state_machine.storage.update_data(self._chat, self._user, data) + + @property + def data(self): + """ + User data + :return: + """ + return self._state_machine.storage.get_value + + def __setitem__(self, key, value): + self.set(key, value) + + def __getitem__(self, item): + return self.get(item) + + def __delitem__(self, key): + self.delete(key) + def __str__(self): return f"{self._chat}:{self._user} - {self._state}" @@ -262,14 +428,10 @@ class StateMachine: self.del_state(chat_id, from_user_id) raise SkipHandler() + log.debug(f"Process state for {chat_id}:{from_user_id} - '{state}'") callback = self.steps[state] controller = Controller(self, chat_id, from_user_id, state) - log.debug(f"Process state for {chat_id}:{from_user_id} - '{state}'") - result = await callback(message, controller) - # if result is True: - # controller.clear() - # elif isinstance(result, str): - # controller.set(result) + await callback(message, controller) def __setitem__(self, key, value): """ diff --git a/examples/state_machine.py b/examples/state_machine.py index a593b48e..94cfb8f9 100644 --- a/examples/state_machine.py +++ b/examples/state_machine.py @@ -6,7 +6,6 @@ from aiogram.dispatcher import Dispatcher from aiogram.dispatcher.state import StateMachine API_TOKEN = 'BOT TOKEN HERE' -API_TOKEN = '380294876:AAFbdYYgq1hBi9hQDcxD3bj8QCNnVec5aHk' logging.basicConfig(level=logging.DEBUG) @@ -14,8 +13,6 @@ loop = asyncio.get_event_loop() bot = Bot(token=API_TOKEN, loop=loop) dp = Dispatcher(bot) -users = {} - @dp.message_handler(commands=['start']) async def send_welcome(message: types.Message): @@ -24,39 +21,41 @@ async def send_welcome(message: types.Message): async def process_name(message, controller): - users[message.from_user.id] = {"name": message.text} + controller["name"] = message.text await message.reply("How old are you?") - controller.set('age') + controller.set_state('age') async def process_age(message, controller): if not message.text.isdigit(): return await message.reply("Age should be a number.\nHow old are you?") - users[message.from_user.id].update({"age": int(message.text)}) + controller["age"] = int(message.text) markup = types.ReplyKeyboardMarkup() markup.add("Male", "Female") markup.add("Other") await message.reply("What is your gender?", reply_markup=markup) - controller.set("sex") + + controller.set_state("sex") async def process_sex(message, controller): if message.text not in ["Male", "Female", "Other"]: return await message.reply("Bad gender name. Choose you gender from keyboard.") - users[message.from_user.id].update({"sex": message.text}) - controller.clear() - - user = users[message.from_user.id] + controller["sex"] = message.text markup = types.ReplyKeyboardRemove() await bot.send_message(message.chat.id, - f"Hi!\nNice to meet you, {user['name']}.\nAge: {user['age']}\nSex: {user['sex']}", + f"Hi!\n" + f"Nice to meet you, {controller['name']}.\n" + f"Age: {controller['age']}\n" + f"Sex: {controller['sex']}", reply_markup=markup) + controller.clear() state = StateMachine(dp, {