diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c2dfa4201830d..2ac53cdccd67b 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -76,6 +76,8 @@ get_lambda_source, get_unique_config_specs, indent_lines_after_first, + is_async_callable, + is_async_generator, ) from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.iter import safetee @@ -3300,7 +3302,7 @@ def __init__( self._atransform = atransform func_for_name: Callable = atransform - if inspect.isasyncgenfunction(transform): + if is_async_generator(transform): self._atransform = transform # type: ignore[assignment] func_for_name = transform elif inspect.isgeneratorfunction(transform): @@ -3513,7 +3515,7 @@ def __init__( self.afunc = afunc func_for_name: Callable = afunc - if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + if is_async_callable(func) or is_async_generator(func): if afunc is not None: raise TypeError( "Func was provided as a coroutine function, but afunc was " @@ -3774,7 +3776,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] afunc = f - if inspect.isasyncgenfunction(afunc): + if is_async_generator(afunc): output: Optional[Output] = None async for chunk in cast( AsyncIterator[Output], @@ -3992,7 +3994,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] afunc = f - if inspect.isasyncgenfunction(afunc): + if is_async_generator(afunc): output: Optional[Output] = None async for chunk in cast( AsyncIterator[Output], @@ -4034,7 +4036,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] ), ): yield chunk - elif not inspect.isasyncgenfunction(afunc): + elif not is_async_generator(afunc): # Otherwise, just yield it yield cast(Output, output) @@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: """ if isinstance(thing, Runnable): return thing - elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + elif is_async_generator(thing) or inspect.isgeneratorfunction(thing): return RunnableGenerator(thing) elif callable(thing): return RunnableLambda(cast(Callable[[Input], Output], thing)) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index f77f756e66636..3214ca6663959 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -1,4 +1,5 @@ """Utility code for runnables.""" + from __future__ import annotations import ast @@ -11,6 +12,8 @@ from typing import ( Any, AsyncIterable, + AsyncIterator, + Awaitable, Callable, Coroutine, Dict, @@ -27,6 +30,8 @@ Union, ) +from typing_extensions import TypeGuard + from langchain_core.pydantic_v1 import BaseConfig, BaseModel from langchain_core.pydantic_v1 import create_model as _create_model_base from langchain_core.runnables.schema import StreamEvent @@ -533,3 +538,25 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) + + +def is_async_generator( + func: Any, +) -> TypeGuard[Callable[..., AsyncIterator]]: + """Check if a function is an async generator.""" + return ( + inspect.isasyncgenfunction(func) + or hasattr(func, "__call__") + and inspect.isasyncgenfunction(func.__call__) + ) + + +def is_async_callable( + func: Any, +) -> TypeGuard[Callable[..., Awaitable]]: + """Check if a function is async.""" + return ( + asyncio.iscoroutinefunction(func) + or hasattr(func, "__call__") + and asyncio.iscoroutinefunction(func.__call__) + ) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 37e6bff2e3fe4..dd643a5e54cc5 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -4883,6 +4883,23 @@ async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: assert [p async for p in arunnable.astream(None)] == [1, 2, 3] assert await arunnable.abatch([None, None]) == [6, 6] + class AsyncGen: + async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield 1 + yield 2 + yield 3 + + arunnablecallable = RunnableGenerator(AsyncGen()) + assert await arunnablecallable.ainvoke(None) == 6 + assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] + assert await arunnablecallable.abatch([None, None]) == [6, 6] + with pytest.raises(NotImplementedError): + arunnablecallable.invoke(None) + with pytest.raises(NotImplementedError): + arunnablecallable.stream(None) + with pytest.raises(NotImplementedError): + arunnablecallable.batch([None, None]) + async def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config