diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 5d9692006..0bac15c61 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -9,28 +9,53 @@ package org.opensearch.flowframework; import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; +import org.opensearch.flowframework.transport.CreateWorkflowAction; +import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.ProvisionWorkflowAction; +import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; +import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; +import java.util.List; import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; + /** * An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch. */ -public class FlowFrameworkPlugin extends Plugin { +public class FlowFrameworkPlugin extends Plugin implements ActionPlugin { /** * Instantiate this plugin. @@ -54,6 +79,45 @@ public Collection createComponents( WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter); + // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep + GlobalContextHandler globalContextHandler = new GlobalContextHandler(client, new CreateIndexStep(clusterService, client)); + + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, globalContextHandler); + } + + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + return ImmutableList.of(new RestCreateWorkflowAction(), new RestProvisionWorkflowAction()); + } + + @Override + public List> getActions() { + return ImmutableList.of( + new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), + new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class) + ); } + + @Override + public List> getExecutorBuilders(Settings settings) { + // TODO : Determine final size/queueSize values for the provision thread pool + return ImmutableList.of( + new FixedExecutorBuilder( + settings, + PROVISION_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(settings), + 10, + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL + ) + ); + } + } diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 0bf8ae890..eb67f2dc4 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -13,19 +13,51 @@ */ public class CommonValue { + /** Default value for no schema version */ public static Integer NO_SCHEMA_VERSION = 0; + /** Index mapping meta field name*/ public static final String META = "_meta"; + /** Schema Version field name */ public static final String SCHEMA_VERSION_FIELD = "schema_version"; + /** Global Context Index Name */ public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context"; + /** Global Context index mapping file path */ public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; + /** Global Context index mapping version */ public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + + /** The transport action name prefix */ + public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; + /** The base URI for this plugin's rest actions */ + public static final String FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework"; + /** The URI for this plugin's workflow rest actions */ + public static final String WORKFLOW_URI = FLOW_FRAMEWORK_BASE_URI + "/workflow"; + /** Field name for workflow Id, the document Id of the indexed use case template */ + public static final String WORKFLOW_ID = "workflow_id"; + /** The field name for provision workflow within a use case template*/ + public static final String PROVISION_WORKFLOW = "provision"; + + /** Flow Framework plugin thread pool name prefix */ + public static final String FLOW_FRAMEWORK_THREAD_POOL_PREFIX = "thread_pool.flow_framework."; + /** The provision workflow thread pool name */ + public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; + + /** Model Id field */ public static final String MODEL_ID = "model_id"; + /** Function Name field */ public static final String FUNCTION_NAME = "function_name"; + /** Model Name field */ public static final String MODEL_NAME = "name"; + /** Model Version field */ public static final String MODEL_VERSION = "model_version"; + /** Model Group Id field */ public static final String MODEL_GROUP_ID = "model_group_id"; + /** Description field */ public static final String DESCRIPTION = "description"; + /** Connector Id field */ public static final String CONNECTOR_ID = "connector_id"; + /** Model format field */ public static final String MODEL_FORMAT = "model_format"; + /** Model config field */ public static final String MODEL_CONFIG = "model_config"; } diff --git a/src/main/java/org/opensearch/flowframework/common/TemplateUtil.java b/src/main/java/org/opensearch/flowframework/common/TemplateUtil.java new file mode 100644 index 000000000..80e59ce81 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/common/TemplateUtil.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.common; + +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * Utility methods for Template parsing + */ +public class TemplateUtil { + + /** + * Converts a JSON string into an XContentParser + * + * @param json the json string + * @return The XContent parser for the json string + * @throws IOException on failure to create the parser + */ + public static XContentParser jsonToParser(String json) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return parser; + } + + /** + * Builds an XContent object representing a map of String keys to String values. + * + * @param xContentBuilder An XContent builder whose position is at the start of the map object to build + * @param map A map as key-value String pairs. + * @throws IOException on a build failure + */ + public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map map) throws IOException { + xContentBuilder.startObject(); + for (Entry e : map.entrySet()) { + xContentBuilder.field((String) e.getKey(), (String) e.getValue()); + } + xContentBuilder.endObject(); + } + + /** + * Parses an XContent object representing a map of String keys to String values. + * + * @param parser An XContent parser whose position is at the start of the map object to parse + * @return A map as identified by the key-value pairs in the XContent + * @throws IOException on a parse failure + */ + public static Map parseStringToStringMap(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Map map = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + map.put(fieldName, parser.text()); + } + return map; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java index 6508fb9f7..f3cb55950 100644 --- a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java +++ b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java @@ -17,6 +17,7 @@ public class FlowFrameworkException extends RuntimeException { private static final long serialVersionUID = 1L; + /** The rest status code of this exception */ private final RestStatus restStatus; /** diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index 30261ae0e..d0ef3503c 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -19,6 +19,9 @@ * An enumeration of Flow Framework indices */ public enum FlowFrameworkIndex { + /** + * Global Context Index + */ GLOBAL_CONTEXT( GLOBAL_CONTEXT_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(GlobalContextHandler::getGlobalContextMappings), @@ -35,14 +38,26 @@ public enum FlowFrameworkIndex { this.version = version; } + /** + * Retrieves the index name + * @return the index name + */ public String getIndexName() { return indexName; } + /** + * Retrieves the index mapping + * @return the index mapping + */ public String getMapping() { return mapping; } + /** + * Retrieves the index version + * @return the index version + */ public Integer getVersion() { return version; } diff --git a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java b/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java index 994cdaeda..53037d7ce 100644 --- a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java @@ -81,7 +81,7 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { @@ -94,6 +94,35 @@ public void putTemplateToGlobalContext(Template template, ActionListener listener) { + if (!createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String exceptionMessage = "Failed to update template for workflow_id : " + + documentId + + ", global_context index does not exist."; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); + try ( + XContentBuilder builder = XContentFactory.jsonBuilder(); + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() + ) { + request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); + listener.onFailure(e); + } + } + } + /** * Update global context index for specific fields * @param documentId global context index document id diff --git a/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java index 1407036b3..b6da0abe5 100644 --- a/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java +++ b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java @@ -17,6 +17,8 @@ import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.TemplateUtil.buildStringToStringMap; +import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; /** * This represents a processor associated with search and ingest pipelines in the {@link Template}. @@ -46,7 +48,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws XContentBuilder xContentBuilder = builder.startObject(); xContentBuilder.field(TYPE_FIELD, this.type); xContentBuilder.field(PARAMS_FIELD); - Template.buildStringToStringMap(xContentBuilder, this.params); + buildStringToStringMap(xContentBuilder, this.params); return xContentBuilder.endObject(); } @@ -70,7 +72,7 @@ public static PipelineProcessor parse(XContentParser parser) throws IOException type = parser.text(); break; case PARAMS_FIELD: - params = Template.parseStringToStringMap(parser); + params = parseStringToStringMap(parser); break; default: throw new IOException("Unable to parse field [" + fieldName + "] in a pipeline processor object."); diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index b3f4478b9..7d08ef240 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -25,6 +25,8 @@ import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.TemplateUtil.jsonToParser; +import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. @@ -159,6 +161,228 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } + /** + * Converts a template object into a Global Context document + * @param builder the XContentBuilder + * @param params the params + * @return the XContentBuilder + * @throws IOException if the document source fails to be generated + */ + public XContentBuilder toDocumentSource(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(NAME_FIELD, this.name); + xContentBuilder.field(DESCRIPTION_FIELD, this.description); + xContentBuilder.field(USE_CASE_FIELD, this.useCase); + xContentBuilder.startArray(OPERATIONS_FIELD); + for (String op : this.operations) { + xContentBuilder.value(op); + } + xContentBuilder.endArray(); + + if (this.templateVersion != null || !this.compatibilityVersion.isEmpty()) { + xContentBuilder.startObject(VERSION_FIELD); + if (this.templateVersion != null) { + xContentBuilder.field(TEMPLATE_FIELD, this.templateVersion); + } + if (!this.compatibilityVersion.isEmpty()) { + xContentBuilder.startArray(COMPATIBILITY_FIELD); + for (Version v : this.compatibilityVersion) { + xContentBuilder.value(v); + } + xContentBuilder.endArray(); + } + xContentBuilder.endObject(); + } + + if (!this.userInputs.isEmpty()) { + xContentBuilder.startObject(USER_INPUTS_FIELD); + for (Entry e : userInputs.entrySet()) { + xContentBuilder.field(e.getKey(), e.getValue()); + } + xContentBuilder.endObject(); + } + + try (XContentBuilder workflowBuilder = JsonXContent.contentBuilder()) { + workflowBuilder.startObject(); + for (Entry e : workflows.entrySet()) { + workflowBuilder.field(e.getKey(), e.getValue()); + } + workflowBuilder.endObject(); + xContentBuilder.field(WORKFLOWS_FIELD, workflowBuilder.toString()); + } + + try (XContentBuilder userOutputsBuilder = JsonXContent.contentBuilder()) { + userOutputsBuilder.startObject(); + for (Entry e : userOutputs.entrySet()) { + userOutputsBuilder.field(e.getKey(), e.getValue()); + } + userOutputsBuilder.endObject(); + xContentBuilder.field(USER_OUTPUTS_FIELD, userOutputsBuilder.toString()); + } + + try (XContentBuilder resourcesCreatedBuilder = JsonXContent.contentBuilder()) { + resourcesCreatedBuilder.startObject(); + for (Entry e : resourcesCreated.entrySet()) { + resourcesCreatedBuilder.field(e.getKey(), e.getValue()); + } + resourcesCreatedBuilder.endObject(); + xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreatedBuilder.toString()); + } + + xContentBuilder.endObject(); + + return xContentBuilder; + + } + + /** + * Parse global context document source into a Template instance + * + * @param documentSource the document source string + * @return an instance of the template + * @throws IOException if content can't be parsed correctly + */ + public static Template parseFromDocumentSource(String documentSource) throws IOException { + XContentParser parser = jsonToParser(documentSource); + + String name = null; + String description = ""; + String useCase = ""; + List operations = new ArrayList<>(); + Version templateVersion = null; + List compatibilityVersion = new ArrayList<>(); + Map userInputs = new HashMap<>(); + Map workflows = new HashMap<>(); + Map userOutputs = new HashMap<>(); + Map resourcesCreated = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case USE_CASE_FIELD: + useCase = parser.text(); + break; + case OPERATIONS_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + operations.add(parser.text()); + } + break; + case VERSION_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String versionFieldName = parser.currentName(); + parser.nextToken(); + switch (versionFieldName) { + case TEMPLATE_FIELD: + templateVersion = Version.fromString(parser.text()); + break; + case COMPATIBILITY_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + compatibilityVersion.add(Version.fromString(parser.text())); + } + break; + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a version object."); + } + } + break; + case USER_INPUTS_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String inputFieldName = parser.currentName(); + switch (parser.nextToken()) { + case VALUE_STRING: + userInputs.put(inputFieldName, parser.text()); + break; + case START_OBJECT: + userInputs.put(inputFieldName, parseStringToStringMap(parser)); + break; + default: + throw new IOException("Unable to parse field [" + inputFieldName + "] in a user inputs object."); + } + } + break; + case WORKFLOWS_FIELD: + String workflowsJson = parser.text(); + XContentParser workflowsParser = jsonToParser(workflowsJson); + while (workflowsParser.nextToken() != XContentParser.Token.END_OBJECT) { + String workflowFieldName = workflowsParser.currentName(); + workflowsParser.nextToken(); + workflows.put(workflowFieldName, Workflow.parse(workflowsParser)); + } + break; + case USER_OUTPUTS_FIELD: + + String userOutputsJson = parser.text(); + XContentParser userOuputsParser = jsonToParser(userOutputsJson); + while (userOuputsParser.nextToken() != XContentParser.Token.END_OBJECT) { + String userOutputsFieldName = userOuputsParser.currentName(); + switch (userOuputsParser.nextToken()) { + case VALUE_STRING: + userOutputs.put(userOutputsFieldName, userOuputsParser.text()); + break; + case START_OBJECT: + userOutputs.put(userOutputsFieldName, parseStringToStringMap(userOuputsParser)); + break; + default: + throw new IOException("Unable to parse field [" + userOutputsFieldName + "] in a user_outputs object."); + } + } + break; + + case RESOURCES_CREATED_FIELD: + + String resourcesCreatedJson = parser.text(); + XContentParser resourcesCreatedParser = jsonToParser(resourcesCreatedJson); + while (resourcesCreatedParser.nextToken() != XContentParser.Token.END_OBJECT) { + String resourcesCreatedField = resourcesCreatedParser.currentName(); + switch (resourcesCreatedParser.nextToken()) { + case VALUE_STRING: + resourcesCreated.put(resourcesCreatedField, resourcesCreatedParser.text()); + break; + case START_OBJECT: + resourcesCreated.put(resourcesCreatedField, parseStringToStringMap(resourcesCreatedParser)); + break; + default: + throw new IOException( + "Unable to parse field [" + resourcesCreatedField + "] in a resources_created object." + ); + } + } + break; + + default: + throw new IOException("Unable to parse field [" + fieldName + "] in a template object."); + } + } + if (name == null) { + throw new IOException("An template object requires a name."); + } + + return new Template( + name, + description, + useCase, + operations, + templateVersion, + compatibilityVersion, + userInputs, + workflows, + userOutputs, + resourcesCreated + ); + } + /** * Parse raw json content into a Template instance. * @@ -317,39 +541,6 @@ public static Template parse(String json) throws IOException { return parse(parser); } - /** - * Builds an XContent object representing a map of String keys to String values. - * - * @param xContentBuilder An XContent builder whose position is at the start of the map object to build - * @param map A map as key-value String pairs. - * @throws IOException on a build failure - */ - public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map map) throws IOException { - xContentBuilder.startObject(); - for (Entry e : map.entrySet()) { - xContentBuilder.field((String) e.getKey(), (String) e.getValue()); - } - xContentBuilder.endObject(); - } - - /** - * Parses an XContent object representing a map of String keys to String values. - * - * @param parser An XContent parser whose position is at the start of the map object to parse - * @return A map as identified by the key-value pairs in the XContent - * @throws IOException on a parse failure - */ - public static Map parseStringToStringMap(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - Map map = new HashMap<>(); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - map.put(fieldName, parser.text()); - } - return map; - } - /** * Output this object in a compact JSON string. * diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 8c4a6ae52..e34c4ddec 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,6 +24,8 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.TemplateUtil.buildStringToStringMap; +import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; /** * This represents a process node (step) in a workflow graph in the {@link Template}. @@ -75,7 +77,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (e.getValue() instanceof String) { xContentBuilder.value(e.getValue()); } else if (e.getValue() instanceof Map) { - Template.buildStringToStringMap(xContentBuilder, (Map) e.getValue()); + buildStringToStringMap(xContentBuilder, (Map) e.getValue()); } else if (e.getValue() instanceof Object[]) { xContentBuilder.startArray(); if (PROCESSORS_FIELD.equals(e.getKey())) { @@ -84,7 +86,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } else { for (Map map : (Map[]) e.getValue()) { - Template.buildStringToStringMap(xContentBuilder, map); + buildStringToStringMap(xContentBuilder, map); } } xContentBuilder.endArray(); @@ -127,7 +129,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { inputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - inputs.put(inputFieldName, Template.parseStringToStringMap(parser)); + inputs.put(inputFieldName, parseStringToStringMap(parser)); break; case START_ARRAY: if (PROCESSORS_FIELD.equals(inputFieldName)) { @@ -139,7 +141,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { } else { List> mapList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - mapList.add(Template.parseStringToStringMap(parser)); + mapList.add(parseStringToStringMap(parser)); } inputs.put(inputFieldName, mapList.toArray(new Map[0])); } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java new file mode 100644 index 000000000..ace440f75 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.opensearch.client.node.NodeClient; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.transport.CreateWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +/** + * Rest Action to facilitate requests to create and update a use case template + */ +public class RestCreateWorkflowAction extends BaseRestHandler { + + private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; + + /** + * Intantiates a new RestCreateWorkflowAction + */ + public RestCreateWorkflowAction() {} + + @Override + public String getName() { + return CREATE_WORKFLOW_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of( + // Create new workflow + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s", WORKFLOW_URI)), + // Update use case template + new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, WORKFLOW_ID)) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + String workflowId = request.param(WORKFLOW_ID); + Template template = Template.parse(request.content().utf8ToString()); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); + return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java new file mode 100644 index 000000000..89471ee00 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.ProvisionWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +/** + * Rest action to facilitate requests to provision a workflow from an inline defined or stored use case template + */ +public class RestProvisionWorkflowAction extends BaseRestHandler { + + private static final String PROVISION_WORKFLOW_ACTION = "provision_workflow_action"; + + /** + * Instantiates a new RestProvisionWorkflowAction + */ + public RestProvisionWorkflowAction() {} + + @Override + public String getName() { + return PROVISION_WORKFLOW_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of( + // Provision workflow from indexed use case template + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_provision")) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + // Create request and provision + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java new file mode 100644 index 000000000..0f49c826f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowAction.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; + +/** + * External Action for public facing RestCreateWorkflowActiom + */ +public class CreateWorkflowAction extends ActionType { + + /** The name of this action */ + public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/create"; + /** An instance of this action */ + public static final CreateWorkflowAction INSTANCE = new CreateWorkflowAction(); + + private CreateWorkflowAction() { + super(NAME, WorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java new file mode 100644 index 000000000..f4147b144 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +/** + * Transport Action to index or update a use case template within the Global Context + */ +public class CreateWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(CreateWorkflowTransportAction.class); + + private final GlobalContextHandler globalContextHandler; + + /** + * Intantiates a new CreateWorkflowTransportAction + * @param transportService the TransportService + * @param actionFilters action filters + * @param globalContextHandler The handler for the global context index + */ + @Inject + public CreateWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + GlobalContextHandler globalContextHandler + ) { + super(CreateWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.globalContextHandler = globalContextHandler; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + if (request.getWorkflowId() == null) { + // Create new global context and state index entries + globalContextHandler.putTemplateToGlobalContext(request.getTemplate(), ActionListener.wrap(response -> { + // TODO : Check if state index exists, create if not + // TODO : Create StateIndexRequest for workflowId, default to NOT_STARTED + listener.onResponse(new WorkflowResponse(response.getId())); + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + })); + } else { + // Update existing entry, full document replacement + globalContextHandler.updateTemplateInGlobalContext( + request.getWorkflowId(), + request.getTemplate(), + ActionListener.wrap(response -> { + // TODO : Create StateIndexRequest for workflowId to reset entry to NOT_STARTED + listener.onResponse(new WorkflowResponse(response.getId())); + }, exception -> { + logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + }) + ); + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java new file mode 100644 index 000000000..022e73488 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACION_NAME_PREFIX; + +/** + * External Action for public facing RestProvisionWorkflowAction + */ +public class ProvisionWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACION_NAME_PREFIX + "workflow/provision"; + /** An instance of this action */ + public static final ProvisionWorkflowAction INSTANCE = new ProvisionWorkflowAction(); + + private ProvisionWorkflowAction() { + super(NAME, WorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java new file mode 100644 index 000000000..0dbec5bf2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; + +/** + * Transport Action to provision a workflow from a stored use case template + */ +public class ProvisionWorkflowTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(ProvisionWorkflowTransportAction.class); + + private final ThreadPool threadPool; + private final Client client; + private final WorkflowProcessSorter workflowProcessSorter; + + /** + * Instantiates a new ProvisionWorkflowTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param threadPool The OpenSearch thread pool + * @param client The node client to retrieve a stored use case template + * @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes + */ + @Inject + public ProvisionWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + WorkflowProcessSorter workflowProcessSorter + ) { + super(ProvisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.threadPool = threadPool; + this.client = client; + this.workflowProcessSorter = workflowProcessSorter; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + + // Retrieve use case template from global context + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + // Stash thread context to interact with system index + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + // Parse template from document source + Template template = Template.parseFromDocumentSource(response.getSourceAsString()); + + // TODO : Update state index entry to PROVISIONING, given workflowId + + // Respond to rest action then execute provisioning workflow async + listener.onResponse(new WorkflowResponse(workflowId)); + executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW)); + }, exception -> { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + })); + } catch (Exception e) { + logger.error("Failed to retrieve template from global context.", e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + /** + * Retrieves a thread from the provision thread pool to execute a workflow + * @param workflowId The id of the workflow + * @param workflow The workflow to execute + */ + private void executeWorkflowAsync(String workflowId, Workflow workflow) { + // TODO : Update Action listener type to State index Request + ActionListener provisionWorkflowListener = ActionListener.wrap(response -> { + logger.info("Provisioning completed successuflly for workflow {}", workflowId); + + // TODO : Create State index request to update STATE entry status to READY + }, exception -> { + logger.error("Provisioning failed for workflow {} : {}", workflowId, exception); + + // TODO : Create State index request to update STATE entry status to FAILED + }); + try { + threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflow, provisionWorkflowListener); }); + } catch (Exception exception) { + provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + /** + * Topologically sorts a given workflow into a sequence of ProcessNodes and executes the workflow + * @param workflow The workflow to execute + * @param workflowListener The listener that updates the status of a workflow execution + */ + private void executeWorkflow(Workflow workflow, ActionListener workflowListener) { + + List processSequence = workflowProcessSorter.sortProcessNodes(workflow); + List> workflowFutureList = new ArrayList<>(); + + for (ProcessNode processNode : processSequence) { + List predecessors = processNode.predecessors(); + + logger.info( + "Queueing process [{}].{}", + processNode.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + + workflowFutureList.add(processNode.execute()); + } + try { + // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally + workflowFutureList.forEach(CompletableFuture::join); + + // TODO : Create State Index request with provisioning state, start time, end time, etc, pending implementation. String for now + workflowListener.onResponse("READY"); + } catch (CancellationException | CompletionException ex) { + workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java new file mode 100644 index 000000000..0b105552f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.flowframework.model.Template; + +import java.io.IOException; + +/** + * Transport Request to create and provision a workflow + */ +public class WorkflowRequest extends ActionRequest { + + /** + * The documentId of the workflow entry within the Global Context index + */ + @Nullable + private String workflowId; + /** + * The use case template to index + */ + @Nullable + private Template template; + + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + */ + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { + this.workflowId = workflowId; + this.template = template; + } + + /** + * Instantiates a new Workflow request + * @param in The input stream to read from + * @throws IOException If the stream cannot be read properly + */ + public WorkflowRequest(StreamInput in) throws IOException { + super(in); + this.workflowId = in.readOptionalString(); + String templateJson = in.readOptionalString(); + this.template = templateJson == null ? null : Template.parse(templateJson); + } + + /** + * Gets the workflow Id of the request + * @return the workflow Id + */ + @Nullable + public String getWorkflowId() { + return this.workflowId; + } + + /** + * Gets the use case template of the request + * @return the use case template + */ + @Nullable + public Template getTemplate() { + return this.template; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(workflowId); + out.writeOptionalString(template == null ? null : template.toJson()); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java new file mode 100644 index 000000000..20a7700a3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; + +/** + * Transport Response from creating or provisioning a workflow + */ +public class WorkflowResponse extends ActionResponse implements ToXContentObject { + + /** + * The documentId of the workflow entry within the Global Context index + */ + private String workflowId; + + /** + * Instantiates a new WorkflowResponse from params + * @param workflowId the documentId of the indexed use case template + */ + public WorkflowResponse(String workflowId) { + this.workflowId = workflowId; + } + + /** + * Instatiates a new WorkflowResponse from an input stream + * @param in the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public WorkflowResponse(StreamInput in) throws IOException { + super(in); + this.workflowId = in.readString(); + } + + /** + * Gets the workflowId of this repsonse + * @return the workflowId + */ + public String getWorkflowId() { + return this.workflowId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(workflowId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 848f621a2..2b2f7338d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -117,6 +117,16 @@ public String getName() { return NAME; } + // TODO : Move to index management class, pending implementation + /** + * Checks if the given index exists + * @param indexName the name of the index + * @return boolean indicating the existence of an index + */ + public boolean doesIndexExist(String indexName) { + return clusterService.state().metadata().hasIndex(indexName); + } + /** * Create Index if it's absent * @param index The index that needs to be created diff --git a/src/main/resources/mappings/global-context.json b/src/main/resources/mappings/global-context.json index 86e952942..cd3ce4e6b 100644 --- a/src/main/resources/mappings/global-context.json +++ b/src/main/resources/mappings/global-context.json @@ -29,10 +29,10 @@ "type": "nested", "properties": { "template": { - "type": "integer" + "type": "text" }, "compatibility": { - "type": "integer" + "type": "text" } } }, diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index ea8a3b520..583636b0a 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -12,6 +12,7 @@ import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -31,6 +32,7 @@ public class FlowFrameworkPluginTests extends OpenSearchTestCase { private ClusterAdminClient clusterAdminClient; private ThreadPool threadPool; + private Settings settings; @Override public void setUp() throws Exception { @@ -41,6 +43,7 @@ public void setUp() throws Exception { when(client.admin()).thenReturn(adminClient); when(adminClient.cluster()).thenReturn(clusterAdminClient); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + settings = Settings.EMPTY; } @Override @@ -51,7 +54,10 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { - assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + assertEquals(3, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + assertEquals(2, ffp.getRestHandlers(null, null, null, null, null, null, null).size()); + assertEquals(2, ffp.getActions().size()); + assertEquals(1, ffp.getExecutorBuilders(settings).size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java index 0380e4808..800f1d49e 100644 --- a/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java @@ -73,7 +73,7 @@ public void setUp() throws Exception { public void testPutTemplateToGlobalContext() throws IOException { Template template = mock(Template.class); - when(template.toXContent(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { + when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { XContentBuilder builder = invocation.getArgument(0); return builder; }); @@ -84,7 +84,7 @@ public void testPutTemplateToGlobalContext() throws IOException { ActionListener callback = invocation.getArgument(1); callback.onResponse(true); return null; - }).when(createIndexStep).initIndexIfAbsent(any(), any()); + }).when(createIndexStep).initIndexIfAbsent(any(FlowFrameworkIndex.class), any()); globalContextHandler.putTemplateToGlobalContext(template, listener); @@ -109,4 +109,38 @@ public void testStoreResponseToGlobalContext() { assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); assertEquals(documentId, requestCaptor.getValue().id()); } + + public void testUpdateTemplateInGlobalContext() throws IOException { + Template template = mock(Template.class); + ActionListener listener = mock(ActionListener.class); + when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { + XContentBuilder builder = invocation.getArgument(0); + return builder; + }); + when(createIndexStep.doesIndexExist(any())).thenReturn(true); + + globalContextHandler.updateTemplateInGlobalContext("1", template, null); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + verify(client, times(1)).index(requestCaptor.capture(), any()); + + assertEquals("1", requestCaptor.getValue().id()); + } + + public void testFailedUpdateTemplateInGlobalContext() throws IOException { + Template template = mock(Template.class); + ActionListener listener = mock(ActionListener.class); + when(createIndexStep.doesIndexExist(any())).thenReturn(false); + + globalContextHandler.updateTemplateInGlobalContext("1", template, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + + assertEquals( + "Failed to update template for workflow_id : 1, global_context index does not exist.", + exceptionCaptor.getValue().getMessage() + ); + + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java new file mode 100644 index 000000000..fd07a91eb --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.Version; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; + +public class RestCreateWorkflowActionTests extends OpenSearchTestCase { + + private String invalidTemplate; + private RestCreateWorkflowAction createWorkflowRestAction; + private String createWorkflowPath; + private String updateWorkflowPath; + private NodeClient nodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + List operations = List.of("operation"); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + Template template = new Template( + "test", + "description", + "use case", + operations, + templateVersion, + compatibilityVersions, + Map.ofEntries(Map.entry("userKey", "userValue"), Map.entry("userMapKey", Map.of("nestedKey", "nestedValue"))), + Map.of("workflow", workflow), + Map.of("outputKey", "outputValue"), + Map.of("resourceKey", "resourceValue") + ); + + // Invalid template configuration, wrong field name + this.invalidTemplate = template.toJson().replace("use_case", "invalid"); + this.createWorkflowRestAction = new RestCreateWorkflowAction(); + this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); + this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestCreateWorkflowActionName() { + String name = createWorkflowRestAction.getName(); + assertEquals("create_workflow_action", name); + } + + public void testRestCreateWorkflowActionRoutes() { + List routes = createWorkflowRestAction.routes(); + assertEquals(2, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(RestRequest.Method.PUT, routes.get(1).getMethod()); + assertEquals(this.createWorkflowPath, routes.get(0).getPath()); + assertEquals(this.updateWorkflowPath, routes.get(1).getPath()); + + } + + public void testInvalidCreateWorkflowRequest() throws IOException { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withContent(new BytesArray(invalidTemplate), MediaTypeRegistry.JSON) + .build(); + + IOException ex = expectThrows(IOException.class, () -> { createWorkflowRestAction.prepareRequest(request, nodeClient); }); + assertEquals("Unable to parse field [invalid] in a template object.", ex.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java new file mode 100644 index 000000000..a44817cec --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; + +public class RestProvisionWorkflowActionTests extends OpenSearchTestCase { + + private RestProvisionWorkflowAction provisionWorkflowRestAction; + private String provisionWorkflowPath; + private NodeClient nodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.provisionWorkflowRestAction = new RestProvisionWorkflowAction(); + this.provisionWorkflowPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_provision"); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestProvisionWorkflowActionName() { + String name = provisionWorkflowRestAction.getName(); + assertEquals("provision_workflow_action", name); + } + + public void testRestProvisiionWorkflowActionRoutes() { + List routes = provisionWorkflowRestAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(this.provisionWorkflowPath, routes.get(0).getPath()); + } + + public void testNullWorkflowIdAndTemplate() throws IOException { + + // Request with no content or params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .build(); + + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { + provisionWorkflowRestAction.prepareRequest(request, nodeClient); + }); + assertEquals("workflow_id cannot be null", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + } + + public void testInvalidRequestWithContent() throws IOException { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { + provisionWorkflowRestAction.prepareRequest(request, nodeClient); + }); + assertEquals("Invalid request format", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java new file mode 100644 index 000000000..1b937570b --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { + + private CreateWorkflowTransportAction createWorkflowTransportAction; + private GlobalContextHandler globalContextHandler; + private Template template; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.globalContextHandler = mock(GlobalContextHandler.class); + this.createWorkflowTransportAction = new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + globalContextHandler + ); + + List operations = List.of("operation"); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + this.template = new Template( + "test", + "description", + "use case", + operations, + templateVersion, + compatibilityVersions, + Map.ofEntries(Map.entry("userKey", "userValue"), Map.entry("userMapKey", Map.of("nestedKey", "nestedValue"))), + Map.of("workflow", workflow), + Map.of("outputKey", "outputValue"), + Map.of("resourceKey", "resourceValue") + ); + } + + public void testCreateNewWorkflow() { + + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(globalContextHandler).putTemplateToGlobalContext(any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + + assertEquals("1", responseCaptor.getValue().getWorkflowId()); + + } + + public void testFailedToCreateNewWorkflow() { + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to create global_context index")); + return null; + }).when(globalContextHandler).putTemplateToGlobalContext(any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to create global_context index", exceptionCaptor.getValue().getMessage()); + } + + public void testUpdateWorkflow() { + + ActionListener listener = mock(ActionListener.class); + WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(globalContextHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + + assertEquals("1", responseCaptor.getValue().getWorkflowId()); + } + + public void testFailedToUpdateWorkflow() { + ActionListener listener = mock(ActionListener.class); + WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onFailure(new Exception("Failed to update use case template")); + return null; + }).when(globalContextHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update use case template", exceptionCaptor.getValue().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java new file mode 100644 index 000000000..7e1e13e03 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.index.get.GetResult; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private Client client; + private WorkflowProcessSorter workflowProcessSorter; + private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; + private Template template; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.threadPool = mock(ThreadPool.class); + this.client = mock(Client.class); + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + + this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + threadPool, + client, + workflowProcessSorter + ); + + List operations = List.of("operation"); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + this.template = new Template( + "test", + "description", + "use case", + operations, + templateVersion, + compatibilityVersions, + Map.ofEntries(Map.entry("userKey", "userValue"), Map.entry("userMapKey", Map.of("nestedKey", "nestedValue"))), + Map.of("provision", workflow), + Map.of("outputKey", "outputValue"), + Map.of("resourceKey", "resourceValue") + ); + + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testProvisionWorkflow() { + + String workflowId = "1"; + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + this.template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + } + + public void testFailedToRetrieveTemplateFromGlobalContext() { + ActionListener listener = mock(ActionListener.class); + WorkflowRequest request = new WorkflowRequest("1", null); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to retrieve template from global context.")); + return null; + }).when(client).get(any(GetRequest.class), any()); + + provisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to retrieve template from global context.", exceptionCaptor.getValue().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java new file mode 100644 index 000000000..cc5c19a09 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class WorkflowRequestResponseTests extends OpenSearchTestCase { + + private Template template; + + @Override + public void setUp() throws Exception { + super.setUp(); + List operations = List.of("operation"); + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + this.template = new Template( + "test", + "description", + "use case", + operations, + templateVersion, + compatibilityVersions, + Map.ofEntries(Map.entry("userKey", "userValue"), Map.entry("userMapKey", Map.of("nestedKey", "nestedValue"))), + Map.of("workflow", workflow), + Map.of("outputKey", "outputValue"), + Map.of("resourceKey", "resourceValue") + ); + } + + public void testNullIdWorkflowRequest() throws IOException { + WorkflowRequest nullIdRequest = new WorkflowRequest(null, template); + assertNull(nullIdRequest.getWorkflowId()); + assertEquals(template, nullIdRequest.getTemplate()); + assertNull(nullIdRequest.validate()); + + BytesStreamOutput out = new BytesStreamOutput(); + nullIdRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(nullIdRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(nullIdRequest.getTemplate().toJson(), streamInputRequest.getTemplate().toJson()); + } + + public void testNullTemplateWorkflowRequest() throws IOException { + WorkflowRequest nullTemplateRequest = new WorkflowRequest("123", null); + assertNotNull(nullTemplateRequest.getWorkflowId()); + assertNull(nullTemplateRequest.getTemplate()); + assertNull(nullTemplateRequest.validate()); + + BytesStreamOutput out = new BytesStreamOutput(); + nullTemplateRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(nullTemplateRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(nullTemplateRequest.getTemplate(), streamInputRequest.getTemplate()); + } + + public void testWorkflowRequest() throws IOException { + WorkflowRequest workflowRequest = new WorkflowRequest("123", template); + assertNotNull(workflowRequest.getWorkflowId()); + assertEquals(template, workflowRequest.getTemplate()); + assertNull(workflowRequest.validate()); + + BytesStreamOutput out = new BytesStreamOutput(); + workflowRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(workflowRequest.getTemplate().toJson(), streamInputRequest.getTemplate().toJson()); + + } + + public void testWorkflowResponse() throws IOException { + WorkflowResponse response = new WorkflowResponse("123"); + assertEquals("123", response.getWorkflowId()); + + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowResponse streamInputResponse = new WorkflowResponse(in); + assertEquals(response.getWorkflowId(), streamInputResponse.getWorkflowId()); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + assertEquals("{\"workflow_id\":\"123\"}", builder.toString()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 72371095c..a952f6fe7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -128,7 +128,7 @@ public void testInitIndexIfAbsent_IndexNotPresent() { ActionListener listener = mock(ActionListener.class); createIndexStep.initIndexIfAbsent(FlowFrameworkIndex.GLOBAL_CONTEXT, listener); - verify(indicesAdminClient, times(1)).create(any(), any()); + verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), any()); } public void testInitIndexIfAbsent_IndexExist() { @@ -187,4 +187,18 @@ public void testInitIndexIfAbsent_IndexExist_returnFalse() { createIndexStep.initIndexIfAbsent(index, listener); assertTrue(indexMappingUpdated.get(index.getIndexName()).get()); } + + public void testDoesIndexExist() { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + + createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX); + + ArgumentCaptor indexExistsCaptor = ArgumentCaptor.forClass(String.class); + verify(mockMetaData, times(1)).hasIndex(indexExistsCaptor.capture()); + + assertEquals(GLOBAL_CONTEXT_INDEX, indexExistsCaptor.getValue()); + } }