Skip to content

Commit

Permalink
feat(Server): finalize thread support
Browse files Browse the repository at this point in the history
- NodeOutput Json Serialization
- read thread from get parameter
- add thread on straming result

resolve #24
  • Loading branch information
bsorrentino committed Sep 24, 2024
1 parent 1cac390 commit 3f4ee84
Showing 1 changed file with 36 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,24 @@ public CompletableFuture<Void> start() throws Exception {
}
}


class NodeOutputSerializer extends StdSerializer<NodeOutput> {
Logger log = LangGraphStreamingServer.log;

protected NodeOutputSerializer() {
super( NodeOutput.class );
}

@Override
public void serialize(NodeOutput nodeOutput, JsonGenerator gen, SerializerProvider serializerProvider) throws IOException {
log.trace( "NodeOutputSerializer start!" );
gen.writeStartObject();
gen.writeStringField("node", nodeOutput.node());
gen.writeObjectField("state", nodeOutput.state());
gen.writeEndObject();
}
}

record PersistentConfig(String sessionId, String threadId) {
public PersistentConfig {
Objects.requireNonNull(sessionId);
Expand All @@ -168,6 +186,9 @@ public GraphStreamServlet(StateGraph<? extends AgentState> stateGraph, ObjectMap
Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
this.stateGraph = stateGraph;
this.objectMapper = objectMapper;
var module = new SimpleModule();
module.addSerializer(NodeOutput.class, new NodeOutputSerializer());
objectMapper.registerModule(module);
this.saver = saver;
}

Expand All @@ -193,6 +214,9 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
var session = request.getSession(true);
Objects.requireNonNull(session, "session cannot be null");

var threadId = request.getParameter("thread");
Objects.requireNonNull(threadId, "thread cannot be null");

final PrintWriter writer = response.getWriter();

Map<String, Object> dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<Map<String, Object>>() {
Expand All @@ -202,7 +226,6 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
var asyncContext = request.startAsync();

try {
var threadId = request.getParameter("threadId");

var config = new PersistentConfig( session.getId(), threadId);

Expand All @@ -215,21 +238,18 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
compiledGraph.streamSnapshots(dataMap, runnableConfig(config) )
.forEachAsync(s -> {
try {
LangGraphStreamingServer.log.trace("{}", s);

writer.print("{");
writer.printf("\"node\": \"%s\"", s.node());
try {
var stateAsString = objectMapper.writeValueAsString(s.state().data());
writer.printf(",\"state\": %s", stateAsString);
writer.printf("[ \"%s\",", threadId);
writer.println();
var outputAsString = objectMapper.writeValueAsString(s);
writer.println(outputAsString);
writer.println( "]" );
} catch (IOException e) {
LangGraphStreamingServer.log.info("error serializing state", e);
writer.printf(",\"state\": {}");
log.warn("error serializing state", e);
}
writer.print("}");
writer.flush();
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
} catch ( InterruptedException e) {
throw new RuntimeException(e);
}

Expand Down Expand Up @@ -266,14 +286,15 @@ public InitData( String graph, String title, Map<String, ArgumentMetadata> args)
}

class InitDataSerializer extends StdSerializer<InitData> {
Logger log = LangGraphStreamingServer.log;

protected InitDataSerializer(Class<InitData> t) {
super(t);
}

@Override
public void serialize(InitData initData, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
LangGraphStreamingServer.log.trace( "InitDataSerializer start!" );
log.trace( "InitDataSerializer start!" );
jsonGenerator.writeStartObject();

jsonGenerator.writeStringField("graph", initData.graph());
Expand All @@ -299,7 +320,8 @@ public void serialize(InitData initData, JsonGenerator jsonGenerator, Serializer
*/
class GraphInitServlet extends HttpServlet {

private static final Logger log = LoggerFactory.getLogger(GraphInitServlet.class);
Logger log = LangGraphStreamingServer.log;

final StateGraph<? extends AgentState> stateGraph;
final ObjectMapper objectMapper = new ObjectMapper();
final InitData initData;
Expand All @@ -324,7 +346,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t

String resultJson = objectMapper.writeValueAsString(initData);

LangGraphStreamingServer.log.trace( "{}", resultJson);
log.trace( "{}", resultJson);

// Start asynchronous processing
final PrintWriter writer = response.getWriter();
Expand Down

0 comments on commit 3f4ee84

Please sign in to comment.