Removed the use of the context instance (Bot.get_current) from all placements that were used previously. (#1230)

* Removed the use of the context instance (Bot.get_current) from all placements that were used previously.

* Fixed tests

* Added changelog

* Change category
This commit is contained in:
Alex Root Junior 2023-07-28 22:23:32 +03:00 committed by GitHub
parent 479e302cba
commit 2ecf9cefd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 45 additions and 49 deletions

2
CHANGES/1230.removal.rst Normal file
View file

@ -0,0 +1,2 @@
Removed the use of the context instance (Bot.get_current) from all placements that were used previously.
This is to avoid the use of the context instance in the wrong place.

View file

@ -230,7 +230,7 @@ from .session.base import BaseSession
T = TypeVar("T")
class Bot(ContextInstanceMixin["Bot"]):
class Bot:
def __init__(
self,
token: str,
@ -284,16 +284,14 @@ class Bot(ContextInstanceMixin["Bot"]):
"""
Generate bot context
:param auto_close:
:param auto_close: close session on exit
:return:
"""
token = self.set_current(self)
try:
yield self
finally:
if auto_close:
await self.session.close()
self.reset_current(token)
async def me(self) -> User:
"""

View file

@ -13,13 +13,6 @@ class BotContextController(BaseModel):
def model_post_init(self, __context: Any) -> None:
self._bot = __context.get("bot") if __context else None
def get_mounted_bot(self) -> Optional["Bot"]:
# Properties are not supported in pydantic BaseModel
# @computed_field decorator is not a solution for this case in due to
# it produces an additional field in model with validation and serialization that
# we don't need here
return self._bot
def as_(self, bot: Optional["Bot"]) -> Self:
"""
Bind object to a bot instance.
@ -29,3 +22,12 @@ class BotContextController(BaseModel):
"""
self._bot = bot
return self
@property
def bot(self) -> Optional["Bot"]:
"""
Get bot instance.
:return: Bot instance
"""
return self._bot

View file

@ -18,7 +18,7 @@ from ..fsm.strategy import FSMStrategy
from ..methods import GetUpdates, TelegramMethod
from ..methods.base import TelegramType
from ..types import Update, User
from ..types.base import UNSET_TYPE, UNSET
from ..types.base import UNSET, UNSET_TYPE
from ..types.update import UpdateTypeLookupError
from ..utils.backoff import Backoff, BackoffConfig
from .event.bases import UNHANDLED, SkipHandler
@ -143,7 +143,7 @@ class Dispatcher(Router):
handled = False
start_time = loop.time()
if update.get_mounted_bot() != bot:
if update.bot != bot:
# Re-mounting update to the current bot instance for making possible to
# use it in shortcuts.
# Here is update is re-created because we need to propagate context to
@ -184,7 +184,7 @@ class Dispatcher(Router):
:param update:
:param kwargs:
"""
parsed_update = Update(**update)
parsed_update = Update.model_validate(update, context={"bot": bot})
return await self.feed_update(bot=bot, update=parsed_update, **kwargs)
@classmethod
@ -558,7 +558,7 @@ class Dispatcher(Router):
polling_timeout: int = 10,
handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
allowed_updates: Optional[Union[List[str], UNSET_TYPE]] = UNSET,
handle_signals: bool = True,
close_bot_session: bool = True,
**kwargs: Any,

View file

@ -27,7 +27,7 @@ class StateFilter(Filter):
)
async def __call__(
self, obj: Union[TelegramObject], raw_state: Optional[str] = None
self, obj: TelegramObject, raw_state: Optional[str] = None
) -> Union[bool, Dict[str, Any]]:
allowed_states = cast(Sequence[StateType], self.states)
for allowed_state in allowed_states:

View file

@ -32,7 +32,7 @@ class BaseHandler(BaseHandlerMixin[T], ABC):
if "bot" in self.data:
return cast(Bot, self.data["bot"])
return Bot.get_current(no_error=False)
raise RuntimeError("Bot instance not found in the context")
@property
def update(self) -> Update:

View file

@ -14,6 +14,7 @@ class TelegramObject(BotContextController, BaseModel):
frozen=True,
populate_by_name=True,
arbitrary_types_allowed=True,
defer_build=True,
)

View file

@ -31,22 +31,19 @@ class ChatActionSender:
def __init__(
self,
*,
bot: Bot,
chat_id: Union[str, int],
action: str = "typing",
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
bot: Optional[Bot] = None,
) -> None:
"""
:param bot: instance of the bot
:param chat_id: target chat id
:param action: chat action type
:param interval: interval between iterations
:param initial_sleep: sleep before first iteration
:param bot: instance of the bot, can be omitted from the context
"""
if bot is None:
bot = Bot.get_current(False)
self.chat_id = chat_id
self.action = action
self.interval = interval
@ -132,7 +129,7 @@ class ChatActionSender:
def typing(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -149,7 +146,7 @@ class ChatActionSender:
def upload_photo(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -166,7 +163,7 @@ class ChatActionSender:
def record_video(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -183,7 +180,7 @@ class ChatActionSender:
def upload_video(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -200,7 +197,7 @@ class ChatActionSender:
def record_voice(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -217,7 +214,7 @@ class ChatActionSender:
def upload_voice(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -234,7 +231,7 @@ class ChatActionSender:
def upload_document(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -251,7 +248,7 @@ class ChatActionSender:
def choose_sticker(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -268,7 +265,7 @@ class ChatActionSender:
def find_location(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -285,7 +282,7 @@ class ChatActionSender:
def record_video_note(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":
@ -302,7 +299,7 @@ class ChatActionSender:
def upload_video_note(
cls,
chat_id: Union[int, str],
bot: Optional[Bot] = None,
bot: Bot,
interval: float = DEFAULT_INTERVAL,
initial_sleep: float = DEFAULT_INITIAL_SLEEP,
) -> "ChatActionSender":

View file

@ -4,7 +4,7 @@ import pytest
from _pytest.config import UsageError
from redis.asyncio.connection import parse_url as parse_redis_url
from aiogram import Bot, Dispatcher
from aiogram import Dispatcher
from aiogram.fsm.storage.memory import (
DisabledEventIsolation,
MemoryStorage,
@ -109,12 +109,7 @@ async def disabled_isolation():
@pytest.fixture()
def bot():
bot = MockedBot()
token = Bot.set_current(bot)
try:
yield bot
finally:
Bot.reset_current(token)
return MockedBot()
@pytest.fixture()

View file

@ -29,15 +29,17 @@ class TestBaseClassBasedHandler:
async def test_bot_from_context(self):
event = Update(update_id=42)
handler = MyHandler(event=event, key=42)
bot = Bot("42:TEST")
with pytest.raises(LookupError):
handler.bot
Bot.set_current(bot)
handler = MyHandler(event=event, key=42, bot=bot)
assert handler.bot == bot
async def test_bot_from_context_missing(self):
event = Update(update_id=42)
handler = MyHandler(event=event, key=42)
with pytest.raises(RuntimeError):
handler.bot
async def test_bot_from_data(self):
event = Update(update_id=42)
bot = Bot("42:TEST")

View file

@ -36,10 +36,9 @@ class TestChatActionSender:
"upload_video_note",
],
)
@pytest.mark.parametrize("pass_bot", [True, False])
async def test_factory(self, action: str, bot: MockedBot, pass_bot: bool):
async def test_factory(self, action: str, bot: MockedBot):
sender_factory = getattr(ChatActionSender, action)
sender = sender_factory(chat_id=42, bot=bot if pass_bot else None)
sender = sender_factory(chat_id=42, bot=bot)
assert isinstance(sender, ChatActionSender)
assert sender.action == action
assert sender.chat_id == 42