Skip to content

Commit

Permalink
Add Util method to fetch inputs from parameters, content, and previou…
Browse files Browse the repository at this point in the history
…s step output (opensearch-project#234)

* Util method to get required inputs

Signed-off-by: Daniel Widdis <[email protected]>

* Implement parsing in some of the steps

Signed-off-by: Daniel Widdis <[email protected]>

* Handle parsing exceptions in the future

Signed-off-by: Daniel Widdis <[email protected]>

* Improve exception handling

Signed-off-by: Daniel Widdis <[email protected]>

* More steps using the new input parsing

Signed-off-by: Daniel Widdis <[email protected]>

* Update Delete Connector Step with parsing util

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 15, 2023
1 parent 9cb621d commit 83af2ba
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 282 deletions.
102 changes: 102 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand All @@ -39,6 +47,9 @@
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

// Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar
private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}");

private ParseUtils() {}

/**
Expand Down Expand Up @@ -161,4 +172,95 @@ public static Map<String, String> getStringToStringMap(Object map, String fieldN
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

/**
* Creates a map containing the specified input keys, with values derived from template data or previous node
* output.
*
* @param requiredInputKeys A set of keys that must be present, or will cause an exception to be thrown
* @param optionalInputKeys A set of keys that may be present, or will be absent in the returned map
* @param currentNodeInputs Input params and content for this node, from workflow parsing
* @param outputs WorkflowData content of previous steps
* @param previousNodeInputs Input params for this node that come from previous steps
* @return A map containing the requiredInputKeys with their corresponding values,
* and optionalInputKeys with their corresponding values if present.
* Throws a {@link FlowFrameworkException} if a required key is not present.
*/
public static Map<String, Object> getInputsFromPreviousSteps(
Set<String> requiredInputKeys,
Set<String> optionalInputKeys,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) {
// Mutable set to ensure all required keys are used
Set<String> requiredKeys = new HashSet<>(requiredInputKeys);
// Merge input sets to add all requested keys
Set<String> keys = new HashSet<>(requiredInputKeys);
keys.addAll(optionalInputKeys);
// Initialize return map
Map<String, Object> inputs = new HashMap<>();
for (String key : keys) {
Object value = null;
// Priority 1: specifically named prior step inputs
// ... parse the previousNodeInputs map and fill in the specified keys
Optional<String> previousNodeForKey = previousNodeInputs.entrySet()
.stream()
.filter(e -> key.equals(e.getValue()))
.map(Map.Entry::getKey)
.findAny();
if (previousNodeForKey.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNodeForKey.get());
if (previousNodeOutput != null) {
value = previousNodeOutput.getContent().get(key);
}
}
// Priority 2: inputs specified in template
// ... fetch from currentNodeInputs (params take precedence)
if (value == null) {
value = currentNodeInputs.getParams().get(key);
}
if (value == null) {
value = currentNodeInputs.getContent().get(key);
}
// Priority 3: other inputs
if (value == null) {
Optional<Object> matchedValue = outputs.values()
.stream()
.map(WorkflowData::getContent)
.filter(m -> m.containsKey(key))
.map(m -> m.get(key))
.findAny();
if (matchedValue.isPresent()) {
value = matchedValue.get();
}
}
// Check for substitution
if (value != null) {
Matcher m = SUBSTITUTION_PATTERN.matcher(value.toString());
if (m.matches()) {
WorkflowData data = outputs.get(m.group(1));
if (data != null && data.getContent().containsKey(m.group(2))) {
value = data.getContent().get(m.group(2));
}
}
inputs.put(key, value);
requiredKeys.remove(key);
}
}
// After iterating is complete, throw exception if requiredKeys is not empty
if (!requiredKeys.isEmpty()) {
throw new FlowFrameworkException(
"Missing required inputs "
+ requiredKeys
+ " in workflow ["
+ currentNodeInputs.getWorkflowId()
+ "] node ["
+ currentNodeInputs.getNodeId()
+ "]",
RestStatus.BAD_REQUEST
);
}
// Finally return the map
return inputs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
Expand All @@ -33,8 +34,8 @@
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD;
Expand Down Expand Up @@ -120,57 +121,44 @@ public void onFailure(Exception e) {
}
};

String name = null;
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = Collections.emptyMap();
Map<String, String> credentials = Collections.emptyMap();
List<ConnectorAction> actions = Collections.emptyList();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());
Set<String> requiredKeys = Set.of(
NAME_FIELD,
DESCRIPTION_FIELD,
VERSION_FIELD,
PROTOCOL_FIELD,
PARAMETERS_FIELD,
CREDENTIAL_FIELD,
ACTIONS_FIELD
);
Set<String> optionalKeys = Collections.emptySet();

try {
for (WorkflowData workflowData : data) {
for (Entry<String, Object> entry : workflowData.getContent().entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case VERSION_FIELD:
version = (String) entry.getValue();
break;
case PROTOCOL_FIELD:
protocol = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(entry.getValue());
break;
case CREDENTIAL_FIELD:
credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD);
break;
case ACTIONS_FIELD:
actions = getConnectorActionList(entry.getValue());
break;
}
}
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String version = (String) inputs.get(VERSION_FIELD);
String protocol = (String) inputs.get(PROTOCOL_FIELD);
Map<String, String> parameters;
Map<String, String> credentials;
List<ConnectorAction> actions;

try {
parameters = getParameterMap(inputs.get(PARAMETERS_FIELD));
credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD);
actions = getConnectorActionList(inputs.get(ACTIONS_FIELD));
} catch (IllegalArgumentException iae) {
throw new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST);
} catch (PrivilegedActionException pae) {
throw new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED);
}
} catch (IllegalArgumentException iae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST));
return createConnectorFuture;
} catch (PrivilegedActionException pae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED));
return createConnectorFuture;
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder()
.name(name)
.description(description)
Expand All @@ -182,12 +170,9 @@ public void onFailure(Exception e) {
.build();

mlClient.createConnector(mlInput, actionListener);
} else {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
} catch (FlowFrameworkException e) {
createConnectorFuture.completeExceptionally(e);
}

return createConnectorFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
Expand Down Expand Up @@ -72,29 +73,23 @@ public void onFailure(Exception e) {
}
};

String connectorId = null;

// Previous Node inputs defines which step the connector ID came from
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> CONNECTOR_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
if (previousNode.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) {
connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString();
}
}
Set<String> requiredKeys = Set.of(CONNECTOR_ID);
Set<String> optionalKeys = Collections.emptySet();

if (connectorId != null) {
mlClient.deleteConnector(connectorId, actionListener);
} else {
deleteConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST)
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);
}
String connectorId = (String) inputs.get(CONNECTOR_ID);

mlClient.deleteConnector(connectorId, actionListener);
} catch (FlowFrameworkException e) {
deleteConnectorFuture.completeExceptionally(e);
}
return deleteConnectorFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.ArrayList;
import java.util.List;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand Down Expand Up @@ -71,27 +71,24 @@ public void onFailure(Exception e) {
}
};

String modelId = null;
Set<String> requiredKeys = Set.of(MODEL_ID);
Set<String> optionalKeys = Collections.emptySet();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

for (WorkflowData workflowData : data) {
if (workflowData.getContent().containsKey(MODEL_ID)) {
modelId = (String) workflowData.getContent().get(MODEL_ID);
break;
}
}
String modelId = (String) inputs.get(MODEL_ID);

if (modelId != null) {
mlClient.deploy(modelId, actionListener);
} else {
deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
deployModelFuture.completeExceptionally(e);
}

return deployModelFuture;
}

Expand Down
Loading

0 comments on commit 83af2ba

Please sign in to comment.