diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 14df7e17e..2b2286b1f 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -63,6 +63,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; /** @@ -106,7 +107,7 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @@ -144,6 +145,7 @@ public List> getSettings() { List> settings = ImmutableList.of( FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY ); diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 1824197e8..536fa2c73 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -18,6 +18,8 @@ private FlowFrameworkSettings() {} /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; + /** The upper limit of max workflow steps that can be in a single workflow */ + public static final int MAX_WORKFLOW_STEPS_LIMIT = 500; /** This setting sets max workflows that can be created */ public static final Setting MAX_WORKFLOWS = Setting.intSetting( @@ -29,6 +31,16 @@ private FlowFrameworkSettings() {} Setting.Property.Dynamic ); + /** This setting sets max workflows that can be created */ + public static final Setting MAX_WORKFLOW_STEPS = Setting.intSetting( + "plugins.flow_framework.max_workflow_steps", + 50, + 1, + MAX_WORKFLOW_STEPS_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + /** This setting sets the timeout for the request */ public static final Setting WORKFLOW_REQUEST_TIMEOUT = Setting.positiveTimeSetting( "plugins.flow_framework.request_timeout", diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3e8b77f9d..da362383b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -32,6 +34,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; @@ -45,16 +48,26 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; + private Integer maxWorkflowSteps; /** * Instantiate this class. * * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. + * @param clusterService The OpenSearch cluster service. + * @param settings OpenSerch settings */ - public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { + public WorkflowProcessSorter( + WorkflowStepFactory workflowStepFactory, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings + ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; + this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } /** @@ -64,6 +77,20 @@ public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ public List sortProcessNodes(Workflow workflow, String workflowId) { + if (workflow.nodes().size() > this.maxWorkflowSteps) { + throw new FlowFrameworkException( + "Workflow " + + workflowId + + " has " + + workflow.nodes().size() + + " nodes, which exceeds the maximum of " + + this.maxWorkflowSteps + + ". Change the setting [" + + MAX_WORKFLOW_STEPS.getKey() + + "] to increase this.", + RestStatus.BAD_REQUEST + ); + } List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index e3827e0b3..2585ffb09 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,7 +63,7 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -84,7 +85,7 @@ public void testPlugin() throws IOException { assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(4, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); - assertEquals(4, ffp.getSettings().size()); + assertEquals(5, ffp.getSettings().size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 70c066c0e..9b664b729 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -13,6 +13,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -34,18 +37,24 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -61,6 +70,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; + private ClusterSettings clusterSettings; + private ClusterService clusterService; private ParseUtils parseUtils; private ThreadContext threadContext; private Settings settings; @@ -73,8 +84,15 @@ public void setUp() throws Exception { .put("plugins.flow_framework.max_workflows.", 2) .put("plugins.flow_framework.request_timeout", TimeValue.timeValueSeconds(10)) .build(); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); + clusterSettings = new ClusterSettings(settings, settingsSet); + clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool, clusterService, settings); this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e9e792add..8958720f4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -32,6 +33,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -41,6 +43,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; @@ -79,11 +82,12 @@ public static void setup() { MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + Settings settings = Settings.builder().put("plugins.flow_framework.max_workflow_steps", 5).build(); final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) + Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.admin()).thenReturn(adminClient); @@ -96,7 +100,7 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); } @AfterClass @@ -245,6 +249,21 @@ public void testExceptions() throws IOException { ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("A")), Collections.emptyList()))); assertEquals("Duplicate node id A.", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); + + ex = assertThrows( + FlowFrameworkException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C"), node("D"), node("E"), node("F")), Collections.emptyList())) + ); + String message = String.format( + Locale.ROOT, + "Workflow %s has %d nodes, which exceeds the maximum of %d. Change the setting [%s] to increase this.", + "123", + 6, + 5, + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + ); + assertEquals(message, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testSuccessfulGraphValidation() throws Exception {