diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index bcafdab9b..7dcc89de6 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -431,9 +431,10 @@ public void updateFlowFrameworkSystemIndexDocWithScript( client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(e); + listener.onFailure( + new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e)) + ); } } } - } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 19e2df21a..d53ce4412 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -100,7 +100,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, testWorkflowID); + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); workflowProcessSorter.validateGraph(sortedNodes); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 1a59d033b..3eb788fb1 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -70,8 +71,6 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener { - logger.debug("Completed Get Workflow Status Request, id:{}", workflowId); - if (r != null && r.isExists()) { try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -79,7 +78,7 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener context.restore())); } catch (Exception e) { logger.error("Failed to get workflow: " + workflowId, e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 1e067a0bc..ee4dcd166 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -82,7 +82,10 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { createConnectorFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), + data.get(0).getWorkflowId() + ) ); try { logger.info("Created connector successfully"); @@ -105,7 +108,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { workflowId, script, ActionListener.wrap(updateResponse -> { - logger.info("updated resources craeted of {}", workflowId); + logger.info("updated resources created of {}", workflowId); }, exception -> { createConnectorFuture.completeExceptionally( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 5fe47b2b0..f3a82b26c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -61,7 +61,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index b8cc83651..a63a800fd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -125,7 +125,9 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()))); + createIngestPipelineFuture.complete( + new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index ba22f3682..8ce89176c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -50,7 +50,10 @@ public CompletableFuture execute(List data) { public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); deployModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus()))) + new WorkflowData( + Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), + data.get(0).getWorkflowId() + ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index 893f34a0d..ac84aaaa0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -56,7 +56,8 @@ public CompletableFuture execute(List data) { logger.info("ML Task retrieval successful"); getMLTaskFuture.complete( new WorkflowData( - Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())) + Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), + data.get(0).getWorkflowId() ) ); }, exception -> { diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 35a3bdfff..89c15c445 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.ofEntries( Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 17dd0b068..ad6cbff8f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 4dedc8bf2..d91cfc0e8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 83cd33f90..4f62885e9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.common.Nullable; + import java.util.Collections; import java.util.Map; @@ -23,6 +25,8 @@ public class WorkflowData { private final Map content; private final Map params; + + @Nullable private String workflowId; private WorkflowData() { @@ -32,20 +36,10 @@ private WorkflowData() { /** * Instantiate this object with content and empty params. * @param content The content map + * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content) { - this(content, Collections.emptyMap()); - } - - /** - * Instantiate this object with content and params. - * @param content The content map - * @param params The params map - */ - public WorkflowData(Map content, Map params) { - this.content = Map.copyOf(content); - this.params = Map.copyOf(params); - this.workflowId = ""; + public WorkflowData(Map content, @Nullable String workflowId) { + this(content, Collections.emptyMap(), workflowId); } /** @@ -54,7 +48,7 @@ public WorkflowData(Map content, Map params) { * @param params The params map * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content, Map params, String workflowId) { + public WorkflowData(Map content, Map params, @Nullable String workflowId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); this.workflowId = workflowId; @@ -81,6 +75,7 @@ public Map getParams() { * Returns the workflowId associated with this workflow. * @return the workflowId of this data. */ + @Nullable public String getWorkflowId() { return this.workflowId; }; diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java new file mode 100644 index 000000000..7e5120849 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestGetWorkflowActionTests extends OpenSearchTestCase { + private RestSearchWorkflowAction restSearchWorkflowAction; + private String searchPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.searchPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_search"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restSearchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testConstructor() { + RestSearchWorkflowAction searchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting); + assertNotNull(searchWorkflowAction); + } + + public void testRestSearchWorkflowActionName() { + String name = restSearchWorkflowAction.getName(); + assertEquals("search_workflow_action", name); + } + + public void testRestSearchWorkflowActionRoutes() { + List routes = restSearchWorkflowAction.routes(); + assertNotNull(routes); + assertEquals(2, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(RestRequest.Method.GET, routes.get(1).getMethod()); + assertEquals(this.searchPath, routes.get(0).getPath()); + assertEquals(this.searchPath, routes.get(1).getPath()); + } + + public void testInvalidSearchRequest() { + final String requestContent = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"template\":\"1.0.0\"}}]}}}"; + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET) + .withPath(this.searchPath) + .withContent(new BytesArray(requestContent), MediaTypeRegistry.JSON) + .build(); + + XContentParseException ex = expectThrows(XContentParseException.class, () -> { + restSearchWorkflowAction.prepareRequest(request, nodeClient); + }); + assertEquals("unknown named object category [org.opensearch.index.query.QueryBuilder]", ex.getMessage()); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.searchPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restSearchWorkflowAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 111b787f8..a05a3927e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -70,7 +70,8 @@ public void setUp() throws Exception { Map.entry(CommonValue.PARAMETERS_FIELD, params), Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 7a4db70a6..67cb6cb9b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 039b0384f..194c80eb0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -50,11 +50,12 @@ public void setUp() throws Exception { Map.entry("model_id", "model_id"), Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") - ) + ), + "test-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId"))); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -113,7 +114,8 @@ public void testMissingData() throws InterruptedException { Map.entry("description", "some description"), Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") - ) + ), + "test-id" ); CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 4cdfaebae..fd856b945 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); MockitoAnnotations.openMocks(this); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index 3a83b1fdd..f5f5f7e7d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -48,7 +48,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); this.getMLTaskStep = new GetMLTaskStep(mlNodeClient); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test"))); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); } public void testGetMLTaskSuccess() throws Exception { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index 8868b628e..f763c8005 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -53,7 +53,8 @@ public void setUp() throws Exception { Map.entry("backend_roles", ImmutableList.of("role-1")), Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 0cac95b49..6aae139e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -58,7 +58,7 @@ public void testNode() throws InterruptedException, ExecutionException { @Override public CompletableFuture execute(List data) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"))); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); return f; } @@ -68,7 +68,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -77,6 +77,7 @@ public String getName() { assertEquals("test", nodeA.workflowStep().getName()); assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); + assertEquals("test-id", nodeA.input().getWorkflowId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d41096624..bd40c50ad 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -69,7 +69,8 @@ public void setUp() throws Exception { Map.entry("embedding_dimension", "384"), Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index ca9d5e7a5..e60707f67 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -54,7 +54,8 @@ public void setUp() throws Exception { Map.entry("name", "xyz"), Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index e2464dace..8a4a1fda9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,13 +26,14 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent); + WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); + assertEquals("test-id-123", contentAndParams.getWorkflowId()); } }