diff --git a/CHANGELOG.md b/CHANGELOG.md index b302d51e6..a66add559 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.13...2.x) ### Features ### Enhancements +- Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658)) ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) - Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639)) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 8df5613d4..ac40e5f8e 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -202,4 +202,18 @@ private CommonValue() {} public static final String RESOURCE_ID = "resource_id"; /** The field name for the opensearch-ml plugin */ public static final String OPENSEARCH_ML = "opensearch-ml"; + + /* + * Constants assoicated with substitution / default templates + */ + /** The field name for connector credential key substitution */ + public static final String CREATE_CONNECTOR_CREDENTIAL_KEY = "create_connector.credential.key"; + /** The field name for connector credential access key substitution */ + public static final String CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY = "create_connector.credential.access_key"; + /** The field name for connector credential secret key substitution */ + public static final String CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY = "create_connector.credential.secret_key"; + /** The field name for connector credential session token substitution */ + public static final String CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN = "create_connector.credential.session_token"; + /** The field name for ingest pipeline model ID substitution */ + public static final String CREATE_INGEST_PIPELINE_MODEL_ID = "create_ingest_pipeline.model_id"; } diff --git a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java index 265409562..bc88f2b4d 100644 --- a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java +++ b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java @@ -13,6 +13,16 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY; +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY; +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN; +import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID; + /** * Enum encapsulating the different default use cases and templates we have stored */ @@ -22,94 +32,119 @@ public enum DefaultUseCases { OPEN_AI_EMBEDDING_MODEL_DEPLOY( "open_ai_embedding_model_deploy", "defaults/openai-embedding-defaults.json", - "substitutionTemplates/deploy-remote-model-template.json" + "substitutionTemplates/deploy-remote-model-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for Cohere embedding model */ COHERE_EMBEDDING_MODEL_DEPLOY( "cohere_embedding_model_deploy", "defaults/cohere-embedding-defaults.json", - "substitutionTemplates/deploy-remote-model-extra-params-template.json" + "substitutionTemplates/deploy-remote-model-extra-params-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for Bedrock Titan embedding model */ BEDROCK_TITAN_EMBEDDING_MODEL_DEPLOY( "bedrock_titan_embedding_model_deploy", "defaults/bedrock-titan-embedding-defaults.json", - "substitutionTemplates/deploy-remote-bedrock-model-template.json" + "substitutionTemplates/deploy-remote-bedrock-model-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN) ), /** defaults file and substitution ready template for Bedrock Titan multimodal embedding model */ BEDROCK_TITAN_MULTIMODAL_MODEL_DEPLOY( "bedrock_titan_multimodal_model_deploy", "defaults/bedrock-titan-multimodal-defaults.json", - "substitutionTemplates/deploy-remote-bedrock-model-template.json" + "substitutionTemplates/deploy-remote-bedrock-model-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN) ), /** defaults file and substitution ready template for Cohere chat model */ COHERE_CHAT_MODEL_DEPLOY( "cohere_chat_model_deploy", "defaults/cohere-chat-defaults.json", - "substitutionTemplates/deploy-remote-model-chat-template.json" + "substitutionTemplates/deploy-remote-model-chat-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for OpenAI chat model */ OPENAI_CHAT_MODEL_DEPLOY( "openai_chat_model_deploy", "defaults/openai-chat-defaults.json", - "substitutionTemplates/deploy-remote-model-chat-template.json" + "substitutionTemplates/deploy-remote-model-chat-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for local neural sparse model and ingest pipeline*/ LOCAL_NEURAL_SPARSE_SEARCH_BI_ENCODER( "local_neural_sparse_search_bi_encoder", "defaults/local-sparse-search-biencoder-defaults.json", - "substitutionTemplates/neural-sparse-local-biencoder-template.json" + "substitutionTemplates/neural-sparse-local-biencoder-template.json", + Collections.emptyList() ), /** defaults file and substitution ready template for semantic search, no model creation*/ - SEMANTIC_SEARCH("semantic_search", "defaults/semantic-search-defaults.json", "substitutionTemplates/semantic-search-template.json"), + SEMANTIC_SEARCH( + "semantic_search", + "defaults/semantic-search-defaults.json", + "substitutionTemplates/semantic-search-template.json", + List.of(CREATE_INGEST_PIPELINE_MODEL_ID) + ), /** defaults file and substitution ready template for multimodal search, no model creation*/ MULTI_MODAL_SEARCH( "multimodal_search", "defaults/multi-modal-search-defaults.json", - "substitutionTemplates/multi-modal-search-template.json" + "substitutionTemplates/multi-modal-search-template.json", + List.of(CREATE_INGEST_PIPELINE_MODEL_ID) ), /** defaults file and substitution ready template for multimodal search, no model creation*/ MULTI_MODAL_SEARCH_WITH_BEDROCK_TITAN( "multimodal_search_with_bedrock_titan", "defaults/multimodal-search-bedrock-titan-defaults.json", - "substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json" + "substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN) ), /** defaults file and substitution ready template for semantic search with query enricher processor attached, no model creation*/ SEMANTIC_SEARCH_WITH_QUERY_ENRICHER( "semantic_search_with_query_enricher", "defaults/semantic-search-query-enricher-defaults.json", - "substitutionTemplates/semantic-search-with-query-enricher-template.json" + "substitutionTemplates/semantic-search-with-query-enricher-template.json", + List.of(CREATE_INGEST_PIPELINE_MODEL_ID) ), /** defaults file and substitution ready template for semantic search with cohere embedding model*/ SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING( "semantic_search_with_cohere_embedding", "defaults/cohere-embedding-semantic-search-defaults.json", - "substitutionTemplates/semantic-search-with-model-template.json" + "substitutionTemplates/semantic-search-with-model-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for semantic search with query enricher processor attached and cohere embedding model*/ SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING_AND_QUERY_ENRICHER( "semantic_search_with_cohere_embedding_query_enricher", "defaults/cohere-embedding-semantic-search-with-query-enricher-defaults.json", - "substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json" + "substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ), /** defaults file and substitution ready template for hybrid search, no model creation*/ - HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"), + HYBRID_SEARCH( + "hybrid_search", + "defaults/hybrid-search-defaults.json", + "substitutionTemplates/hybrid-search-template.json", + List.of(CREATE_INGEST_PIPELINE_MODEL_ID) + ), /** defaults file and substitution ready template for conversational search with cohere chat model*/ CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY( "conversational_search_with_llm_deploy", "defaults/conversational-search-defaults.json", - "substitutionTemplates/conversational-search-with-cohere-model-template.json" + "substitutionTemplates/conversational-search-with-cohere-model-template.json", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) ); private final String useCaseName; private final String defaultsFile; private final String substitutionReadyFile; + private final List requiredParams; private static final Logger logger = LogManager.getLogger(DefaultUseCases.class); - DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) { + DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile, List requiredParams) { this.useCaseName = useCaseName; this.defaultsFile = defaultsFile; this.substitutionReadyFile = substitutionReadyFile; + this.requiredParams = requiredParams; } /** @@ -136,6 +171,14 @@ public String getSubstitutionReadyFile() { return substitutionReadyFile; } + /** + * Returns the required params for the given enum Constant + * @return the required params of the given useCase + */ + public List getRequiredParams() { + return requiredParams; + } + /** * Gets the defaultsFile based on the given use case. * @param useCaseName name of the given use case @@ -171,4 +214,21 @@ public static String getSubstitutionReadyFileByUseCaseName(String useCaseName) t logger.error("Unable to find substitution ready file for use case: {}", useCaseName); throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST); } + + /** + * Gets the required parameters based on the given use case + * @param useCaseName name of the given use case + * @return the list of required params + */ + public static List getRequiredParamsByUseCaseName(String useCaseName) { + if (useCaseName != null && !useCaseName.isEmpty()) { + for (DefaultUseCases useCase : values()) { + if (useCase.getUseCaseName().equals(useCaseName)) { + return new ArrayList(useCase.getRequiredParams()); + } + } + } + logger.error("Default use case [" + useCaseName + "] does not exist"); + throw new FlowFrameworkException("Default use case [" + useCaseName + "] does not exist", RestStatus.BAD_REQUEST); + } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 59c8a3b59..5db17d2b7 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -126,12 +127,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase); useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath); - - if (request.hasContent()) { + List requiredParams = DefaultUseCases.getRequiredParamsByUseCaseName(useCase); + + if (!request.hasContent()) { + if (!requiredParams.isEmpty()) { + throw new FlowFrameworkException( + "Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(), + RestStatus.BAD_REQUEST + ); + } + } else { try { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); Map userDefaults = ParseUtils.parseStringToObjectMap(parser); + + // Validate user defaults key set + Set userDefaultKeys = userDefaults.keySet(); + if (!userDefaultKeys.containsAll(requiredParams)) { + requiredParams.removeAll(userDefaultKeys); + throw new FlowFrameworkException( + "Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(), + RestStatus.BAD_REQUEST + ); + } + // updates the default params with anything user has given that matches for (Map.Entry userDefaultsEntry : userDefaults.entrySet()) { String key = userDefaultsEntry.getKey(); @@ -141,13 +161,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } } catch (Exception ex) { - RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex); - String errorMessage = - "failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings"; - logger.error(errorMessage, ex); - throw new FlowFrameworkException(errorMessage, status); + if (ex instanceof FlowFrameworkException) { + throw ex; + } else { + RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex); + String errorMessage = + "failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings"; + logger.error(errorMessage, ex); + throw new FlowFrameworkException(errorMessage, status); + } } - } useCaseTemplateFileInStringFormat = (String) ParseUtils.conditionallySubstitute( diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 8cbabc95f..2eaff69b4 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -345,16 +345,26 @@ protected Response createWorkflow(RestClient client, Template template) throws E * Helper method to invoke the Create Workflow Rest Action without validation * @param client the rest client * @param useCase the usecase to create + * @param the required params * @throws Exception if the request fails * @return a rest response */ - protected Response createWorkflowWithUseCase(RestClient client, String useCase) throws Exception { + protected Response createWorkflowWithUseCase(RestClient client, String useCase, List params) throws Exception { + + StringBuilder sb = new StringBuilder(); + for (String param : params) { + sb.append('"').append(param).append("\" : \"\","); + } + if (!params.isEmpty()) { + sb.deleteCharAt(sb.length() - 1); + } + return TestHelpers.makeRequest( client, "POST", WORKFLOW_URI + "?validation=off&use_case=" + useCase, Collections.emptyMap(), - "{}", + "{" + sb.toString() + "}", null ); } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index ada8d5513..894ee3010 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -43,6 +43,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; +import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; @@ -406,7 +408,7 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception { public void testDefaultCohereUseCase() throws Exception { // Hit Create Workflow API with original template - Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy"); + Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); @@ -442,8 +444,18 @@ public void testDefaultCohereUseCase() throws Exception { } public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Exception { - // Hit Create Workflow API with original template - Response response = createWorkflowWithUseCase(client(), "semantic_search"); + // Hit Create Workflow API with original template without required params + ResponseException exception = expectThrows( + ResponseException.class, + () -> createWorkflowWithUseCase(client(), "semantic_search", Collections.emptyList()) + ); + assertTrue( + exception.getMessage() + .contains("Missing the following required parameters for use case [semantic_search] : [create_ingest_pipeline.model_id]") + ); + + // Pass in required params + Response response = createWorkflowWithUseCase(client(), "semantic_search", List.of(CREATE_INGEST_PIPELINE_MODEL_ID)); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); @@ -483,7 +495,11 @@ public void testAllDefaultUseCasesCreation() throws Exception { .collect(Collectors.toSet()); for (String useCaseName : allUseCaseNames) { - Response response = createWorkflowWithUseCase(client(), useCaseName); + Response response = createWorkflowWithUseCase( + client(), + useCaseName, + DefaultUseCases.getRequiredParamsByUseCaseName(useCaseName) + ); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 3381bbbec..b55d6b1f2 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -34,6 +34,7 @@ import java.util.Locale; import java.util.Map; +import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; @@ -140,7 +141,7 @@ public void testCreateWorkflowRequestWithUseCaseButNoProvision() throws Exceptio RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) .withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName())) - .withContent(new BytesArray(""), MediaTypeRegistry.JSON) + .withContent(new BytesArray("{\"" + CREATE_CONNECTOR_CREDENTIAL_KEY + "\":\"\"}"), MediaTypeRegistry.JSON) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); doAnswer(invocation -> { @@ -157,7 +158,7 @@ public void testCreateWorkflowRequestWithUseCaseAndContent() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) .withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName())) - .withContent(new BytesArray("{\"key\":\"step\"}"), MediaTypeRegistry.JSON) + .withContent(new BytesArray("{\"" + CREATE_CONNECTOR_CREDENTIAL_KEY + "\":\"step\"}"), MediaTypeRegistry.JSON) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); doAnswer(invocation -> {