Forum topic in FSM (#1161)

* Base implementation

* Added tests, fixed arguments priority

* Use `Optional[X]` instead of `X | None`

* Added changelog

* Added tests
This commit is contained in:
Alex Root Junior 2023-04-22 19:35:41 +03:00 committed by GitHub
parent 1538bc2e2d
commit 942ba0d520
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 164 additions and 60 deletions

View file

@ -4,6 +4,10 @@ from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Tuple
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.types import Chat, TelegramObject, Update, User
EVENT_FROM_USER_KEY = "event_from_user"
EVENT_CHAT_KEY = "event_chat"
EVENT_THREAD_ID_KEY = "event_thread_id"
class UserContextMiddleware(BaseMiddleware):
async def __call__(
@ -14,61 +18,64 @@ class UserContextMiddleware(BaseMiddleware):
) -> Any:
if not isinstance(event, Update):
raise RuntimeError("UserContextMiddleware got an unexpected event type!")
chat, user = self.resolve_event_context(event=event)
with self.context(chat=chat, user=user):
if user is not None:
data["event_from_user"] = user
if chat is not None:
data["event_chat"] = chat
return await handler(event, data)
@contextmanager
def context(self, chat: Optional[Chat] = None, user: Optional[User] = None) -> Iterator[None]:
chat_token = None
user_token = None
if chat:
chat_token = chat.set_current(chat)
if user:
user_token = user.set_current(user)
try:
yield
finally:
if chat and chat_token:
chat.reset_current(chat_token)
if user and user_token:
user.reset_current(user_token)
chat, user, thread_id = self.resolve_event_context(event=event)
if user is not None:
data[EVENT_FROM_USER_KEY] = user
if chat is not None:
data[EVENT_CHAT_KEY] = chat
if thread_id is not None:
data[EVENT_THREAD_ID_KEY] = thread_id
return await handler(event, data)
@classmethod
def resolve_event_context(cls, event: Update) -> Tuple[Optional[Chat], Optional[User]]:
def resolve_event_context(
cls, event: Update
) -> Tuple[Optional[Chat], Optional[User], Optional[int]]:
"""
Resolve chat and user instance from Update object
"""
if event.message:
return event.message.chat, event.message.from_user
return (
event.message.chat,
event.message.from_user,
event.message.message_thread_id if event.message.is_topic_message else None,
)
if event.edited_message:
return event.edited_message.chat, event.edited_message.from_user
return (
event.edited_message.chat,
event.edited_message.from_user,
event.edited_message.message_thread_id
if event.edited_message.is_topic_message
else None,
)
if event.channel_post:
return event.channel_post.chat, None
return event.channel_post.chat, None, None
if event.edited_channel_post:
return event.edited_channel_post.chat, None
return event.edited_channel_post.chat, None, None
if event.inline_query:
return None, event.inline_query.from_user
return None, event.inline_query.from_user, None
if event.chosen_inline_result:
return None, event.chosen_inline_result.from_user
return None, event.chosen_inline_result.from_user, None
if event.callback_query:
if event.callback_query.message:
return event.callback_query.message.chat, event.callback_query.from_user
return None, event.callback_query.from_user
return (
event.callback_query.message.chat,
event.callback_query.from_user,
event.callback_query.message.message_thread_id
if event.callback_query.message.is_topic_message
else None,
)
return None, event.callback_query.from_user, None
if event.shipping_query:
return None, event.shipping_query.from_user
return None, event.shipping_query.from_user, None
if event.pre_checkout_query:
return None, event.pre_checkout_query.from_user
return None, event.pre_checkout_query.from_user, None
if event.poll_answer:
return None, event.poll_answer.user
return None, event.poll_answer.user, None
if event.my_chat_member:
return event.my_chat_member.chat, event.my_chat_member.from_user
return event.my_chat_member.chat, event.my_chat_member.from_user, None
if event.chat_member:
return event.chat_member.chat, event.chat_member.from_user
return event.chat_member.chat, event.chat_member.from_user, None
if event.chat_join_request:
return event.chat_join_request.chat, event.chat_join_request.from_user
return None, None
return event.chat_join_request.chat, event.chat_join_request.from_user, None
return None, None, None

View file

@ -47,25 +47,42 @@ class FSMContextMiddleware(BaseMiddleware):
) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
thread_id = data.get("event_thread_id")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return self.resolve_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
)
def resolve_context(
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
if chat_id is not None and user_id is not None:
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
chat_id, user_id, thread_id = apply_strategy(
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=self.strategy,
)
return self.get_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
destiny=destiny,
)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id, destiny=destiny)
return None
def get_context(
@ -73,6 +90,7 @@ class FSMContextMiddleware(BaseMiddleware):
bot: Bot,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
@ -81,6 +99,7 @@ class FSMContextMiddleware(BaseMiddleware):
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
thread_id=thread_id,
destiny=destiny,
),
)

View file

@ -15,6 +15,7 @@ class StorageKey:
bot_id: int
chat_id: int
user_id: int
thread_id: Optional[int] = None
destiny: str = DEFAULT_DESTINY

View file

@ -70,7 +70,10 @@ class DefaultKeyBuilder(KeyBuilder):
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
parts.extend([str(key.chat_id), str(key.user_id)])
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:

View file

@ -1,16 +1,24 @@
from enum import Enum, auto
from typing import Tuple
from typing import Optional, Tuple
class FSMStrategy(Enum):
USER_IN_CHAT = auto()
CHAT = auto()
GLOBAL_USER = auto()
USER_IN_THREAD = auto()
def apply_strategy(chat_id: int, user_id: int, strategy: FSMStrategy) -> Tuple[int, int]:
def apply_strategy(
strategy: FSMStrategy,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
) -> Tuple[int, int, Optional[int]]:
if strategy == FSMStrategy.CHAT:
return chat_id, chat_id
return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER:
return user_id, user_id
return chat_id, user_id
return user_id, user_id, None
if strategy == FSMStrategy.USER_IN_THREAD:
return chat_id, user_id, thread_id
return chat_id, user_id, None