Skip to content

Commit

Permalink
lib: Add Command(graph=Command.PARENT, ...)
Browse files Browse the repository at this point in the history
- This makes the command bubble up out of the current graph and be handled by the calling graph (the immediate parent)
- This could be extended to support eg. ROOT graph, or some other level
  • Loading branch information
nfcampos committed Nov 23, 2024
1 parent 4f4e7a6 commit fed60e7
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 17 deletions.
17 changes: 14 additions & 3 deletions libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Sequence

from langgraph.checkpoint.base import EmptyChannelError # noqa: F401
from langgraph.types import Interrupt
from langgraph.types import Command, Interrupt

# EmptyChannelError re-exported for backwards compatibility

Expand Down Expand Up @@ -58,7 +58,11 @@ class InvalidUpdateError(Exception):
pass


class GraphInterrupt(Exception):
class GraphBubbleUp(Exception):
pass


class GraphInterrupt(GraphBubbleUp):
"""Raised when a subgraph is interrupted, suppressed by the root graph.
Never raised directly, or surfaced to the user."""

Expand All @@ -73,13 +77,20 @@ def __init__(self, value: Any) -> None:
super().__init__([Interrupt(value=value)])


class GraphDelegate(Exception):
class GraphDelegate(GraphBubbleUp):
"""Raised when a graph is delegated (for distributed mode)."""

def __init__(self, *args: dict[str, Any]) -> None:
super().__init__(*args)


class ParentCommand(GraphBubbleUp):
args: tuple[Command]

def __init__(self, command: Command) -> None:
super().__init__(command)


class EmptyInputError(Exception):
"""Raised when graph receives an empty input."""

Expand Down
15 changes: 14 additions & 1 deletion libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
from langgraph.channels.last_value import LastValue
from langgraph.channels.named_barrier_value import NamedBarrierValue
from langgraph.constants import EMPTY_SEQ, NS_END, NS_SEP, SELF, TAG_HIDDEN
from langgraph.errors import ErrorCode, InvalidUpdateError, create_error_message
from langgraph.errors import (
ErrorCode,
InvalidUpdateError,
ParentCommand,
create_error_message,
)
from langgraph.graph.graph import END, START, Branch, CompiledGraph, Graph, Send
from langgraph.managed.base import (
ChannelKeyPlaceholder,
Expand Down Expand Up @@ -623,6 +628,8 @@ def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:

def _get_root(input: Any) -> Any:
if isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return input.update
else:
return input
Expand All @@ -640,6 +647,8 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
)
return input.get(key, SKIP_WRITE)
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return _get_state_key(input.update, key=key)
elif get_type_hints(type(input)):
value = getattr(input, key, SKIP_WRITE)
Expand Down Expand Up @@ -822,6 +831,8 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
return [value]
if not isinstance(value, GraphCommand):
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, str):
rtn.append(value.goto)
Expand All @@ -839,6 +850,8 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
return [value]
if not isinstance(value, GraphCommand):
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, str):
rtn.append(value.goto)
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from langchain_core.tools.base import get_all_basemodel_annotations
from typing_extensions import Annotated, get_args, get_origin

from langgraph.errors import GraphInterrupt
from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.utils.runnable import RunnableCallable

Expand Down Expand Up @@ -275,7 +275,7 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphInterrupt as e:
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
Expand Down Expand Up @@ -316,7 +316,7 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphInterrupt as e:
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
Expand Down
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def prepare_single_task(
None,
task_id,
task_path,
writers=proc.flat_writers,
)

else:
Expand Down Expand Up @@ -720,6 +721,7 @@ def prepare_single_task(
None,
task_id,
task_path,
writers=proc.flat_writers,
)
else:
return PregelTask(task_id, name, task_path)
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/pregel/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from langchain_core.runnables.config import get_executor_for_config
from typing_extensions import ParamSpec

from langgraph.errors import GraphInterrupt
from langgraph.errors import GraphBubbleUp

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -68,7 +68,7 @@ def submit( # type: ignore[valid-type]
def done(self, task: concurrent.futures.Future) -> None:
try:
task.result()
except GraphInterrupt:
except GraphBubbleUp:
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
self.tasks.pop(task)
Expand Down Expand Up @@ -155,7 +155,7 @@ def done(self, task: asyncio.Task) -> None:
if exc := task.exception():
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
if isinstance(exc, GraphInterrupt):
if isinstance(exc, GraphBubbleUp):
self.tasks.pop(task)
else:
self.tasks.pop(task)
Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TAG_HIDDEN,
TASKS,
)
from langgraph.errors import InvalidUpdateError
from langgraph.pregel.log import logger
from langgraph.types import Command, PregelExecutableTask, Send

Expand Down Expand Up @@ -68,6 +69,8 @@ def map_command(
cmd: Command,
) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is not parent graph")
if cmd.send:
if isinstance(cmd.send, (tuple, list)):
sends = cmd.send
Expand Down
38 changes: 34 additions & 4 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import random
import time
from dataclasses import replace
from functools import partial
from typing import Any, Callable, Optional, Sequence

Expand All @@ -10,9 +11,10 @@
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUMING,
CONFIG_KEY_SEND,
NS_SEP,
)
from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphInterrupt
from langgraph.types import PregelExecutableTask, RetryPolicy
from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphBubbleUp, ParentCommand
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,7 +42,21 @@ def run_with_retry(
task.proc.invoke(task.input, config)
# if successful, end
break
except GraphInterrupt:
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph == ns:
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
Expand Down Expand Up @@ -118,7 +134,21 @@ async def arun_with_retry(
await task.proc.ainvoke(task.input, config)
# if successful, end
break
except GraphInterrupt:
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph == ns:
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PUSH,
TAG_HIDDEN,
)
from langgraph.errors import GraphDelegate, GraphInterrupt
from langgraph.errors import GraphBubbleUp, GraphInterrupt
from langgraph.pregel.executor import Submit
from langgraph.pregel.retry import arun_with_retry, run_with_retry
from langgraph.types import PregelExecutableTask, RetryPolicy
Expand Down Expand Up @@ -298,7 +298,7 @@ def commit(
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exception.args[0]]:
self.put_writes(task.id, interrupts)
elif isinstance(exception, GraphDelegate):
elif isinstance(exception, GraphBubbleUp):
raise exception
else:
# save error to checkpointer
Expand All @@ -324,7 +324,7 @@ def _should_stop_others(
if fut.cancelled():
return True
if exc := fut.exception():
return not isinstance(exc, GraphInterrupt)
return not isinstance(exc, GraphBubbleUp)
else:
return False

Expand Down
6 changes: 6 additions & 0 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Hashable,
Literal,
Expand Down Expand Up @@ -140,6 +141,7 @@ class PregelExecutableTask(NamedTuple):
id: str
path: tuple[Union[str, int, tuple], ...]
scheduled: bool = False
writers: Sequence[Runnable] = ()


class StateSnapshot(NamedTuple):
Expand Down Expand Up @@ -233,12 +235,14 @@ def __eq__(self, value: object) -> bool:


N = TypeVar("N", bound=Hashable)
PARENT = Literal["__parent__"]


@dataclasses.dataclass(**_DC_KWARGS)
class Command(Generic[N]):
"""One or more commands to update the graph's state and send messages to nodes."""

graph: Optional[Union[PARENT, str]] = None
update: Optional[dict[str, Any]] = None
send: Union[Send, Sequence[Send]] = ()
resume: Optional[Union[Any, dict[str, Any]]] = None
Expand All @@ -252,6 +256,8 @@ def __repr__(self) -> str:
)
return f"Command({contents})"

PARENT = ClassVar[PARENT] = "__parent__"


StreamChunk = tuple[tuple[str, ...], str, Any]

Expand Down

0 comments on commit fed60e7

Please sign in to comment.