From 16aefea482a7aaa2590232b2160b769566205a77 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Wed, 11 Sep 2024 19:39:49 +0200 Subject: [PATCH] refactor(server): enable use of StateGraph work on #24 --- .../DiagramCorrectionProcess.java | 5 +- .../image_to_diagram/ImageToDiagram.java | 20 ++++---- .../image_to_diagram/ImageToDiagramTest.java | 5 +- .../langgraph4j/LangGraphStreamingServer.java | 47 +++++++++++-------- .../LangGraphStreamingServerTest.java | 42 +++++++---------- 5 files changed, 62 insertions(+), 57 deletions(-) diff --git a/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/DiagramCorrectionProcess.java b/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/DiagramCorrectionProcess.java index 81a0812..b39fc3a 100644 --- a/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/DiagramCorrectionProcess.java +++ b/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/DiagramCorrectionProcess.java @@ -16,6 +16,7 @@ import static java.util.Optional.ofNullable; import static org.bsc.langgraph4j.StateGraph.END; import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async; +import static org.bsc.langgraph4j.utils.CollectionsUtils.last; import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf; @Slf4j( topic="DiagramCorrectionProcess" ) @@ -44,7 +45,7 @@ CompletableFuture> reviewResult(State state) { CompletableFuture> future = new CompletableFuture<>(); try { - var diagramCode = state.diagramCode().last() + var diagramCode = last( state.diagramCode() ) .orElseThrow(() -> new IllegalArgumentException("no diagram code provided!")); var error = state.evaluationError() @@ -71,7 +72,7 @@ CompletableFuture> reviewResult(State state) { private CompletableFuture> evaluateResult(State state) { - var diagramCode = state.diagramCode().last() + var diagramCode = last( state.diagramCode() ) .orElseThrow(() -> new IllegalArgumentException("no diagram code provided!")); return PlantUMLAction.validate( diagramCode ) diff --git a/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/ImageToDiagram.java b/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/ImageToDiagram.java index 5230d94..c41006e 100644 --- a/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/ImageToDiagram.java +++ b/image-to-diagram/src/main/java/dev/langchain4j/image_to_diagram/ImageToDiagram.java @@ -6,20 +6,22 @@ import org.bsc.async.AsyncGenerator; import org.bsc.langgraph4j.NodeOutput; import org.bsc.langgraph4j.state.AgentState; -import org.bsc.langgraph4j.state.AppendableValue; +import org.bsc.langgraph4j.state.AppenderChannel; +import org.bsc.langgraph4j.state.Channel; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; -import java.util.Map; -import java.util.Optional; +import java.util.*; -import static java.util.Optional.ofNullable; +import static org.bsc.langgraph4j.utils.CollectionsUtils.*; public interface ImageToDiagram { class State extends AgentState { - + static Map> SCHEMA = mapOf( + "messages", AppenderChannel.of(ArrayList::new) + ); public State(Map initData) { super(initData); } @@ -27,8 +29,8 @@ public State(Map initData) { public Optional diagram() { return value("diagram"); } - public AppendableValue diagramCode() { - return appendableValue("diagramCode"); + public List diagramCode() { + return this.>value("diagramCode").orElseGet(Collections::emptyList); } public Optional evaluationResult() { return value("evaluationResult" ); @@ -49,10 +51,10 @@ public boolean isExecutionError() { public boolean lastTwoDiagramsAreEqual() { if( diagramCode().size() < 2 ) return false; - String last = diagramCode().last() + String last = last( diagramCode() ) .map(String::trim) .orElseThrow( () -> new IllegalStateException( "last() is null!" ) ); - String prev = diagramCode().lastMinus(1) + String prev = lastMinus( diagramCode(), 1) .map(String::trim) .orElseThrow( () -> new IllegalStateException( "last(-1) is null!" ) ); diff --git a/image-to-diagram/src/test/java/dev/langchain4j/image_to_diagram/ImageToDiagramTest.java b/image-to-diagram/src/test/java/dev/langchain4j/image_to_diagram/ImageToDiagramTest.java index 76aca4c..bdd88de 100644 --- a/image-to-diagram/src/test/java/dev/langchain4j/image_to_diagram/ImageToDiagramTest.java +++ b/image-to-diagram/src/test/java/dev/langchain4j/image_to_diagram/ImageToDiagramTest.java @@ -22,6 +22,7 @@ import static java.lang.String.format; import static java.util.Optional.ofNullable; +import static org.bsc.langgraph4j.utils.CollectionsUtils.last; import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf; import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -136,7 +137,7 @@ public void imageToDiagram() throws Exception { } System.out.println( ofNullable(state) - .flatMap( s -> s.diagramCode().last() ).orElse("NO DIAGRAM CODE") ); + .flatMap( s -> last( s.diagramCode() ) ).orElse("NO DIAGRAM CODE") ); } @@ -160,7 +161,7 @@ public String reviewDiagram( String diagramId ) throws Exception { }) .join(); - var code = result.diagramCode().last(); + var code = last( result.diagramCode() ); assertTrue( code.isPresent() ); assertEquals( expectedCode, code.get().trim() ); diff --git a/server-jetty/src/main/java/org/bsc/langgraph4j/LangGraphStreamingServer.java b/server-jetty/src/main/java/org/bsc/langgraph4j/LangGraphStreamingServer.java index 27f6843..ce204d2 100644 --- a/server-jetty/src/main/java/org/bsc/langgraph4j/LangGraphStreamingServer.java +++ b/server-jetty/src/main/java/org/bsc/langgraph4j/LangGraphStreamingServer.java @@ -32,7 +32,7 @@ * of LangGraph. * Implementations of this interface can be used to create a web server * that exposes an API for interacting with compiled language graphs. - */ + */ public interface LangGraphStreamingServer { Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class); @@ -45,7 +45,7 @@ static Builder builder() { class Builder { private int port = 8080; - private Map inputArgs = new HashMap<>(); + private final Map inputArgs = new HashMap<>(); private String title = null; private ObjectMapper objectMapper; @@ -74,7 +74,7 @@ public Builder addInputStringArg(String name) { return this; } - public LangGraphStreamingServer build(CompiledGraph compiledGraph) throws Exception { + public LangGraphStreamingServer build(StateGraph stateGraph) throws Exception { Server server = new Server(); @@ -82,28 +82,31 @@ public LangGraphStreamingServer build(CompiledGraph(stateGraph, initData)), "/init"); + // context.setContextPath("/"); // Add the streaming servlet - context.addServlet(new ServletHolder(new GraphExecutionServlet(compiledGraph, objectMapper)), "/stream"); + context.addServlet(new ServletHolder(new GraphExecutionServlet(stateGraph, objectMapper)), "/stream"); - InitData initData = new InitData(title, inputArgs); - context.addServlet(new ServletHolder(new GraphInitServlet(compiledGraph, initData)), "/init"); - - Handler.Sequence handlerList = new Handler.Sequence(resourceHandler, context); + var handlerList = new Handler.Sequence( resourceHandler, context); server.setHandler(handlerList); @@ -127,12 +130,12 @@ public CompletableFuture start() throws Exception { class GraphExecutionServlet extends HttpServlet { - final CompiledGraph compiledGraph; + final StateGraph stateGraph; final ObjectMapper objectMapper; - public GraphExecutionServlet(CompiledGraph compiledGraph, ObjectMapper objectMapper) { - Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null"); - this.compiledGraph = compiledGraph; + public GraphExecutionServlet(StateGraph stateGraph, ObjectMapper objectMapper) { + Objects.requireNonNull(stateGraph, "stateGraph cannot be null"); + this.stateGraph = stateGraph; this.objectMapper = objectMapper; } @@ -151,6 +154,10 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) var asyncContext = request.startAsync(); try { + var config = CompileConfig.builder().build(); + + var compiledGraph = stateGraph.compile(config); + compiledGraph.stream(dataMap) .forEachAsync(s -> { try { @@ -197,7 +204,7 @@ record InitData( */ class GraphInitServlet extends HttpServlet { - final CompiledGraph compiledGraph; + final StateGraph stateGraph; final ObjectMapper objectMapper = new ObjectMapper(); final InitData initData; @@ -212,9 +219,9 @@ public Result(GraphRepresentation graph, InitData initData) { } } - public GraphInitServlet(CompiledGraph compiledGraph, InitData initData) { - Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null"); - this.compiledGraph = compiledGraph; + public GraphInitServlet(StateGraph stateGraph, InitData initData) { + Objects.requireNonNull(stateGraph, "stateGraph cannot be null"); + this.stateGraph = stateGraph; this.initData = initData; } @@ -223,7 +230,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t response.setContentType("application/json"); response.setCharacterEncoding("UTF-8"); - GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID, initData.title(), false); + GraphRepresentation graph = stateGraph.getGraph(GraphRepresentation.Type.MERMAID, initData.title(), false); final Result result = new Result(graph, initData); String resultJson = objectMapper.writeValueAsString(result); diff --git a/server-jetty/src/test/java/org/bsc/langgraph4j/LangGraphStreamingServerTest.java b/server-jetty/src/test/java/org/bsc/langgraph4j/LangGraphStreamingServerTest.java index 7891678..556371a 100644 --- a/server-jetty/src/test/java/org/bsc/langgraph4j/LangGraphStreamingServerTest.java +++ b/server-jetty/src/test/java/org/bsc/langgraph4j/LangGraphStreamingServerTest.java @@ -4,6 +4,7 @@ import org.bsc.langgraph4j.state.AgentState; import static org.bsc.langgraph4j.StateGraph.END; +import static org.bsc.langgraph4j.StateGraph.START; import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async; import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async; import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf; @@ -12,26 +13,6 @@ public class LangGraphStreamingServerTest { public static void main(String[] args) throws Exception { - StateGraph workflow = new StateGraph<>(AgentState::new); - - workflow.setEntryPoint("agent_1"); - - workflow.addNode("agent_1", node_async((state ) -> { - System.out.println("agent_1 "); - System.out.println(state); - return mapOf("prop1", "value1"); - }) ) ; - - workflow.addNode("agent_2", node_async( state -> { - - System.out.print( "agent_2: "); - System.out.println( state ); - return mapOf("prop2", "value2"); - })); - - workflow.addEdge("agent_2", "agent_1" ); - - EdgeAction conditionalAge = new EdgeAction<>() { int steps= 0; @Override @@ -44,16 +25,29 @@ public String apply(AgentState state) { } }; - workflow.addConditionalEdges("agent_1", - edge_async(conditionalAge), mapOf( "a2", "agent_2", "end", END ) ); - CompiledGraph app = workflow.compile(); + StateGraph workflow = new StateGraph<>(AgentState::new) + .addNode("agent_1", node_async((state ) -> { + System.out.println("agent_1 "); + System.out.println(state); + return mapOf("prop1", "value1"); + }) ) + .addNode("agent_2", node_async( state -> { + System.out.print( "agent_2: "); + System.out.println( state ); + return mapOf("prop2", "value2"); + })) + .addEdge(START, "agent_1") + .addEdge("agent_2", "agent_1" ) + .addConditionalEdges("agent_1", + edge_async(conditionalAge), mapOf( "a2", "agent_2", "end", END ) ) + ; LangGraphStreamingServer server = LangGraphStreamingServer.builder() .port(8080) .title("LANGGRAPH4j - TEST") .addInputStringArg("input") - .build(app); + .build(workflow); server.start().join();