Skip to content

Commit

Permalink
addressed comments and added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Nov 15, 2023
1 parent 5583855 commit 3ee4d72
Show file tree
Hide file tree
Showing 29 changed files with 317 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ public void updateFlowFrameworkSystemIndexDocWithScript(
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage());
listener.onFailure(e);
listener.onFailure(
new FlowFrameworkException("Failed to update " + indexName + "entry: " + documentId, ExceptionsHelper.status(e))
);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,9 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeOptionalInstant(provisionEndTime);

if (user != null) {
output.writeBoolean(true); // user exists
user.writeTo(output);
} else {
output.writeBoolean(false); // user does not exist
output.writeBoolean(false);
}

if (userOutputs != null) {
Expand Down Expand Up @@ -407,7 +406,7 @@ public String getWorkflowId() {
* @return the error
*/
public String getError() {
return workflowId;
return error;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception));
XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder));

} catch (IOException e) {
logger.error("Failed to send back provision workflow exception", e);
channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage()));
}
}));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
try {
// generating random workflowId only for validation purpose
String uniqueID = UUID.randomUUID().toString();
validateWorkflows(templateWithUser, uniqueID);
validateWorkflows(templateWithUser);
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error("Workflow validation failed for template : " + templateWithUser.name());
Expand Down Expand Up @@ -215,9 +215,9 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
}));
}

private void validateWorkflows(Template template, String testWorkflowID) throws Exception {
private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, testWorkflowID);
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
workflowProcessSorter.validateGraph(sortedNodes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class GetWorkflowResponse extends ActionResponse implements ToXContentObj
public GetWorkflowResponse(StreamInput in) throws IOException {
super(in);
workflowState = new WorkflowState(in);
allStatus = false;
allStatus = in.readBoolean();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -46,7 +47,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction<GetWorkfl

/**
* Intantiates a new CreateWorkflowTransportAction
* @param transportService the TransportService
* @param transportService The TransportService
* @param actionFilters action filters
* @param client The client used to make the request to OS
* @param xContentRegistry contentRegister to parse get response
Expand All @@ -70,16 +71,14 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
logger.debug("Completed Get Workflow Status Request, id:{}", workflowId);

if (r != null && r.isExists()) {
try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
WorkflowState workflowState = WorkflowState.parse(parser);
listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll()));
} catch (Exception e) {
logger.error("Failed to parse workflowState" + r.getId(), e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST));
}
} else {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));
Expand All @@ -88,13 +87,13 @@ protected void doExecute(Task task, GetWorkflowRequest request, ActionListener<G
if (e instanceof IndexNotFoundException) {
listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND));
} else {
logger.error("Failed to get workflow status of " + workflowId, e);
listener.onFailure(e);
logger.error("Failed to get workflow status of: " + workflowId, e);
listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND));
}
}), () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to get workflow: " + workflowId, e);
listener.onFailure(e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());
// TODO: Add the workflowID to the workflow data so I can update the state Index.

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
Expand Down Expand Up @@ -179,13 +177,14 @@ private void executeWorkflowAsync(String workflowId, List<ProcessNode> workflowS

}, exception -> {
logger.error("Provisioning failed for workflow {} : {}", workflowId, exception);

flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
ImmutableMap.of(
STATE_FIELD,
State.FAILED,
ERROR_FIELD,
"failed provision", // TODO: improve the error message here
exception.getMessage(), // TODO: potentially improve the error message here
PROVISIONING_PROGRESS_FIELD,
ProvisioningProgress.FAILED,
PROVISION_END_TIME_FIELD,
Expand Down Expand Up @@ -233,7 +232,7 @@ private void executeWorkflow(List<ProcessNode> workflowSequence, ActionListener<
// Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally
workflowFutureList.forEach(CompletableFuture::join);

// workflowListener.onResponse("READY");
workflowListener.onResponse("READY");

} catch (IllegalArgumentException e) {
workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws I
@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
createConnectorFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())))
new WorkflowData(
Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())),
data.get(0).getWorkflowId()
)
);
try {
logger.info("Created connector successfully");
Expand All @@ -105,7 +108,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
workflowId,
script,
ActionListener.wrap(updateResponse -> {
logger.info("updated resources craeted of {}", workflowId);
logger.info("updated resources created of {}", workflowId);
}, exception -> {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
@Override
public void onResponse(CreateIndexResponse createIndexResponse) {
logger.info("created index: {}", createIndexResponse.index());
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index())));
future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("Created ingest pipeline : " + putPipelineRequest.getId());

// PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead
createIngestPipelineFuture.complete(new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId())));
createIngestPipelineFuture.complete(
new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId())
);

// TODO : Use node client to index response data to global context (pending global context index implementation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))
new WorkflowData(
Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())),
data.get(0).getWorkflowId()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name()))
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())),
data.get(0).getWorkflowId()
)
);
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Map.ofEntries(
Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
),
data.get(0).getWorkflowId()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.common.Nullable;

import java.util.Collections;
import java.util.Map;

Expand All @@ -23,6 +25,8 @@ public class WorkflowData {

private final Map<String, Object> content;
private final Map<String, String> params;

@Nullable
private String workflowId;

private WorkflowData() {
Expand All @@ -32,20 +36,10 @@ private WorkflowData() {
/**
* Instantiate this object with content and empty params.
* @param content The content map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content) {
this(content, Collections.emptyMap());
}

/**
* Instantiate this object with content and params.
* @param content The content map
* @param params The params map
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = "";
public WorkflowData(Map<String, Object> content, @Nullable String workflowId) {
this(content, Collections.emptyMap(), workflowId);
}

/**
Expand All @@ -54,7 +48,7 @@ public WorkflowData(Map<String, Object> content, Map<String, String> params) {
* @param params The params map
* @param workflowId The workflow ID associated with this step
*/
public WorkflowData(Map<String, Object> content, Map<String, String> params, String workflowId) {
public WorkflowData(Map<String, Object> content, Map<String, String> params, @Nullable String workflowId) {
this.content = Map.copyOf(content);
this.params = Map.copyOf(params);
this.workflowId = workflowId;
Expand All @@ -81,6 +75,7 @@ public Map<String, String> getParams() {
* Returns the workflowId associated with this workflow.
* @return the workflowId of this data.
*/
@Nullable
public String getWorkflowId() {
return this.workflowId;
};
Expand Down
14 changes: 14 additions & 0 deletions src/test/java/org/opensearch/flowframework/TestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -40,4 +46,12 @@ public static ClusterSettings clusterSetting(Settings settings, Setting<?>... se
ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet);
return clusterSettings;
}

public static XContentBuilder builder() throws IOException {
return XContentBuilder.builder(XContentType.JSON.xContent());
}

public static Map<String, Object> XContentBuilderToMap(XContentBuilder builder) {
return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
}
}
Loading

0 comments on commit 3ee4d72

Please sign in to comment.