From 18e331ed611900dae7710382151f1269ae00e238 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 29 Nov 2023 13:19:26 -0800 Subject: [PATCH] Addressing PR comments Signed-off-by: Owais Kazi --- .../org/opensearch/flowframework/model/WorkflowNode.java | 8 +++----- .../org/opensearch/flowframework/workflow/ToolStep.java | 2 +- src/main/resources/mappings/workflow-steps.json | 2 +- .../flowframework/workflow/RegisterAgentTests.java | 3 ++- .../opensearch/flowframework/workflow/ToolStepTests.java | 3 ++- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 1066bc09a..84eb327a8 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -26,9 +26,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.LLM_FIELD; -import static org.opensearch.flowframework.util.ParseUtils.buildLLMMap; -import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; -import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.*; /** * This represents a process node (step) in a workflow graph in the {@link Template}. @@ -100,7 +98,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } xContentBuilder.endArray(); - } else if (e.getValue() instanceof Object) { + } else if (e.getValue() instanceof LLMSpec) { if (LLM_FIELD.equals(e.getKey())) { xContentBuilder.startObject(); buildLLMMap(xContentBuilder, (LLMSpec) e.getValue()); @@ -150,7 +148,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { break; case START_OBJECT: if (LLM_FIELD.equals(inputFieldName)) { - userInputs.put(inputFieldName, LLMSpec.parse(parser)); + userInputs.put(inputFieldName, parseLLM(parser)); } else { userInputs.put(inputFieldName, parseStringToStringMap(parser)); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 2e15d56db..8b4e901fc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -95,7 +95,7 @@ public CompletableFuture execute(List data) throws I MLToolSpec mlToolSpec = builder.build(); - toolFuture.complete(new WorkflowData(Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)))); + toolFuture.complete(new WorkflowData(Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), data.get(0).getWorkflowId())); } logger.info("Tool registered successfully {}", type); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 31bd55135..f20bdf3aa 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -109,7 +109,7 @@ "agent_id" ] }, - "tool": { + "create_tool": { "inputs": [ "type" ], diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index bdb3e9f51..e04aa63d7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -68,7 +68,8 @@ public void setUp() throws Exception { Map.entry("created_time", 1689793598499L), Map.entry("last_updated_time", 1689793598499L), Map.entry("app_type", "app") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index a16f78364..079e1c9e2 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -32,7 +32,8 @@ public void setUp() throws Exception { Map.entry("description", "description"), Map.entry("parameters", Collections.emptyMap()), Map.entry("include_output_in_agent_response", false) - ) + ), + "test-id" ); }