Skip to content

Commit

Permalink
Handled interface changes
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Nov 30, 2023
1 parent 90bf495 commit 952e825
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.*;
import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseLLM;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap;

/**
* This represents a process node (step) in a workflow graph in the {@link Template}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ public RegisterAgentStep(MachineLearningNodeClient mlClient) {
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {

CompletableFuture<WorkflowData> registerAgentModelFuture = new CompletableFuture<>();

Expand All @@ -77,7 +82,11 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws I
public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
logger.info("Remote Agent registration successful");
registerAgentModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())), data.get(0).getWorkflowId())
new WorkflowData(
Map.ofEntries(Map.entry(AGENT_ID, mlRegisterAgentResponse.getAgentId())),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

Expand All @@ -99,6 +108,12 @@ public void onFailure(Exception e) {
Instant lastUpdateTime = null;
String appType = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

Expand Down
22 changes: 20 additions & 2 deletions src/main/java/org/opensearch/flowframework/workflow/ToolStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -39,13 +40,24 @@ public class ToolStep implements WorkflowStep {
static final String NAME = "create_tool";

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
String type = null;
String name = null;
String description = null;
Map<String, String> parameters = Collections.emptyMap();
Boolean includeOutputInAgentResponse = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

Expand Down Expand Up @@ -95,7 +107,13 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws I

MLToolSpec mlToolSpec = builder.build();

toolFuture.complete(new WorkflowData(Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), data.get(0).getWorkflowId()));
toolFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

logger.info("Tool registered successfully {}", type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -69,7 +68,8 @@ public void setUp() throws Exception {
Map.entry("last_updated_time", 1689793598499L),
Map.entry("app_type", "app")
),
"test-id"
"test-id",
"test-node-id"
);
}

Expand All @@ -87,7 +87,12 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup
return null;
}).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = registerAgentStep.execute(List.of(inputData));
CompletableFuture<WorkflowData> future = registerAgentStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

Expand All @@ -108,7 +113,12 @@ public void testRegisterAgentFailure() throws IOException {
return null;
}).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = registerAgentStep.execute(List.of(inputData));
CompletableFuture<WorkflowData> future = registerAgentStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand All @@ -33,14 +32,20 @@ public void setUp() throws Exception {
Map.entry("parameters", Collections.emptyMap()),
Map.entry("include_output_in_agent_response", false)
),
"test-id"
"test-id",
"test-node-id"
);
}

public void testTool() throws IOException, ExecutionException, InterruptedException {
ToolStep toolStep = new ToolStep();

CompletableFuture<WorkflowData> future = toolStep.execute(List.of(inputData));
CompletableFuture<WorkflowData> future = toolStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isDone());
assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass());
Expand Down

0 comments on commit 952e825

Please sign in to comment.