diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 84eb327a8..999ba460f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -26,7 +26,10 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; -import static org.opensearch.flowframework.util.ParseUtils.*; +import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; +import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseLLM; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** * This represents a process node (step) in a workflow graph in the {@link Template}. diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 3ae59eff8..44270d8e6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -68,7 +68,12 @@ public RegisterAgentStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); @@ -77,7 +82,11 @@ public CompletableFuture execute(List data) throws I public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { logger.info("Remote Agent registration successful"); registerAgentModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), data.get(0).getWorkflowId()) + new WorkflowData( + Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) ); } @@ -99,6 +108,12 @@ public void onFailure(Exception e) { Instant lastUpdateTime = null; String appType = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 8b4e901fc..339142139 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -15,6 +15,7 @@ import org.opensearch.ml.common.agent.MLToolSpec; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,13 +40,24 @@ public class ToolStep implements WorkflowStep { static final String NAME = "create_tool"; @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { String type = null; String name = null; String description = null; Map parameters = Collections.emptyMap(); Boolean includeOutputInAgentResponse = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); @@ -95,7 +107,13 @@ public CompletableFuture execute(List data) throws I MLToolSpec mlToolSpec = builder.build(); - toolFuture.complete(new WorkflowData(Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), data.get(0).getWorkflowId())); + toolFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); } logger.info("Tool registered successfully {}", type); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index e04aa63d7..0f4b33471 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -69,7 +68,8 @@ public void setUp() throws Exception { Map.entry("last_updated_time", 1689793598499L), Map.entry("app_type", "app") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -87,7 +87,12 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup return null; }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); - CompletableFuture future = registerAgentStep.execute(List.of(inputData)); + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); @@ -108,7 +113,12 @@ public void testRegisterAgentFailure() throws IOException { return null; }).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); - CompletableFuture future = registerAgentStep.execute(List.of(inputData)); + CompletableFuture future = registerAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 079e1c9e2..c7e8df2d8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -33,14 +32,20 @@ public void setUp() throws Exception { Map.entry("parameters", Collections.emptyMap()), Map.entry("include_output_in_agent_response", false) ), - "test-id" + "test-id", + "test-node-id" ); } public void testTool() throws IOException, ExecutionException, InterruptedException { ToolStep toolStep = new ToolStep(); - CompletableFuture future = toolStep.execute(List.of(inputData)); + CompletableFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());