Removed the use of the context instance (Bot.get_current) from all placements that were used previously. (#1230)

* Removed the use of the context instance (Bot.get_current) from all placements that were used previously.

* Fixed tests

* Added changelog

* Change category
This commit is contained in:
Alex Root Junior 2023-07-28 22:23:32 +03:00 committed by GitHub
parent 479e302cba
commit 2ecf9cefd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 45 additions and 49 deletions

2
CHANGES/1230.removal.rst Normal file
View file

@ -0,0 +1,2 @@
Removed the use of the context instance (Bot.get_current) from all placements that were used previously.
This is to avoid the use of the context instance in the wrong place.

View file

@ -230,7 +230,7 @@ from .session.base import BaseSession
T = TypeVar("T") T = TypeVar("T")
class Bot(ContextInstanceMixin["Bot"]): class Bot:
def __init__( def __init__(
self, self,
token: str, token: str,
@ -284,16 +284,14 @@ class Bot(ContextInstanceMixin["Bot"]):
""" """
Generate bot context Generate bot context
:param auto_close: :param auto_close: close session on exit
:return: :return:
""" """
token = self.set_current(self)
try: try:
yield self yield self
finally: finally:
if auto_close: if auto_close:
await self.session.close() await self.session.close()
self.reset_current(token)
async def me(self) -> User: async def me(self) -> User:
""" """

View file

@ -13,13 +13,6 @@ class BotContextController(BaseModel):
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
self._bot = __context.get("bot") if __context else None self._bot = __context.get("bot") if __context else None
def get_mounted_bot(self) -> Optional["Bot"]:
# Properties are not supported in pydantic BaseModel
# @computed_field decorator is not a solution for this case in due to
# it produces an additional field in model with validation and serialization that
# we don't need here
return self._bot
def as_(self, bot: Optional["Bot"]) -> Self: def as_(self, bot: Optional["Bot"]) -> Self:
""" """
Bind object to a bot instance. Bind object to a bot instance.
@ -29,3 +22,12 @@ class BotContextController(BaseModel):
""" """
self._bot = bot self._bot = bot
return self return self
@property
def bot(self) -> Optional["Bot"]:
"""
Get bot instance.
:return: Bot instance
"""
return self._bot

View file

@ -18,7 +18,7 @@ from ..fsm.strategy import FSMStrategy
from ..methods import GetUpdates, TelegramMethod from ..methods import GetUpdates, TelegramMethod
from ..methods.base import TelegramType from ..methods.base import TelegramType
from ..types import Update, User from ..types import Update, User
from ..types.base import UNSET_TYPE, UNSET from ..types.base import UNSET, UNSET_TYPE
from ..types.update import UpdateTypeLookupError from ..types.update import UpdateTypeLookupError
from ..utils.backoff import Backoff, BackoffConfig from ..utils.backoff import Backoff, BackoffConfig
from .event.bases import UNHANDLED, SkipHandler from .event.bases import UNHANDLED, SkipHandler
@ -143,7 +143,7 @@ class Dispatcher(Router):
handled = False handled = False
start_time = loop.time() start_time = loop.time()
if update.get_mounted_bot() != bot: if update.bot != bot:
# Re-mounting update to the current bot instance for making possible to # Re-mounting update to the current bot instance for making possible to
# use it in shortcuts. # use it in shortcuts.
# Here is update is re-created because we need to propagate context to # Here is update is re-created because we need to propagate context to
@ -184,7 +184,7 @@ class Dispatcher(Router):
:param update: :param update:
:param kwargs: :param kwargs:
""" """
parsed_update = Update(**update) parsed_update = Update.model_validate(update, context={"bot": bot})
return await self.feed_update(bot=bot, update=parsed_update, **kwargs) return await self.feed_update(bot=bot, update=parsed_update, **kwargs)
@classmethod @classmethod
@ -558,7 +558,7 @@ class Dispatcher(Router):
polling_timeout: int = 10, polling_timeout: int = 10,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None, allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
handle_signals: bool = True, handle_signals: bool = True,
close_bot_session: bool = True, close_bot_session: bool = True,
**kwargs: Any, **kwargs: Any,

View file

@ -27,7 +27,7 @@ class StateFilter(Filter):
) )
async def __call__( async def __call__(
self, obj: Union[TelegramObject], raw_state: Optional[str] = None self, obj: TelegramObject, raw_state: Optional[str] = None
) -> Union[bool, Dict[str, Any]]: ) -> Union[bool, Dict[str, Any]]:
allowed_states = cast(Sequence[StateType], self.states) allowed_states = cast(Sequence[StateType], self.states)
for allowed_state in allowed_states: for allowed_state in allowed_states:

View file

@ -32,7 +32,7 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
if "bot" in self.data: if "bot" in self.data:
return cast(Bot, self.data["bot"]) return cast(Bot, self.data["bot"])
return Bot.get_current(no_error=False) raise RuntimeError("Bot instance not found in the context")
@property @property
def update(self) -> Update: def update(self) -> Update:

View file

@ -14,6 +14,7 @@ class TelegramObject(BotContextController, BaseModel):
frozen=True, frozen=True,
populate_by_name=True, populate_by_name=True,
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
defer_build=True,
) )

View file

@ -31,22 +31,19 @@ class ChatActionSender:
def __init__( def __init__(
self, self,
*, *,
bot: Bot,
chat_id: Union[str, int], chat_id: Union[str, int],
action: str = "typing", action: str = "typing",
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
bot: Optional[Bot] = None,
) -> None: ) -> None:
""" """
:param bot: instance of the bot
:param chat_id: target chat id :param chat_id: target chat id
:param action: chat action type :param action: chat action type
:param interval: interval between iterations :param interval: interval between iterations
:param initial_sleep: sleep before first iteration :param initial_sleep: sleep before first iteration
:param bot: instance of the bot, can be omitted from the context
""" """
if bot is None:
bot = Bot.get_current(False)
self.chat_id = chat_id self.chat_id = chat_id
self.action = action self.action = action
self.interval = interval self.interval = interval
@ -132,7 +129,7 @@ class ChatActionSender:
def typing( def typing(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -149,7 +146,7 @@ class ChatActionSender:
def upload_photo( def upload_photo(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -166,7 +163,7 @@ class ChatActionSender:
def record_video( def record_video(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -183,7 +180,7 @@ class ChatActionSender:
def upload_video( def upload_video(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -200,7 +197,7 @@ class ChatActionSender:
def record_voice( def record_voice(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -217,7 +214,7 @@ class ChatActionSender:
def upload_voice( def upload_voice(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -234,7 +231,7 @@ class ChatActionSender:
def upload_document( def upload_document(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -251,7 +248,7 @@ class ChatActionSender:
def choose_sticker( def choose_sticker(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -268,7 +265,7 @@ class ChatActionSender:
def find_location( def find_location(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -285,7 +282,7 @@ class ChatActionSender:
def record_video_note( def record_video_note(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":
@ -302,7 +299,7 @@ class ChatActionSender:
def upload_video_note( def upload_video_note(
cls, cls,
chat_id: Union[int, str], chat_id: Union[int, str],
bot: Optional[Bot] = None, bot: Bot,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP, initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender": ) -> "ChatActionSender":

View file

@ -4,7 +4,7 @@ import pytest
from _pytest.config import UsageError from _pytest.config import UsageError
from redis.asyncio.connection import parse_url as parse_redis_url from redis.asyncio.connection import parse_url as parse_redis_url
from aiogram import Bot, Dispatcher from aiogram import Dispatcher
from aiogram.fsm.storage.memory import ( from aiogram.fsm.storage.memory import (
DisabledEventIsolation, DisabledEventIsolation,
MemoryStorage, MemoryStorage,
@ -109,12 +109,7 @@ async def disabled_isolation():
@pytest.fixture() @pytest.fixture()
def bot(): def bot():
bot = MockedBot() return MockedBot()
token = Bot.set_current(bot)
try:
yield bot
finally:
Bot.reset_current(token)
@pytest.fixture() @pytest.fixture()

View file

@ -29,15 +29,17 @@ class TestBaseClassBasedHandler:
async def test_bot_from_context(self): async def test_bot_from_context(self):
event = Update(update_id=42) event = Update(update_id=42)
handler = MyHandler(event=event, key=42)
bot = Bot("42:TEST") bot = Bot("42:TEST")
handler = MyHandler(event=event, key=42, bot=bot)
with pytest.raises(LookupError):
handler.bot
Bot.set_current(bot)
assert handler.bot == bot assert handler.bot == bot
async def test_bot_from_context_missing(self):
event = Update(update_id=42)
handler = MyHandler(event=event, key=42)
with pytest.raises(RuntimeError):
handler.bot
async def test_bot_from_data(self): async def test_bot_from_data(self):
event = Update(update_id=42) event = Update(update_id=42)
bot = Bot("42:TEST") bot = Bot("42:TEST")

View file

@ -36,10 +36,9 @@ class TestChatActionSender:
"upload_video_note", "upload_video_note",
], ],
) )
@pytest.mark.parametrize("pass_bot", [True, False]) async def test_factory(self, action: str, bot: MockedBot):
async def test_factory(self, action: str, bot: MockedBot, pass_bot: bool):
sender_factory = getattr(ChatActionSender, action) sender_factory = getattr(ChatActionSender, action)
sender = sender_factory(chat_id=42, bot=bot if pass_bot else None) sender = sender_factory(chat_id=42, bot=bot)
assert isinstance(sender, ChatActionSender) assert isinstance(sender, ChatActionSender)
assert sender.action == action assert sender.action == action
assert sender.chat_id == 42 assert sender.chat_id == 42