Skip to content

Commit

Permalink
[Core] Check is async callable (#21714)
Browse files Browse the repository at this point in the history
To permit proper coercion of objects like the following:


```python
class MyAsyncCallable:
    async def __call__(self, foo):
        return await ...

class MyAsyncGenerator:
    async def __call__(self, foo):
        await ...
        yield 
```
  • Loading branch information
hinthornw authored May 15, 2024
1 parent 7128c2d commit ca768c8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
get_lambda_source,
get_unique_config_specs,
indent_lines_after_first,
is_async_callable,
is_async_generator,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
Expand Down Expand Up @@ -3300,7 +3302,7 @@ def __init__(
self._atransform = atransform
func_for_name: Callable = atransform

if inspect.isasyncgenfunction(transform):
if is_async_generator(transform):
self._atransform = transform # type: ignore[assignment]
func_for_name = transform
elif inspect.isgeneratorfunction(transform):
Expand Down Expand Up @@ -3513,7 +3515,7 @@ def __init__(
self.afunc = afunc
func_for_name: Callable = afunc

if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if is_async_callable(func) or is_async_generator(func):
if afunc is not None:
raise TypeError(
"Func was provided as a coroutine function, but afunc was "
Expand Down Expand Up @@ -3774,7 +3776,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def]

afunc = f

if inspect.isasyncgenfunction(afunc):
if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
Expand Down Expand Up @@ -3992,7 +3994,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def]

afunc = f

if inspect.isasyncgenfunction(afunc):
if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
Expand Down Expand Up @@ -4034,7 +4036,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def]
),
):
yield chunk
elif not inspect.isasyncgenfunction(afunc):
elif not is_async_generator(afunc):
# Otherwise, just yield it
yield cast(Output, output)

Expand Down Expand Up @@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
"""
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(cast(Callable[[Input], Output], thing))
Expand Down
27 changes: 27 additions & 0 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility code for runnables."""

from __future__ import annotations

import ast
Expand All @@ -11,6 +12,8 @@
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Expand All @@ -27,6 +30,8 @@
Union,
)

from typing_extensions import TypeGuard

from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent
Expand Down Expand Up @@ -533,3 +538,25 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)


def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)


def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")
and asyncio.iscoroutinefunction(func.__call__)
)
17 changes: 17 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4883,6 +4883,23 @@ async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert await arunnable.abatch([None, None]) == [6, 6]

class AsyncGen:
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3

arunnablecallable = RunnableGenerator(AsyncGen())
assert await arunnablecallable.ainvoke(None) == 6
assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
assert await arunnablecallable.abatch([None, None]) == [6, 6]
with pytest.raises(NotImplementedError):
arunnablecallable.invoke(None)
with pytest.raises(NotImplementedError):
arunnablecallable.stream(None)
with pytest.raises(NotImplementedError):
arunnablecallable.batch([None, None])


async def test_runnable_gen_context_config() -> None:
"""Test that a generator can call other runnables with config
Expand Down

0 comments on commit ca768c8

Please sign in to comment.