diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index bcafdab9b..7dcc89de6 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -431,9 +431,10 @@ public void updateFlowFrameworkSystemIndexDocWithScript( client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(e); + listener.onFailure( + new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e)) + ); } } } - } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index 3d40c225e..cea14461b 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -290,10 +290,9 @@ public void writeTo(StreamOutput output) throws IOException { output.writeOptionalInstant(provisionEndTime); if (user != null) { - output.writeBoolean(true); // user exists user.writeTo(output); } else { - output.writeBoolean(false); // user does not exist + output.writeBoolean(false); } if (userOutputs != null) { @@ -407,7 +406,7 @@ public String getWorkflowId() { * @return the error */ public String getError() { - return workflowId; + return error; } /** diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 83c662e53..6d9d5e3b5 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -86,8 +86,10 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); } })); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 19e2df21a..d53ce4412 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -100,7 +100,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, testWorkflowID); + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); workflowProcessSorter.validateGraph(sortedNodes); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java index 0d0f96c78..b2c9bb884 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -33,7 +33,7 @@ public class GetWorkflowResponse extends ActionResponse implements ToXContentObj public GetWorkflowResponse(StreamInput in) throws IOException { super(in); workflowState = new WorkflowState(in); - allStatus = false; + allStatus = in.readBoolean(); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 1a59d033b..f3bc1dd9e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -46,7 +47,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction { - logger.debug("Completed Get Workflow Status Request, id:{}", workflowId); - if (r != null && r.isExists()) { try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -79,7 +78,7 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener context.restore())); } catch (Exception e) { logger.error("Failed to get workflow: " + workflowId, e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index bfde9edc3..db64bf23e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -110,8 +110,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); @@ -179,13 +177,14 @@ private void executeWorkflowAsync(String workflowId, List workflowS }, exception -> { logger.error("Provisioning failed for workflow {} : {}", workflowId, exception); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, ImmutableMap.of( STATE_FIELD, State.FAILED, ERROR_FIELD, - "failed provision", // TODO: improve the error message here + exception.getMessage(), // TODO: potentially improve the error message here PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.FAILED, PROVISION_END_TIME_FIELD, @@ -233,7 +232,7 @@ private void executeWorkflow(List workflowSequence, ActionListener< // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally workflowFutureList.forEach(CompletableFuture::join); - // workflowListener.onResponse("READY"); + workflowListener.onResponse("READY"); } catch (IllegalArgumentException e) { workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 1e067a0bc..ee4dcd166 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -82,7 +82,10 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { createConnectorFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), + data.get(0).getWorkflowId() + ) ); try { logger.info("Created connector successfully"); @@ -105,7 +108,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { workflowId, script, ActionListener.wrap(updateResponse -> { - logger.info("updated resources craeted of {}", workflowId); + logger.info("updated resources created of {}", workflowId); }, exception -> { createConnectorFuture.completeExceptionally( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 5fe47b2b0..f3a82b26c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -61,7 +61,7 @@ public CompletableFuture execute(List data) { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()))); + future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index b8cc83651..a63a800fd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -125,7 +125,9 @@ public CompletableFuture execute(List data) { logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead - createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()))); + createIngestPipelineFuture.complete( + new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index ba22f3682..8ce89176c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -50,7 +50,10 @@ public CompletableFuture execute(List data) { public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); deployModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus()))) + new WorkflowData( + Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), + data.get(0).getWorkflowId() + ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index 893f34a0d..ac84aaaa0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -56,7 +56,8 @@ public CompletableFuture execute(List data) { logger.info("ML Task retrieval successful"); getMLTaskFuture.complete( new WorkflowData( - Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())) + Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), + data.get(0).getWorkflowId() ) ); }, exception -> { diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 35a3bdfff..89c15c445 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.ofEntries( Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 17dd0b068..ad6cbff8f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 4dedc8bf2..d91cfc0e8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.ofEntries( Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) - ) + ), + data.get(0).getWorkflowId() ) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 83cd33f90..4f62885e9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.common.Nullable; + import java.util.Collections; import java.util.Map; @@ -23,6 +25,8 @@ public class WorkflowData { private final Map content; private final Map params; + + @Nullable private String workflowId; private WorkflowData() { @@ -32,20 +36,10 @@ private WorkflowData() { /** * Instantiate this object with content and empty params. * @param content The content map + * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content) { - this(content, Collections.emptyMap()); - } - - /** - * Instantiate this object with content and params. - * @param content The content map - * @param params The params map - */ - public WorkflowData(Map content, Map params) { - this.content = Map.copyOf(content); - this.params = Map.copyOf(params); - this.workflowId = ""; + public WorkflowData(Map content, @Nullable String workflowId) { + this(content, Collections.emptyMap(), workflowId); } /** @@ -54,7 +48,7 @@ public WorkflowData(Map content, Map params) { * @param params The params map * @param workflowId The workflow ID associated with this step */ - public WorkflowData(Map content, Map params, String workflowId) { + public WorkflowData(Map content, Map params, @Nullable String workflowId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); this.workflowId = workflowId; @@ -81,6 +75,7 @@ public Map getParams() { * Returns the workflowId associated with this workflow. * @return the workflowId of this data. */ + @Nullable public String getWorkflowId() { return this.workflowId; }; diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 9c3f8a07e..07221297a 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -13,8 +13,14 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import java.io.IOException; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -40,4 +46,12 @@ public static ClusterSettings clusterSetting(Settings settings, Setting... se ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); return clusterSettings; } + + public static XContentBuilder builder() throws IOException { + return XContentBuilder.builder(XContentType.JSON.xContent()); + } + + public static Map XContentBuilderToMap(XContentBuilder builder) { + return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java new file mode 100644 index 000000000..0f6ddab59 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestGetWorkflowActionTests extends OpenSearchTestCase { + private RestGetWorkflowAction restGetWorkflowAction; + private String getPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restGetWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testConstructor() { + RestGetWorkflowAction getWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); + assertNotNull(getWorkflowAction); + } + + public void testRestGetWorkflowActionName() { + String name = restGetWorkflowAction.getName(); + assertEquals("get_workflow", name); + } + + public void testRestGetWorkflowActionRoutes() { + List routes = restGetWorkflowAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.GET, routes.get(0).getMethod()); + assertEquals(this.getPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java new file mode 100644 index 000000000..c3991783d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.Assert; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.mockito.Mockito; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GetWorkflowTransportActionTests extends OpenSearchTestCase { + + private GetWorkflowTransportAction getWorkflowTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Client client; + private ThreadPool threadPool; + private ThreadContext threadContext; + private ActionListener response; + private Task task; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.threadPool = mock(ThreadPool.class); + this.getWorkflowTransportAction = new GetWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + xContentRegistry() + ); + task = Mockito.mock(Task.class); + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + + response = new ActionListener() { + @Override + public void onResponse(GetWorkflowResponse getResponse) { + assertTrue(true); + } + + @Override + public void onFailure(Exception e) {} + }; + + } + + public void testGetTransportAction() throws IOException { + GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest("1234", false); + getWorkflowTransportAction.doExecute(task, getWorkflowRequest, response); + } + + public void testGetAction() { + Assert.assertNotNull(GetWorkflowAction.INSTANCE.name()); + Assert.assertEquals(GetWorkflowAction.INSTANCE.name(), GetWorkflowAction.NAME); + } + + public void testGetAnomalyDetectorRequest() throws IOException { + GetWorkflowRequest request = new GetWorkflowRequest("1234", false); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetWorkflowRequest newRequest = new GetWorkflowRequest(input); + Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); + Assert.assertEquals(request.getAll(), newRequest.getAll()); + Assert.assertNull(newRequest.validate()); + } + + public void testGetAnomalyDetectorResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + String workflowId = randomAlphaOfLength(5); + WorkflowState workFlowState = new WorkflowState( + workflowId, + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + GetWorkflowResponse response = new GetWorkflowResponse(workFlowState, false); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetWorkflowResponse newResponse = new GetWorkflowResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertEquals(map.get("state"), workFlowState.getState()); + Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 111b787f8..a05a3927e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -70,7 +70,8 @@ public void setUp() throws Exception { Map.entry(CommonValue.PARAMETERS_FIELD, params), Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 7a4db70a6..67cb6cb9b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 039b0384f..194c80eb0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -50,11 +50,12 @@ public void setUp() throws Exception { Map.entry("model_id", "model_id"), Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") - ) + ), + "test-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId"))); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -113,7 +114,8 @@ public void testMissingData() throws InterruptedException { Map.entry("description", "some description"), Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") - ) + ), + "test-id" ); CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 4cdfaebae..fd856b945 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId"))); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); MockitoAnnotations.openMocks(this); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index 3a83b1fdd..f5f5f7e7d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -48,7 +48,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); this.getMLTaskStep = new GetMLTaskStep(mlNodeClient); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test"))); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); } public void testGetMLTaskSuccess() throws Exception { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index 8868b628e..f763c8005 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -53,7 +53,8 @@ public void setUp() throws Exception { Map.entry("backend_roles", ImmutableList.of("role-1")), Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 0cac95b49..6aae139e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -58,7 +58,7 @@ public void testNode() throws InterruptedException, ExecutionException { @Override public CompletableFuture execute(List data) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"))); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); return f; } @@ -68,7 +68,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -77,6 +77,7 @@ public String getName() { assertEquals("test", nodeA.workflowStep().getName()); assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); + assertEquals("test-id", nodeA.input().getWorkflowId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d41096624..bd40c50ad 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -69,7 +69,8 @@ public void setUp() throws Exception { Map.entry("embedding_dimension", "384"), Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index ca9d5e7a5..e60707f67 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -54,7 +54,8 @@ public void setUp() throws Exception { Map.entry("name", "xyz"), Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") - ) + ), + "test-id" ); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index e2464dace..8a4a1fda9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,13 +26,14 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent); + WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); + assertEquals("test-id-123", contentAndParams.getWorkflowId()); } }