Skip to content

Commit

Permalink
Permit ordering of tools in register agent step (opensearch-project#283)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 18, 2023
1 parent 100c701 commit 597640f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 22 deletions.
16 changes: 16 additions & 0 deletions src/main/java/org/opensearch/flowframework/model/WorkflowNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
import static org.opensearch.flowframework.util.ParseUtils.parseStringToObjectMap;
Expand Down Expand Up @@ -93,6 +94,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (PipelineProcessor p : (PipelineProcessor[]) e.getValue()) {
xContentBuilder.value(p);
}
} else if (TOOLS_FIELD.equals(e.getKey())) {
for (String t : (String[]) e.getValue()) {
xContentBuilder.value(t);
}
} else {
for (Map<?, ?> map : (Map<?, ?>[]) e.getValue()) {
buildStringToObjectMap(xContentBuilder, map);
Expand Down Expand Up @@ -151,6 +156,12 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
processorList.add(PipelineProcessor.parse(parser));
}
userInputs.put(inputFieldName, processorList.toArray(new PipelineProcessor[0]));
} else if (TOOLS_FIELD.equals(inputFieldName)) {
List<String> toolsList = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
toolsList.add(parser.text());
}
userInputs.put(inputFieldName, toolsList.toArray(new String[0]));
} else {
List<Map<String, Object>> mapList = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
Expand All @@ -173,6 +184,11 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
case DOUBLE:
userInputs.put(inputFieldName, parser.doubleValue());
break;
case BIG_INTEGER:
userInputs.put(inputFieldName, parser.bigIntegerValue());
break;
default:
throw new IOException("Unable to parse field [" + inputFieldName + "] in a node object.");
}
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.Nullable;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.WorkflowResources;
Expand All @@ -28,6 +29,7 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -155,7 +157,8 @@ public void onFailure(Exception e) {
String description = (String) inputs.get(DESCRIPTION_FIELD);
String llmModelId = (String) inputs.get(LLM_MODEL_ID);
Map<String, String> llmParameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), LLM_PARAMETERS);
List<MLToolSpec> tools = getTools(previousNodeInputs, outputs);
String[] tools = (String[]) inputs.get(TOOLS_FIELD);
List<MLToolSpec> toolsList = getTools(tools, previousNodeInputs, outputs);
Map<String, String> parameters = getStringToStringMap(inputs.get(PARAMETERS_FIELD), PARAMETERS_FIELD);
MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD));
Instant createdTime = Instant.ofEpochMilli((Long) inputs.get(CREATED_TIME));
Expand Down Expand Up @@ -188,7 +191,7 @@ public void onFailure(Exception e) {

builder.type(type)
.llm(llmSpec)
.tools(tools)
.tools(toolsList)
.parameters(parameters)
.memory(memory)
.createdTime(createdTime)
Expand All @@ -210,24 +213,25 @@ public String getName() {
return NAME;
}

private List<MLToolSpec> getTools(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
private List<MLToolSpec> getTools(@Nullable String[] tools, Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
List<MLToolSpec> mlToolSpecList = new ArrayList<>();
List<String> previousNodes = previousNodeInputs.entrySet()
.stream()
.filter(e -> TOOLS_FIELD.equals(e.getValue()))
.map(Map.Entry::getKey)
.collect(Collectors.toList());

if (previousNodes != null) {
previousNodes.forEach((previousNode) -> {
WorkflowData previousNodeOutput = outputs.get(previousNode);
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) {
MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD);
logger.info("Tool added {}", mlToolSpec.getType());
mlToolSpecList.add(mlToolSpec);
}
});
}
// Anything in tools is sorted first, followed by anything else in previous node inputs
List<String> sortedNodes = tools == null ? new ArrayList<>() : Arrays.asList(tools);
previousNodes.removeAll(sortedNodes);
sortedNodes.addAll(previousNodes);
sortedNodes.forEach((node) -> {
WorkflowData previousNodeOutput = outputs.get(node);
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(TOOLS_FIELD)) {
MLToolSpec mlToolSpec = (MLToolSpec) previousNodeOutput.getContent().get(TOOLS_FIELD);
logger.info("Tool added {}", mlToolSpec.getType());
mlToolSpecList.add(mlToolSpec);
}
});
return mlToolSpecList;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public void testNode() throws IOException {
Map.entry("bar", Map.of("key", "value")),
Map.entry("baz", new Map<?, ?>[] { Map.of("A", "a"), Map.of("B", "b") }),
Map.entry("processors", new PipelineProcessor[] { new PipelineProcessor("test-type", Map.of("key2", "value2")) }),
Map.entry("created_time", 1689793598499L)
Map.entry("created_time", 1689793598499L),
Map.entry("tools", new String[] { "foo", "bar" })
)
);
assertEquals("A", nodeA.id());
Expand All @@ -46,6 +47,7 @@ public void testNode() throws IOException {
assertEquals("test-type", pp[0].type());
assertEquals(Map.of("key2", "value2"), pp[0].params());
assertEquals(1689793598499L, map.get("created_time"));
assertArrayEquals(new String[] { "foo", "bar" }, (String[]) map.get("tools"));

// node equality is based only on ID
WorkflowNode nodeA2 = new WorkflowNode("A", "a2-type", Map.of(), Map.of("bar", "baz"));
Expand All @@ -63,6 +65,7 @@ public void testNode() throws IOException {
assertTrue(json.contains("\"bar\":{\"key\":\"value\"}"));
assertTrue(json.contains("\"processors\":[{\"type\":\"test-type\",\"params\":{\"key2\":\"value2\"}}]"));
assertTrue(json.contains("\"created_time\":1689793598499"));
assertTrue(json.contains("\"tools\":[\"foo\",\"bar\"]"));

WorkflowNode nodeX = WorkflowNode.parse(TemplateTestJsonUtil.jsonToParser(json));
assertEquals("A", nodeX.id());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -54,10 +52,6 @@ public void setUp() throws Exception {
this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false);

LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap());

Map<?, ?> mlMemorySpec = Map.ofEntries(
Map.entry(MLMemorySpec.MEMORY_TYPE_FIELD, "type"),
Map.entry(MLMemorySpec.SESSION_ID_FIELD, "abc"),
Expand All @@ -71,7 +65,7 @@ public void setUp() throws Exception {
Map.entry("type", "type"),
Map.entry("llm.model_id", "xyz"),
Map.entry("llm.parameters", Collections.emptyMap()),
Map.entry("tools", tools),
Map.entry("tools", new String[] { "abc", "xyz" }),
Map.entry("parameters", Collections.emptyMap()),
Map.entry("memory", mlMemorySpec),
Map.entry("created_time", 1689793598499L),
Expand Down

0 comments on commit 597640f

Please sign in to comment.