diff --git a/.github/workflows/publish-snapshots.yml b/.github/workflows/publish-snapshots.yml index 737787a4d..33887585a 100644 --- a/.github/workflows/publish-snapshots.yml +++ b/.github/workflows/publish-snapshots.yml @@ -20,7 +20,7 @@ jobs: contents: write steps: - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v3 with: distribution: temurin # Temurin is a distribution of adoptium java-version: 17 diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 5e68587e2..107386c3d 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -36,7 +36,7 @@ This package uses the [Gradle](https://docs.gradle.org/current/userguide/usergui #### Building from the command line 1. `./gradlew check` builds and tests. -2. `./gradlew :run` installs and runs ML-Commons and Flow Framework Plugins into a local cluster +2. `./gradlew :run` runs the plugin. 3. `./gradlew spotlessApply` formats code. And/or import formatting rules in [formatterConfig.xml](formatter/formatterConfig.xml) with IDE. 4. `./gradlew test` to run the complete test suite. diff --git a/build.gradle b/build.gradle index 59f3cc9b3..47381c288 100644 --- a/build.gradle +++ b/build.gradle @@ -162,7 +162,7 @@ dependencies { configurations.all { resolutionStrategy { force("com.google.guava:guava:32.1.3-jre") // CVE for 31.1 - force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0 + force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // CVE for < 3.29.0 force("com.fasterxml.jackson.core:jackson-core:2.16.0") // Dependency Jar Hell } } diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index d64cd4917..7f93135c4 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 48338a458..3fa8f862f 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,8 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionSha256Sum=9d926787066a081739e8200858338b4a69e837c3a821a33aca9db09dd4a41026 diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 14df7e17e..2b2286b1f 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -63,6 +63,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -106,7 +107,7 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @@ -144,6 +145,7 @@ public List> getSettings() { List> settings = ImmutableList.of( FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY ); diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 1824197e8..536fa2c73 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {} /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; + /** The upper limit of max workflow steps that can be in a single workflow */ + public static final int MAX_WORKFLOW_STEPS_LIMIT = 500; /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOWS = Setting.intSetting( @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {} Setting.Property.Dynamic ); + /** This setting sets max workflows that can be created */ + public static final Setting MAX_WORKFLOW_STEPS = Setting.intSetting( + "plugins.flow_framework.max_workflow_steps", + 50, + 1, + MAX_WORKFLOW_STEPS_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + /** This setting sets the timeout for the request */ public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( "plugins.flow_framework.request_timeout", diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java deleted file mode 100644 index bf950f280..000000000 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.workflow; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTask; - -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.concurrent.CompletableFuture; - -import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; -import static org.opensearch.flowframework.common.CommonValue.TASK_ID; - -/** - * Step to retrieve an ML Task - */ -public class GetMLTaskStep implements WorkflowStep { - - private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class); - private MachineLearningNodeClient mlClient; - static final String NAME = "get_ml_task"; - - /** - * Instantiate this class - * @param mlClient client to instantiate MLClient - */ - public GetMLTaskStep(MachineLearningNodeClient mlClient) { - this.mlClient = mlClient; - } - - @Override - public CompletableFuture execute(List data) { - - CompletableFuture getMLTaskFuture = new CompletableFuture<>(); - - ActionListener actionListener = ActionListener.wrap(response -> { - - // TODO : Add retry capability if response status is not COMPLETED : - // https://github.com/opensearch-project/flow-framework/issues/158 - - logger.info("ML Task retrieval successful"); - getMLTaskFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())), - data.get(0).getWorkflowId() - ) - ); - }, exception -> { - logger.error("Failed to retrieve ML Task"); - getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); - }); - - String taskId = null; - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case TASK_ID: - taskId = (String) content.get(TASK_ID); - break; - default: - break; - } - } - } - - if (taskId == null) { - logger.error("Failed to retrieve ML Task"); - getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); - } else { - mlClient.getTask(taskId, actionListener); - } - - return getMLTaskFuture; - } - - @Override - public String getName() { - return NAME; - } - -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3e8b77f9d..da362383b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -32,6 +34,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; @@ -45,16 +48,26 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; + private Integer maxWorkflowSteps; /** * Instantiate this class. * * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. + * @param clusterService The OpenSearch cluster service. + * @param settings OpenSerch settings */ - public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { + public WorkflowProcessSorter( + WorkflowStepFactory workflowStepFactory, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings + ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; + this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } /** @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ public List sortProcessNodes(Workflow workflow, String workflowId) { + if (workflow.nodes().size() > this.maxWorkflowSteps) { + throw new FlowFrameworkException( + "Workflow " + + workflowId + + " has " + + workflow.nodes().size() + + " nodes, which exceeds the maximum of " + + this.maxWorkflowSteps + + ". Change the setting [" + + MAX_WORKFLOW_STEPS.getKey() + + "] to increase this.", + RestStatus.BAD_REQUEST + ); + } List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index e3827e0b3..2585ffb09 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,7 +63,7 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -84,7 +85,7 @@ public void testPlugin() throws IOException { assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(4, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); - assertEquals(4, ffp.getSettings().size()); + assertEquals(5, ffp.getSettings().size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 70c066c0e..9b664b729 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -13,6 +13,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -34,18 +37,24 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -61,6 +70,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; + private ClusterSettings clusterSettings; + private ClusterService clusterService; private ParseUtils parseUtils; private ThreadContext threadContext; private Settings settings; @@ -73,8 +84,15 @@ public void setUp() throws Exception { .put("plugins.flow_framework.max_workflows.", 2) .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) .build(); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); + clusterSettings = new ClusterSettings(settings, settingsSet); + clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 8103f4fbf..d1590acd8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -32,6 +33,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -41,6 +43,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; @@ -79,11 +82,12 @@ public static void setup() { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.admin()).thenReturn(adminClient); @@ -96,7 +100,7 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); } @AfterClass @@ -245,6 +249,21 @@ public void testExceptions() throws IOException { ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("A")), Collections.emptyList()))); assertEquals("Duplicate node id A.", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); + + ex = assertThrows( + FlowFrameworkException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C"), node("D"), node("E"), node("F")), Collections.emptyList())) + ); + String message = String.format( + Locale.ROOT, + "Workflow %s has %d nodes, which exceeds the maximum of %d. Change the setting [%s] to increase this.", + "123", + 6, + 5, + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + ); + assertEquals(message, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testSuccessfulGraphValidation() throws Exception {