Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 25, 2024
1 parent abc0c8c commit 8e1cd0e
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 0 deletions.
76 changes: 76 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14378,3 +14378,79 @@ async def dummy(state):
graph = graph_builder.compile()

assert graph.get_graph(xray=True).to_json() == graph.get_graph(xray=False).to_json()


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool

@tool(return_direct=True)
def get_user_name() -> GraphCommand:
"""Retrieve user name"""
return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT)

subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
subgraph_builder.add_edge(START, "tool")
subgraph = subgraph_builder.compile()

class CustomParentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# this key is not available to the child graph
user_name: str

builder = StateGraph(CustomParentState)
builder.add_node("alice", subgraph)
builder.add_edge(START, "alice")
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert graph.invoke({"messages": [("user", "get user name")]}, config) == {
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
}
assert graph.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"alice": {
"user_name": "Meow",
}
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
80 changes: 80 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12565,3 +12565,83 @@ def normalize_config(config: Optional[dict]) -> Optional[dict]:
assert stream_task["interrupts"] == history_task.interrupts
assert stream_task.get("error") == history_task.error
assert stream_task.get("state") == history_task.state


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_parent_command(checkpointer_name: str) -> None:
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool

@tool(return_direct=True)
def get_user_name() -> GraphCommand:
"""Retrieve user name"""
return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT)

subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
subgraph_builder.add_edge(START, "tool")
subgraph = subgraph_builder.compile()

class CustomParentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# this key is not available to the child graph
user_name: str

builder = StateGraph(CustomParentState)
builder.add_node("alice", subgraph)
builder.add_edge(START, "alice")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert await graph.ainvoke(
{"messages": [("user", "get user name")]}, config
) == {
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
}
assert await graph.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(
content="get user name",
additional_kwargs={},
response_metadata={},
),
],
"user_name": "Meow",
},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"alice": {
"user_name": "Meow",
}
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)

0 comments on commit 8e1cd0e

Please sign in to comment.