Skip to content

Commit

Permalink
Merge branch 'release/1.0-20240807-1'
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 7, 2024
2 parents 22ca263 + 475ffab commit dc47678
Show file tree
Hide file tree
Showing 16 changed files with 571 additions and 217 deletions.
91 changes: 46 additions & 45 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@

The main type of graph in `langgraph` is the `StatefulGraph`. This graph is parameterized by a state object that it passes around to each node.
Each node then returns operations to update that state. These operations can either SET specific attributes on the state (e.g. overwrite the existing values) or ADD to the existing attribute.
Whether to set or add is denoted by initialize the property with a `AppendableValue`. The State must inherit from `AgentState` base class (that essentially is a `Map` wrapper).
Whether to set or add is described in the state's schema provided to the graph. The schema is a Map of Channels, each Channel represent an attribute in the state. If an attribute is described with an `AppendeChannel` it will be a List and each element referring the attribute will be automaically added by graph during processing. The State must inherit from `AgentState` base class (that essentially is a `Map` wrapper).

```java
public class AgentState {

public AgentState( Map<String,Object> initData ) { ... };
public AgentState( Map<String,Object> initData ) { ... }

public final java.util.Map<String,Object> data() { ... };
public final java.util.Map<String,Object> data() { ... }

public final <T> Optional<T> value(String key) { ... };

public final <T> AppendableValue<T> appendableValue(String key ) { ... };
public final <T> Optional<T> value(String key) { ... }
public final <T> T value(String key, T defaultValue ) { ... }
public final <T> T value(String key, Supplier<T> defaultProvider ) { ... }


}
```
Expand Down Expand Up @@ -128,19 +129,24 @@ Below you can find a piece of code of the `AgentExecutor` to give you an idea of

public static class State implements AgentState {

public State(Map<String, Object> initData) {
super(initData);
}

Optional<String> input() {
return value("input");
}
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
AppendableValue<IntermediateStep> intermediateSteps() {
return appendableValue("intermediate_steps");
}
// the state's (partial) schema
static Map<String, Channel<?>> SCHEMA = mapOf(
"intermediate_steps", AppenderChannel.<IntermediateStep>of(ArrayList::new)
);

public State(Map<String, Object> initData) {
super(initData);
}

Optional<String> input() {
return value("input");
}
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
List<IntermediateStep> intermediateSteps() {
return this.<List<IntermediateStep>>value("intermediate_steps").orElseGet(emptyList());
}

}

Expand All @@ -155,32 +161,27 @@ var agentRunnable = Agent.builder()
.tools( toolSpecifications )
.build();

var workflow = new StateGraph<>(State::new);

workflow.setEntryPoint("agent");

workflow.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state)) // see implementation in the repo code
);

workflow.addNode( "action", node_async( state ->
executeTools(toolInfoList, state)) // see implementation in the repo code
);

workflow.addConditionalEdge(
"agent",
edge_async( state -> {
if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) {
return "end";
}
return "continue";
}),
Map.of("continue", "action", "end", END)
);

workflow.addEdge("action", "agent");

var app = workflow.compile();
// Fluent Interface
var app = new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
)
.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
)
.addConditionalEdges(
"agent",
edge_async( state -> {
if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) {
return "end";
}
return "continue";
}),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile();

return app.stream( inputs );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.state.AgentState;

Expand All @@ -20,6 +19,7 @@

import static java.util.Collections.emptyList;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.listOf;
Expand Down Expand Up @@ -50,8 +50,7 @@ public Optional<String> generation() {

}
public List<String> documents() {
Optional<List<String>> result = value("documents");
return result.orElse(emptyList());
return this.<List<String>>value("documents").orElse(emptyList());
}

}
Expand Down Expand Up @@ -248,43 +247,40 @@ private String gradeGeneration_v_documentsAndQuestion( State state ) {
}

public CompiledGraph<State> buildGraph() throws Exception {
var workflow = new StateGraph<>(State::new);

// Define the nodes
workflow.addNode("web_search", node_async(this::webSearch) ); // web search
workflow.addNode("retrieve", node_async(this::retrieve) ); // retrieve
workflow.addNode("grade_documents", node_async(this::gradeDocuments) ); // grade documents
workflow.addNode("generate", node_async(this::generate) ); // generatae
workflow.addNode("transform_query", node_async(this::transformQuery)); // transform_query

// Build graph
workflow.setConditionalEntryPoint(
edge_async(this::routeQuestion),
mapOf(
"web_search", "web_search",
"vectorstore", "retrieve"
));

workflow.addEdge("web_search", "generate");
workflow.addEdge("retrieve", "grade_documents");
workflow.addConditionalEdges(
"grade_documents",
edge_async(this::decideToGenerate),
mapOf(
"transform_query","transform_query",
"generate", "generate"
));
workflow.addEdge("transform_query", "retrieve");
workflow.addConditionalEdges(
"generate",
edge_async(this::gradeGeneration_v_documentsAndQuestion),
mapOf(
"not supported", "generate",
"useful", END,
"not useful", "transform_query"
));

return workflow.compile();
return new StateGraph<>(State::new)
// Define the nodes
.addNode("web_search", node_async(this::webSearch) ) // web search
.addNode("retrieve", node_async(this::retrieve) ) // retrieve
.addNode("grade_documents", node_async(this::gradeDocuments) ) // grade documents
.addNode("generate", node_async(this::generate) ) // generatae
.addNode("transform_query", node_async(this::transformQuery)) // transform_query
// Build graph
.addConditionalEdges(START,
edge_async(this::routeQuestion),
mapOf(
"web_search", "web_search",
"vectorstore", "retrieve"
))

.addEdge("web_search", "generate")
.addEdge("retrieve", "grade_documents")
.addConditionalEdges(
"grade_documents",
edge_async(this::decideToGenerate),
mapOf(
"transform_query","transform_query",
"generate", "generate"
))
.addEdge("transform_query", "retrieve")
.addConditionalEdges(
"generate",
edge_async(this::gradeGeneration_v_documentsAndQuestion),
mapOf(
"not supported", "generate",
"useful", END,
"not useful", "transform_query"
))
.compile();
}

public static void main( String[] args ) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;

import java.util.*;
import java.util.stream.Collectors;

import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

public class AgentExecutor {

public static class State extends AgentState {
static Map<String, Channel<?>> SCHEMA = mapOf(
"intermediate_steps", AppenderChannel.<IntermediateStep>of(ArrayList::new)
);

public State(Map<String, Object> initData) {
super(initData);
Expand All @@ -33,8 +39,8 @@ Optional<String> input() {
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
AppendableValue<IntermediateStep> intermediateSteps() {
return appendableValue("intermediate_steps");
List<IntermediateStep> intermediateSteps() {
return this.<List<IntermediateStep>>value("intermediate_steps").orElseGet(ArrayList::new);
}


Expand All @@ -45,7 +51,7 @@ Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception
var input = state.input()
.orElseThrow(() -> new IllegalArgumentException("no input provided!"));

var intermediateSteps = state.intermediateSteps().values();
var intermediateSteps = state.intermediateSteps();

var response = agentRunnable.execute( input, intermediateSteps );

Expand Down Expand Up @@ -106,27 +112,21 @@ public CompiledGraph<State> compile(ChatLanguageModel chatLanguageModel, List<Ob
.tools( toolSpecifications )
.build();

var workflow = new StateGraph<>(State::new);

workflow.setEntryPoint("agent");

workflow.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
);

workflow.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
);

workflow.addConditionalEdges(
"agent",
edge_async(this::shouldContinue),
mapOf("continue", "action", "end", END)
);

workflow.addEdge("action", "agent");

return workflow.compile();
return new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
)
.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
)
.addConditionalEdges(
"agent",
edge_async(this::shouldContinue),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile();

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.junit.jupiter.api.Test;

import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.listOf;
Expand Down Expand Up @@ -93,18 +94,17 @@ void executeAgentWithDoubleToolInvocation() throws Exception {
@Test
public void getGraphTest() throws Exception {

var workflow = new StateGraph<>(AgentState::new);

workflow.setEntryPoint("agent");
workflow.addNode( "agent", node_async( state -> mapOf() ));
workflow.addNode( "action", node_async( state -> mapOf() ));
workflow.addConditionalEdges(
"agent",
edge_async(state -> ""),
mapOf("continue", "action", "end", END)
);
workflow.addEdge("action", "agent");
var app = workflow.compile();
var app = new StateGraph<>(AgentState::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state -> mapOf() ))
.addNode( "action", node_async( state -> mapOf() ))
.addConditionalEdges(
"agent",
edge_async(state -> ""),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile();

var plantUml = app.getGraph( GraphRepresentation.Type.PLANTUML, "Agent Executor" );

Expand Down
Loading

0 comments on commit dc47678

Please sign in to comment.