diff --git a/aiogram/dispatcher/requirement.py b/aiogram/dispatcher/requirement.py index 1d7b0ab0..6dc1e351 100644 --- a/aiogram/dispatcher/requirement.py +++ b/aiogram/dispatcher/requirement.py @@ -1,5 +1,8 @@ import enum +import asyncio import inspect +import contextvars +from functools import partial from contextlib import AsyncExitStack, asynccontextmanager, contextmanager from typing import ( Any, @@ -46,7 +49,7 @@ async def move_to_async_gen(context_manager: Any) -> Any: class Requirement(Generic[T]): - __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type" + __slots__ = "callable", "children", "cache_key", "use_cache", "generator_type", "is_async" def __init__( self, @@ -60,10 +63,14 @@ class Requirement(Generic[T]): self.generator_type = GeneratorKind.not_a_gen + self.is_async: Optional[bool] = None # unset value + if inspect.isasyncgenfunction(callable_): self.generator_type = GeneratorKind.async_gen elif inspect.isgeneratorfunction(callable_): self.generator_type = GeneratorKind.plain_gen + else: + self.is_async = inspect.iscoroutinefunction(callable_) self.cache_key = hash(callable_) if cache_key is None else cache_key self.children = get_reqs_from_callable(callable_) @@ -119,11 +126,17 @@ async def initialize_callable_requirement( if async_cm is not None: return await stack.enter_async_context(async_cm) + else: - result = required.callable(**actual_data) - if isinstance(result, Awaitable): - return cast(T, await result) - return cast(T, result) + if not required.is_async: + context = contextvars.copy_context() + wrapped = partial(context.run, partial(required.callable, **actual_data)) + + loop = asyncio.get_event_loop() + return cast(T, await loop.run_in_executor(None, wrapped)) + + else: + return cast(T, await required.callable(**actual_data)) def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]: