Skip to content

Commit

Permalink
addressed comments and added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 15, 2023
1 parent 5583855 commit e6c9ab4
Show file tree
Hide file tree
Showing 23 changed files with 153 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ public void updateFlowFrameworkSystemIndexDocWithScript(
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage());
listener.onFailure(e);
listener.onFailure(
new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e))

Check warning on line 435 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L431-L435

Added lines #L431 - L435 were not covered by tests
);
}

Check warning on line 437 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L437

Added line #L437 was not covered by tests
}
}

Check warning on line 439 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L439

Added line #L439 was not covered by tests

}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
try {
// generating random workflowId only for validation purpose
String uniqueID = UUID.randomUUID().toString();
validateWorkflows(templateWithUser, uniqueID);
validateWorkflows(templateWithUser);
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for template : " + templateWithUser.name());
Expand Down Expand Up @@ -215,9 +215,9 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
}));
}

private void validateWorkflows(Template template, String testWorkflowID) throws Exception {
private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, testWorkflowID);
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);

Check warning on line 220 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L220

Added line #L220 was not covered by tests
workflowProcessSorter.validateGraph(sortedNodes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -70,16 +71,14 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {

Check warning on line 73 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L69-L73

Added lines #L69 - L73 were not covered by tests
logger.debug("Completed Get Workflow Status Request, id:{}", workflowId);

if (r != null && r.isExists()) {
try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
WorkflowState workflowState = WorkflowState.parse(parser);
listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll()));
} catch (Exception e) {
logger.error("Failed to parse workflowState" + r.getId(), e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST));
}

Check warning on line 82 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L75-L82

Added lines #L75 - L82 were not covered by tests
} else {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));

Check warning on line 84 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L84

Added line #L84 was not covered by tests
Expand All @@ -88,13 +87,13 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
if (e instanceof IndexNotFoundException) {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));

Check warning on line 88 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L88

Added line #L88 was not covered by tests
} else {
logger.error("Failed to get workflow status of " + workflowId, e);
listener.onFailure(e);
logger.error("Failed to get workflow status of: " + workflowId, e);
listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND));

Check warning on line 91 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L90-L91

Added lines #L90 - L91 were not covered by tests
}
}), () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to get workflow: " + workflowId, e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

Check warning on line 98 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L93-L98

Added lines #L93 - L98 were not covered by tests
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws I
@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
createConnectorFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())))
new WorkflowData(
Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())),
data.get(0).getWorkflowId()
)
);
try {
logger.info("Created connector successfully");
Expand All @@ -105,7 +108,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
workflowId,
script,
ActionListener.wrap(updateResponse -> {
logger.info("updated resources craeted of {}", workflowId);
logger.info("updated resources created of {}", workflowId);
}, exception -> {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))

Check warning on line 114 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L111-L114

Added lines #L111 - L114 were not covered by tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
@Override
public void onResponse(CreateIndexResponse createIndexResponse) {
logger.info("created index: {}", createIndexResponse.index());
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index())));
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("Created ingest pipeline : " + putPipelineRequest.getId());

// PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead
createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId())));
createIngestPipelineFuture.complete(
new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId())
);

// TODO : Use node client to index response data to global context (pending global context index implementation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))
new WorkflowData(
Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())),
data.get(0).getWorkflowId()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name()))
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.common.Nullable;

import java.util.Collections;
import java.util.Map;

Expand All @@ -23,6 +25,8 @@ public class WorkflowData {

private final Map<String, Object> content;
private final Map<String, String> params;

@Nullable
private String workflowId;

private WorkflowData() {
Expand All @@ -32,20 +36,10 @@ private WorkflowData() {
/**
* Instantiate this object with content and empty params.
* @param content The content map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content) {
this(content, Collections.emptyMap());
}

/**
* Instantiate this object with content and params.
* @param content The content map
* @param params The params map
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = "";
public WorkflowData(Map<String, Object> content, @Nullable String workflowId) {
this(content, Collections.emptyMap(), workflowId);
}

/**
Expand All @@ -54,7 +48,7 @@ public WorkflowData(Map<String, Object> content, Map<String, String> params) {
* @param params The params map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params, String workflowId) {
public WorkflowData(Map<String, Object> content, Map<String, String> params, @Nullable String workflowId) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = workflowId;
Expand All @@ -81,6 +75,7 @@ public Map<String, String> getParams() {
* Returns the workflowId associated with this workflow.
* @return the workflowId of this data.
*/
@Nullable
public String getWorkflowId() {
return this.workflowId;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.rest;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestChannel;
import org.opensearch.test.rest.FakeRestRequest;

import java.util.List;
import java.util.Locale;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class RestGetWorkflowActionTests extends OpenSearchTestCase {
private RestSearchWorkflowAction restSearchWorkflowAction;
private String searchPath;
private NodeClient nodeClient;
private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting;

@Override
public void setUp() throws Exception {
super.setUp();

this.searchPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_search");
flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class);
when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true);
this.restSearchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting);
this.nodeClient = mock(NodeClient.class);
}

public void testConstructor() {
RestSearchWorkflowAction searchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting);
assertNotNull(searchWorkflowAction);
}

public void testRestSearchWorkflowActionName() {
String name = restSearchWorkflowAction.getName();
assertEquals("search_workflow_action", name);
}

public void testRestSearchWorkflowActionRoutes() {
List<RestHandler.Route> routes = restSearchWorkflowAction.routes();
assertNotNull(routes);
assertEquals(2, routes.size());
assertEquals(RestRequest.Method.POST, routes.get(0).getMethod());
assertEquals(RestRequest.Method.GET, routes.get(1).getMethod());
assertEquals(this.searchPath, routes.get(0).getPath());
assertEquals(this.searchPath, routes.get(1).getPath());
}

public void testInvalidSearchRequest() {
final String requestContent = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"template\":\"1.0.0\"}}]}}}";
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET)
.withPath(this.searchPath)
.withContent(new BytesArray(requestContent), MediaTypeRegistry.JSON)
.build();

XContentParseException ex = expectThrows(XContentParseException.class, () -> {
restSearchWorkflowAction.prepareRequest(request, nodeClient);
});
assertEquals("unknown named object category [org.opensearch.index.query.QueryBuilder]", ex.getMessage());
}

public void testFeatureFlagNotEnabled() throws Exception {
when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false);
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.searchPath)
.build();
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
restSearchWorkflowAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled."));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ public void setUp() throws Exception {
Map.entry(CommonValue.PARAMETERS_FIELD, params),
Map.entry(CommonValue.CREDENTIALS_FIELD, credentials),
Map.entry(CommonValue.ACTIONS_FIELD, actions)
)
),
"test-id"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase {
public void setUp() throws Exception {
super.setUp();
MockitoAnnotations.openMocks(this);
inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")));
inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id");
clusterService = mock(ClusterService.class);
client = mock(Client.class);
adminClient = mock(AdminClient.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ public void setUp() throws Exception {
Map.entry("model_id", "model_id"),
Map.entry("input_field_name", "inputField"),
Map.entry("output_field_name", "outputField")
)
),
"test-id"
);

// Set output data to returned pipelineId
outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")));
outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id");

client = mock(Client.class);
adminClient = mock(AdminClient.class);
Expand Down Expand Up @@ -113,7 +114,8 @@ public void testMissingData() throws InterruptedException {
Map.entry("description", "some description"),
Map.entry("type", "text_embedding"),
Map.entry("model_id", "model_id")
)
),
"test-id"
);

CompletableFuture<WorkflowData> future = createIngestPipelineStep.execute(List.of(incorrectData));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase {
public void setUp() throws Exception {
super.setUp();

inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")));
inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id");

MockitoAnnotations.openMocks(this);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);
this.getMLTaskStep = new GetMLTaskStep(mlNodeClient);
this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")));
this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id");
}

public void testGetMLTaskSuccess() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public void setUp() throws Exception {
Map.entry("backend_roles", ImmutableList.of("role-1")),
Map.entry("access_mode", AccessMode.PUBLIC),
Map.entry("add_all_backend_roles", false)
)
),
"test-id"
);

}
Expand Down
Loading

0 comments on commit e6c9ab4

Please sign in to comment.