Skip to content

Commit

Permalink
addressed comments and added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 15, 2023
1 parent 5583855 commit 3ee4d72
Show file tree
Hide file tree
Showing 29 changed files with 317 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ public void updateFlowFrameworkSystemIndexDocWithScript(
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage());
listener.onFailure(e);
listener.onFailure(
new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e))

Check warning on line 435 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L431-L435

Added lines #L431 - L435 were not covered by tests
);
}

Check warning on line 437 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L437

Added line #L437 was not covered by tests
}
}

Check warning on line 439 in src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java#L439

Added line #L439 was not covered by tests

}
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,9 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeOptionalInstant(provisionEndTime);

if (user != null) {
output.writeBoolean(true); // user exists
user.writeTo(output);

Check warning on line 293 in src/main/java/org/opensearch/flowframework/model/WorkflowState.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/model/WorkflowState.java#L293

Added line #L293 was not covered by tests
} else {
output.writeBoolean(false); // user does not exist
output.writeBoolean(false);
}

if (userOutputs != null) {
Expand Down Expand Up @@ -407,7 +406,7 @@ public String getWorkflowId() {
* @return the error
*/
public String getError() {
return workflowId;
return error;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception));
XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder));

Check warning on line 88 in src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java#L86-L88

Added lines #L86 - L88 were not covered by tests

} catch (IOException e) {
logger.error("Failed to send back provision workflow exception", e);
channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage()));
}
}));

Check warning on line 94 in src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java#L90-L94

Added lines #L90 - L94 were not covered by tests

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
try {
// generating random workflowId only for validation purpose
String uniqueID = UUID.randomUUID().toString();
validateWorkflows(templateWithUser, uniqueID);
validateWorkflows(templateWithUser);
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for template : " + templateWithUser.name());
Expand Down Expand Up @@ -215,9 +215,9 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
}));
}

private void validateWorkflows(Template template, String testWorkflowID) throws Exception {
private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, testWorkflowID);
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);

Check warning on line 220 in src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java#L220

Added line #L220 was not covered by tests
workflowProcessSorter.validateGraph(sortedNodes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class GetWorkflowResponse extends ActionResponse implements ToXContentObj
public GetWorkflowResponse(StreamInput in) throws IOException {
super(in);
workflowState = new WorkflowState(in);
allStatus = false;
allStatus = in.readBoolean();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -46,7 +47,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction<GetWorkfl

/**
* Intantiates a new CreateWorkflowTransportAction
* @param transportService the TransportService
* @param transportService The TransportService
* @param actionFilters action filters
* @param client The client used to make the request to OS
* @param xContentRegistry contentRegister to parse get response
Expand All @@ -70,16 +71,14 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
logger.debug("Completed Get Workflow Status Request, id:{}", workflowId);

if (r != null && r.isExists()) {
try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
WorkflowState workflowState = WorkflowState.parse(parser);
listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll()));
} catch (Exception e) {
logger.error("Failed to parse workflowState" + r.getId(), e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST));
}

Check warning on line 82 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L75-L82

Added lines #L75 - L82 were not covered by tests
} else {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));

Check warning on line 84 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L84

Added line #L84 was not covered by tests
Expand All @@ -88,13 +87,13 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
if (e instanceof IndexNotFoundException) {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));

Check warning on line 88 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L88

Added line #L88 was not covered by tests
} else {
logger.error("Failed to get workflow status of " + workflowId, e);
listener.onFailure(e);
logger.error("Failed to get workflow status of: " + workflowId, e);
listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND));

Check warning on line 91 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L90-L91

Added lines #L90 - L91 were not covered by tests
}
}), () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to get workflow: " + workflowId, e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));

Check warning on line 96 in src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java#L93-L96

Added lines #L93 - L96 were not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());
// TODO: Add the workflowID to the workflow data so I can update the state Index.

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
Expand Down Expand Up @@ -179,13 +177,14 @@ private void executeWorkflowAsync(String workflowId, List<ProcessNode> workflowS

}, exception -> {
logger.error("Provisioning failed for workflow {} : {}", workflowId, exception);

flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
ImmutableMap.of(
STATE_FIELD,
State.FAILED,
ERROR_FIELD,
"failed provision", // TODO: improve the error message here
exception.getMessage(), // TODO: potentially improve the error message here
PROVISIONING_PROGRESS_FIELD,
ProvisioningProgress.FAILED,
PROVISION_END_TIME_FIELD,
Expand Down Expand Up @@ -233,7 +232,7 @@ private void executeWorkflow(List<ProcessNode> workflowSequence, ActionListener<
// Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally
workflowFutureList.forEach(CompletableFuture::join);

// workflowListener.onResponse("READY");
workflowListener.onResponse("READY");

} catch (IllegalArgumentException e) {
workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws I
@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
createConnectorFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())))
new WorkflowData(
Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())),
data.get(0).getWorkflowId()
)
);
try {
logger.info("Created connector successfully");
Expand All @@ -105,7 +108,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
workflowId,
script,
ActionListener.wrap(updateResponse -> {
logger.info("updated resources craeted of {}", workflowId);
logger.info("updated resources created of {}", workflowId);
}, exception -> {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))

Check warning on line 114 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L111-L114

Added lines #L111 - L114 were not covered by tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
@Override
public void onResponse(CreateIndexResponse createIndexResponse) {
logger.info("created index: {}", createIndexResponse.index());
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index())));
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("Created ingest pipeline : " + putPipelineRequest.getId());

// PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead
createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId())));
createIngestPipelineFuture.complete(
new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId())
);

// TODO : Use node client to index response data to global context (pending global context index implementation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))
new WorkflowData(
Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())),
data.get(0).getWorkflowId()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name()))
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.common.Nullable;

import java.util.Collections;
import java.util.Map;

Expand All @@ -23,6 +25,8 @@ public class WorkflowData {

private final Map<String, Object> content;
private final Map<String, String> params;

@Nullable
private String workflowId;

private WorkflowData() {
Expand All @@ -32,20 +36,10 @@ private WorkflowData() {
/**
* Instantiate this object with content and empty params.
* @param content The content map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content) {
this(content, Collections.emptyMap());
}

/**
* Instantiate this object with content and params.
* @param content The content map
* @param params The params map
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = "";
public WorkflowData(Map<String, Object> content, @Nullable String workflowId) {
this(content, Collections.emptyMap(), workflowId);
}

/**
Expand All @@ -54,7 +48,7 @@ public WorkflowData(Map<String, Object> content, Map<String, String> params) {
* @param params The params map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params, String workflowId) {
public WorkflowData(Map<String, Object> content, Map<String, String> params, @Nullable String workflowId) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = workflowId;
Expand All @@ -81,6 +75,7 @@ public Map<String, String> getParams() {
* Returns the workflowId associated with this workflow.
* @return the workflowId of this data.
*/
@Nullable
public String getWorkflowId() {
return this.workflowId;
};
Expand Down
14 changes: 14 additions & 0 deletions src/test/java/org/opensearch/flowframework/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -40,4 +46,12 @@ public static ClusterSettings clusterSetting(Settings settings, Setting<?>... se
ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet);
return clusterSettings;
}

public static XContentBuilder builder() throws IOException {
return XContentBuilder.builder(XContentType.JSON.xContent());
}

public static Map<String, Object> XContentBuilderToMap(XContentBuilder builder) {
return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
}
}
Loading

0 comments on commit 3ee4d72

Please sign in to comment.