feat(reqs): determine function types

This commit is contained in:
mpa 2020-07-21 05:43:56 +04:00
parent ca44f9c01a
commit 2d4419ecdc
No known key found for this signature in database
GPG key ID: BCCFBFCCC9B754A8

View file

@ -1,5 +1,8 @@
import enum
import asyncio
import inspect
import contextvars
from functools import partial
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from typing import (
Any,
@ -46,7 +49,7 @@ async def move_to_async_gen(context_manager: Any) -> Any:
class Requirement(Generic[T]):
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type"
__slots__ = "callable", "children", "cache_key", "use_cache", "generator_type", "is_async"
def __init__(
self,
@ -60,10 +63,14 @@ class Requirement(Generic[T]):
self.generator_type = GeneratorKind.not_a_gen
self.is_async: Optional[bool] = None # unset value
if inspect.isasyncgenfunction(callable_):
self.generator_type = GeneratorKind.async_gen
elif inspect.isgeneratorfunction(callable_):
self.generator_type = GeneratorKind.plain_gen
else:
self.is_async = inspect.iscoroutinefunction(callable_)
self.cache_key = hash(callable_) if cache_key is None else cache_key
self.children = get_reqs_from_callable(callable_)
@ -119,11 +126,17 @@ async def initialize_callable_requirement(
if async_cm is not None:
return await stack.enter_async_context(async_cm)
else:
result = required.callable(**actual_data)
if isinstance(result, Awaitable):
return cast(T, await result)
return cast(T, result)
if not required.is_async:
context = contextvars.copy_context()
wrapped = partial(context.run, partial(required.callable, **actual_data))
loop = asyncio.get_event_loop()
return cast(T, await loop.run_in_executor(None, wrapped))
else:
return cast(T, await required.callable(**actual_data))
def get_reqs_from_callable(callable_: _RequiredCallback[T]) -> Dict[str, Requirement[Any]]: