Skip to content

Commit

Permalink
Merge branch 'release/1.0-20240809'
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 9, 2024
2 parents aa515bc + bffa8a4 commit 92bf3df
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,42 @@
import lombok.Value;
import lombok.experimental.Accessors;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

@Value
@Accessors( fluent = true)
class AgentAction {
public class AgentAction {
public static final Serializer SERIALIZER = new Serializer();
@NonNull
ToolExecutionRequest toolExecutionRequest;
String log;

public static class Serializer implements org.bsc.langgraph4j.serializer.Serializer<AgentAction> {

private Serializer() {}

@Override
public void write(AgentAction action, ObjectOutput out) throws IOException {
ToolExecutionRequest ter = action.toolExecutionRequest();
out.writeUTF( ter.id() );
out.writeUTF( ter.name() );
out.writeUTF( ter.arguments() );
out.writeUTF( action.log() );

}

@Override
public AgentAction read(ObjectInput in) throws IOException, ClassNotFoundException {
ToolExecutionRequest ter = ToolExecutionRequest.builder()
.id(in.readUTF())
.name(in.readUTF())
.arguments(in.readUTF())
.build();

return new AgentAction( ter, in.readUTF() );

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import dev.langchain4j.model.output.FinishReason;
import lombok.var;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.*;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;
import org.bsc.langgraph4j.state.AppenderChannel;
Expand All @@ -24,6 +23,69 @@

public class AgentExecutor {

public class GraphBuilder {
private BaseCheckpointSaver checkpointSaver;
private ChatLanguageModel chatLanguageModel;
private List<Object> objectsWithTools;

public GraphBuilder checkpointSaver(BaseCheckpointSaver checkpointSaver) {
this.checkpointSaver = checkpointSaver;
return this;
}
public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
this.chatLanguageModel = chatLanguageModel;
return this;
}
public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
this.objectsWithTools = objectsWithTools;
return this;
}

public CompiledGraph<State> build() throws GraphStateException {
Objects.requireNonNull(objectsWithTools, "objectsWithTools is required!");
Objects.requireNonNull(chatLanguageModel, "chatLanguageModel is required!");


var toolInfoList = ToolInfo.fromList( objectsWithTools );

final List<ToolSpecification> toolSpecifications = toolInfoList.stream()
.map(ToolInfo::specification)
.collect(Collectors.toList());

var agentRunnable = Agent.builder()
.chatLanguageModel(chatLanguageModel)
.tools( toolSpecifications )
.build();

CompileConfig.Builder config = new CompileConfig.Builder();

if( checkpointSaver != null ) {
config.checkpointSaver(checkpointSaver);
}

return new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
)
.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
)
.addConditionalEdges(
"agent",
edge_async(AgentExecutor.this::shouldContinue),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile( config.build() );

}
}

public final GraphBuilder builder() {
return new GraphBuilder();
}

public static class State extends AgentState {
static Map<String, Channel<?>> SCHEMA = mapOf(
"intermediate_steps", AppenderChannel.<IntermediateStep>of(ArrayList::new)
Expand All @@ -43,7 +105,6 @@ List<IntermediateStep> intermediateSteps() {
return this.<List<IntermediateStep>>value("intermediate_steps").orElseGet(ArrayList::new);
}


}

Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception {
Expand Down Expand Up @@ -100,40 +161,4 @@ String shouldContinue(State state) {
return "continue";
}

public CompiledGraph<State> compile(ChatLanguageModel chatLanguageModel, List<Object> objectsWithTools) throws Exception {
var toolInfoList = ToolInfo.fromList( objectsWithTools );

final List<ToolSpecification> toolSpecifications = toolInfoList.stream()
.map(ToolInfo::specification)
.collect(Collectors.toList());

var agentRunnable = Agent.builder()
.chatLanguageModel(chatLanguageModel)
.tools( toolSpecifications )
.build();

return new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
)
.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
)
.addConditionalEdges(
"agent",
edge_async(this::shouldContinue),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile();

}

public AsyncGenerator<NodeOutput<State>> execute(ChatLanguageModel chatLanguageModel, Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception {

var app = compile(chatLanguageModel, objectsWithTools);

return app.stream( inputs );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,34 @@
import lombok.Value;
import lombok.experimental.Accessors;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Map;
import java.util.Objects;

@Value
@Accessors( fluent = true)
class AgentFinish {
Map<String,Object> returnValues;
public class AgentFinish {
public static final Serializer SERIALIZER = new Serializer();

Map<String, Object> returnValues;
String log;

public static class Serializer implements org.bsc.langgraph4j.serializer.Serializer<AgentFinish> {

private Serializer() {
}

@Override
public void write(AgentFinish object, ObjectOutput out) throws IOException {
out.writeObject(object.returnValues);
out.writeUTF(object.log);
}

@Override
public AgentFinish read(ObjectInput in) throws IOException, ClassNotFoundException {
return new AgentFinish((Map<String, Object>) in.readObject(), in.readUTF());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
package dev.langchain4j.agentexecutor;

import lombok.Value;
import lombok.experimental.Accessors;

@Value
@Accessors( fluent = true)
class AgentOutcome {
AgentAction action;
AgentFinish finish;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

public class AgentOutcome implements Externalizable {
private AgentAction action;
private AgentFinish finish;

AgentAction action() { return action; }
AgentFinish finish() { return finish; }

public AgentOutcome() {}
public AgentOutcome( AgentAction action, AgentFinish finish) {
this.action = action;
this.finish = finish;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
AgentAction.SERIALIZER.writeNullable(action, out);
AgentFinish.SERIALIZER.writeNullable(finish, out);

}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
AgentAction.SERIALIZER.readNullable(in)
.ifPresent( value -> action = value );
AgentFinish.SERIALIZER.readNullable(in)
.ifPresent( value -> finish = value );
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
package dev.langchain4j.agentexecutor;

import lombok.Value;
import lombok.experimental.Accessors;

@Value
@Accessors( fluent = true)
public class IntermediateStep {
AgentAction action;
String observation;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

public class IntermediateStep implements Externalizable {
private AgentAction action;
private String observation;

AgentAction action() { return action; }
String observation() { return observation; }

public IntermediateStep() {}
public IntermediateStep( AgentAction action, String observation) {
this.action = action;
this.observation = observation;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
AgentAction.SERIALIZER.write(action, out);
out.writeUTF(observation);
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
action = AgentAction.SERIALIZER.read(in);
observation = in.readUTF();
}
}
Loading

0 comments on commit 92bf3df

Please sign in to comment.