mirror of
https://github.com/aiogram/aiogram.git
synced 2026-04-08 16:37:47 +00:00
feat(reqs): determine function types
This commit is contained in:
parent
ca44f9c01a
commit
2d4419ecdc
1 changed files with 18 additions and 5 deletions
|
|
@ -1,5 +1,8 @@
|
|||
import enum
|
||||
import asyncio
|
||||
import inspect
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
|
|
@ -46,7 +49,7 @@ async def move_to_async_gen(context_manager: Any) -> Any:
|
|||
|
||||
|
||||
class Requirement(Generic[T]):
|
||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
|
||||
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type", "is_async"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -60,10 +63,14 @@ class Requirement(Generic[T]):
|
|||
|
||||
self.generator_type = GeneratorKind.not_a_gen
|
||||
|
||||
self.is_async: Optional[bool] = None # unset value
|
||||
|
||||
if inspect.isasyncgenfunction(callable_):
|
||||
self.generator_type = GeneratorKind.async_gen
|
||||
elif inspect.isgeneratorfunction(callable_):
|
||||
self.generator_type = GeneratorKind.plain_gen
|
||||
else:
|
||||
self.is_async = inspect.iscoroutinefunction(callable_)
|
||||
|
||||
self.cache_key = hash(callable_) if cache_key is None else cache_key
|
||||
self.children = get_reqs_from_callable(callable_)
|
||||
|
|
@ -119,11 +126,17 @@ async def initialize_callable_requirement(
|
|||
|
||||
if async_cm is not None:
|
||||
return await stack.enter_async_context(async_cm)
|
||||
|
||||
else:
|
||||
result = required.callable(**actual_data)
|
||||
if isinstance(result, Awaitable):
|
||||
return cast(T, await result)
|
||||
return cast(T, result)
|
||||
if not required.is_async:
|
||||
context = contextvars.copy_context()
|
||||
wrapped = partial(context.run, partial(required.callable, **actual_data))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return cast(T, await loop.run_in_executor(None, wrapped))
|
||||
|
||||
else:
|
||||
return cast(T, await required.callable(**actual_data))
|
||||
|
||||
|
||||
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue