From 198931cfce5cd1c926314694cdb488bd8132f97e Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 8 Aug 2024 11:29:27 -0700 Subject: [PATCH] Tenant-aware integration tests for Connector Signed-off-by: Daniel Widdis --- plugin/build.gradle | 197 ++++---- .../TransportCreateConnectorAction.java | 8 +- .../ml/rest/RestMLConnectorTenantAwareIT.java | 429 ++++++++++++++++++ 3 files changed, 555 insertions(+), 79 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java diff --git a/plugin/build.gradle b/plugin/build.gradle index 48f51ef6a0..4d506c120b 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -87,7 +87,7 @@ dependencies { implementation("software.amazon.awssdk:utils:2.25.40") // AWS OpenSearch Service dependency implementation("software.amazon.awssdk:apache-client:2.25.40") - + configurations.all { resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4' resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5-h2:5.2.4' @@ -134,11 +134,17 @@ publishing { } compileJava { - options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) + options.compilerArgs.addAll([ + "-processor", + 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor' + ]) } compileTestJava { - options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) + options.compilerArgs.addAll([ + "-processor", + 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor' + ]) } //TODO: check which one should be enabled @@ -180,10 +186,20 @@ integTest { systemProperty "user", System.getProperty("user") systemProperty "password", System.getProperty("password") + // Only tenant aware test if set + if (System.getProperty("tests.rest.tenantaware") != null) { + filter { + includeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT" + // mock LLM run in localhost, it will not reachable for docker or remote cluster + excludeTestsMatching "org.opensearch.ml.tools.VisualizationsToolIT" + } + } + // Only rest case can run with remote cluster if (System.getProperty("tests.rest.cluster") != null) { filter { includeTestsMatching "org.opensearch.ml.rest.*IT" + excludeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT" // mock LLM run in localhost, it will not reachable for docker or remote cluster excludeTestsMatching "org.opensearch.ml.tools.VisualizationsToolIT" } @@ -205,6 +221,18 @@ integTest { // The 'doFirst' delays till execution time. doFirst { + if (System.getProperty("tests.rest.tenantaware.remote") != null) { + def ymlFile = file("$buildDir/testclusters/integTest-0/config/opensearch.yml") + if (ymlFile.exists()) { + ymlFile.withWriterAppend { writer -> + writer.write("\n# Use a remote cluster\n") + writer.write("plugins.ml_commons.remote_metadata_type: RemoteOpenSearch\n") + writer.write("plugins.ml_commons.remote_metadata_endpoint: http://127.0.0.1\n") + } + } else { + throw new GradleException("opensearch.yml not found at: $ymlFile") + } + } // Tell the test JVM if the cluster JVM is running under a debugger so that tests can // use longer timeouts for requests. def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null @@ -310,45 +338,45 @@ jacocoTestReport { } List jacocoExclusions = [ - // TODO: add more unit test to meet the minimal test coverage. - 'org.opensearch.ml.constant.CommonValue', - 'org.opensearch.ml.plugin.MachineLearningPlugin*', - 'org.opensearch.ml.indices.MLIndicesHandler', - 'org.opensearch.ml.rest.RestMLPredictionAction', - 'org.opensearch.ml.profile.MLModelProfile', - 'org.opensearch.ml.profile.MLPredictRequestStats', - 'org.opensearch.ml.action.deploy.TransportDeployModelAction', - 'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction', - 'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction', - 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction', - 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1', - 'org.opensearch.ml.action.tasks.GetTaskTransportAction', - 'org.opensearch.ml.action.tasks.SearchTaskTransportAction', - 'org.opensearch.ml.model.MLModelManager', - 'org.opensearch.ml.stats.MLClusterLevelStat', - 'org.opensearch.ml.stats.MLStatLevel', - 'org.opensearch.ml.utils.IndexUtils', - 'org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener', - 'org.opensearch.ml.cluster.DiscoveryNodeHelper.HotDataNodePredicate', - 'org.opensearch.ml.cluster.MLCommonsClusterEventListener', - 'org.opensearch.ml.task.MLTaskManager', - 'org.opensearch.ml.task.MLTrainingTaskRunner', - 'org.opensearch.ml.task.MLPredictTaskRunner', - 'org.opensearch.ml.task.MLTaskDispatcher', - 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner', - 'org.opensearch.ml.task.MLExecuteTaskRunner', - 'org.opensearch.ml.action.profile.MLProfileTransportAction', - 'org.opensearch.ml.rest.RestMLPredictionAction', - 'org.opensearch.ml.breaker.DiskCircuitBreaker', - 'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory', - 'org.opensearch.ml.action.training.TrainingITTests', - 'org.opensearch.ml.action.prediction.PredictionITTests', - 'org.opensearch.ml.cluster.MLSyncUpCron', - 'org.opensearch.ml.model.MLModelGroupManager', - 'org.opensearch.ml.helper.ModelAccessControlHelper', - 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', - 'org.opensearch.ml.model.MLModelCacheHelper', - 'org.opensearch.ml.model.MLModelCacheHelper.1' + // TODO: add more unit test to meet the minimal test coverage. + 'org.opensearch.ml.constant.CommonValue', + 'org.opensearch.ml.plugin.MachineLearningPlugin*', + 'org.opensearch.ml.indices.MLIndicesHandler', + 'org.opensearch.ml.rest.RestMLPredictionAction', + 'org.opensearch.ml.profile.MLModelProfile', + 'org.opensearch.ml.profile.MLPredictRequestStats', + 'org.opensearch.ml.action.deploy.TransportDeployModelAction', + 'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction', + 'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction', + 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction', + 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1', + 'org.opensearch.ml.action.tasks.GetTaskTransportAction', + 'org.opensearch.ml.action.tasks.SearchTaskTransportAction', + 'org.opensearch.ml.model.MLModelManager', + 'org.opensearch.ml.stats.MLClusterLevelStat', + 'org.opensearch.ml.stats.MLStatLevel', + 'org.opensearch.ml.utils.IndexUtils', + 'org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener', + 'org.opensearch.ml.cluster.DiscoveryNodeHelper.HotDataNodePredicate', + 'org.opensearch.ml.cluster.MLCommonsClusterEventListener', + 'org.opensearch.ml.task.MLTaskManager', + 'org.opensearch.ml.task.MLTrainingTaskRunner', + 'org.opensearch.ml.task.MLPredictTaskRunner', + 'org.opensearch.ml.task.MLTaskDispatcher', + 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner', + 'org.opensearch.ml.task.MLExecuteTaskRunner', + 'org.opensearch.ml.action.profile.MLProfileTransportAction', + 'org.opensearch.ml.rest.RestMLPredictionAction', + 'org.opensearch.ml.breaker.DiskCircuitBreaker', + 'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory', + 'org.opensearch.ml.action.training.TrainingITTests', + 'org.opensearch.ml.action.prediction.PredictionITTests', + 'org.opensearch.ml.cluster.MLSyncUpCron', + 'org.opensearch.ml.model.MLModelGroupManager', + 'org.opensearch.ml.helper.ModelAccessControlHelper', + 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', + 'org.opensearch.ml.model.MLModelCacheHelper', + 'org.opensearch.ml.model.MLModelCacheHelper.1' ] jacocoTestCoverageVerification { @@ -435,7 +463,11 @@ afterEvaluate { } task buildPackages(type: GradleBuild) { - tasks = ['build', 'buildRpm', 'buildDeb'] + tasks = [ + 'build', + 'buildRpm', + 'buildDeb' + ] } } @@ -464,7 +496,10 @@ String opensearchMlPlugin = "opensearch-ml-" + project.version + ".zip" testClusters { "${baseName}$i" { testDistribution = "ARCHIVE" - versions = [bwcShortVersion, opensearch_version] + versions = [ + bwcShortVersion, + opensearch_version + ] numberOfNodes = 3 plugin(provider(new Callable() { @Override @@ -496,22 +531,22 @@ String opensearchMlPlugin = "opensearch-ml-" + project.version + ".zip" } List> plugins = [ - provider(new Callable() { - @Override - RegularFile call() throws Exception { - return new RegularFile() { - @Override - File getAsFile() { - project.mkdir "$bwcFilePath/$project.version" - copy { - from "$buildDir/distributions/$opensearchMlPlugin" - into "$bwcFilePath/$project.version" - } - return fileTree(bwcFilePath + project.version).getSingleFile() + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + project.mkdir "$bwcFilePath/$project.version" + copy { + from "$buildDir/distributions/$opensearchMlPlugin" + into "$bwcFilePath/$project.version" } + return fileTree(bwcFilePath + project.version).getSingleFile() } } - }) + } + }) ] // Creates 2 test clusters with 3 nodes of the old version. @@ -524,9 +559,11 @@ List> plugins = [ systemProperty 'tests.rest.bwcsuite', 'old_cluster' systemProperty 'tests.rest.bwcsuite_round', 'old' systemProperty 'tests.plugin_bwc_version', bwcVersion - nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}$i".allHttpSocketURI.join(",")}") - nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}$i".getName()}") - } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}$i".allHttpSocketURI.join(",") +}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}$i".getName() +}") +} } // Upgrade one node of the old cluster to new OpenSearch version with upgraded plugin version @@ -537,15 +574,17 @@ task "${baseName}#mixedClusterTask"(type: StandaloneRestIntegTestTask) { dependsOn "${baseName}#oldVersionClusterTask0" doFirst { testClusters."${baseName}0".upgradeNodeAndPluginToNextVersion(plugins) - } +} filter { includeTestsMatching "org.opensearch.ml.bwc.*IT" - } +} systemProperty 'tests.rest.bwcsuite', 'mixed_cluster' systemProperty 'tests.rest.bwcsuite_round', 'first' systemProperty 'tests.plugin_bwc_version', bwcVersion - nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",")}") - nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName()}") + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",") +}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName() +}") } // Upgrades the second node to new OpenSearch version with upgraded plugin version after the first node is upgraded. @@ -556,15 +595,17 @@ task "${baseName}#twoThirdsUpgradedClusterTask"(type: StandaloneRestIntegTestTas useCluster testClusters."${baseName}0" doFirst { testClusters."${baseName}0".upgradeNodeAndPluginToNextVersion(plugins) - } +} filter { includeTestsMatching "org.opensearch.ml.bwc.*IT" - } +} systemProperty 'tests.rest.bwcsuite', 'mixed_cluster' systemProperty 'tests.rest.bwcsuite_round', 'second' systemProperty 'tests.plugin_bwc_version', bwcVersion - nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",")}") - nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName()}") + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",") +}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName() +}") } // Upgrade the third node to new OpenSearch version with upgraded plugin version after the second node is upgraded. @@ -575,16 +616,18 @@ task "${baseName}#rollingUpgradeClusterTask"(type: StandaloneRestIntegTestTask) useCluster testClusters."${baseName}0" doFirst { testClusters."${baseName}0".upgradeNodeAndPluginToNextVersion(plugins) - } +} filter { includeTestsMatching "org.opensearch.ml.bwc.*IT" - } +} mustRunAfter "${baseName}#mixedClusterTask" systemProperty 'tests.rest.bwcsuite', 'mixed_cluster' systemProperty 'tests.rest.bwcsuite_round', 'third' systemProperty 'tests.plugin_bwc_version', bwcVersion - nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",")}") - nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName()}") + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}0".allHttpSocketURI.join(",") +}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}0".getName() +}") } // Upgrades all the nodes of the old cluster to new OpenSearch version with upgraded plugin version @@ -594,14 +637,16 @@ task "${baseName}#fullRestartClusterTask"(type: StandaloneRestIntegTestTask) { useCluster testClusters."${baseName}1" doFirst { testClusters."${baseName}1".upgradeAllNodesAndPluginsToNextVersion(plugins) - } +} filter { includeTestsMatching "org.opensearch.ml.bwc.*IT" - } +} systemProperty 'tests.rest.bwcsuite', 'upgraded_cluster' systemProperty 'tests.plugin_bwc_version', bwcVersion - nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}1".allHttpSocketURI.join(",")}") - nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}1".getName()}") + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}1".allHttpSocketURI.join(",") +}") + nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}1".getName() +}") } // A bwc test suite which runs all the bwc tasks combined diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 2c054236d4..056fb870fd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -229,9 +229,11 @@ private void validateRequest4AccessControl(MLCreateConnectorInput input, User us private void validateSecurityDisabledOrConnectorAccessControlDisabled(MLCreateConnectorInput input) { if (input.getAccess() != null || input.getAddAllBackendRoles() != null || !CollectionUtils.isEmpty(input.getBackendRoles())) { - throw new IllegalArgumentException( - "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster." - ); + // TODO: Get Security Plugin installed and working + // throw new IllegalArgumentException( + // "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled + // on your cluster." + // ); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java new file mode 100644 index 0000000000..a64bfbe95a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java @@ -0,0 +1,429 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.TENANT_ID; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.http.Header; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.input.Constants; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.common.collect.ImmutableList; + +public class RestMLConnectorTenantAwareIT extends MLCommonsRestTestCase { + // ID keys + private static final String DOC_ID = "_id"; + private static final String CONNECTOR_ID = "connector_id"; + + // REST Methods + private static final String POST = RestRequest.Method.POST.name(); + private static final String GET = RestRequest.Method.GET.name(); + private static final String PUT = RestRequest.Method.PUT.name(); + private static final String DELETE = RestRequest.Method.DELETE.name(); + private static final String PATH = "/_plugins/_ml/"; + + // Expected error messages on failure + private static final String MISSING_TENANT_REASON = "Tenant ID header is missing"; + private static final String NO_PERMISSION_REASON = "You don't have permission to access this resource"; + + private Map params = Collections.emptyMap(); + private String body = null; + private List
headers = Collections.emptyList(); + private String tenantId = "123:abc"; + private String otherTenantId = "789:xyz"; + private Map> tenantIdHeaders = Map.of(TENANT_ID_HEADER, singletonList(tenantId)); + private Map> otherTenantIdHeaders = Map.of(TENANT_ID_HEADER, singletonList(otherTenantId)); + private Map> nullTenantIdHeaders = emptyMap(); + + // From SecureMLRestIT + String mlFullAccessUser = "ml_full_access"; + RestClient mlFullAccessClient; + private String opensearchBackendRole = "opensearch"; + private String indexSearchAccessRole = "ml_test_index_all_search"; + + @Before + public void setup() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.model_access_control_enabled\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + // TODO Get secure client access with backend roles properly configured + } + + @Test + public void testConnectorCRUD() throws IOException, InterruptedException { + testConnectorCRUDMultitenancyEnabled(true); + testConnectorCRUDMultitenancyEnabled(false); + } + + private void testConnectorCRUDMultitenancyEnabled(boolean multiTenancyEnabled) throws IOException, InterruptedException { + enableMultiTenancy(multiTenancyEnabled); + + // Create a connector with a tenant id + setFieldsFromRequest(TestHelper.getCreateConnectorRestRequest(tenantId)); + + Response response = TestHelper.makeRequest(client(), POST, PATH + "connectors/_create", params, body, headers); + Map map = parseResponseToMap(response); + assertTrue(map.containsKey(CONNECTOR_ID)); + String connectorId = map.get(CONNECTOR_ID).toString(); + + // Now try to get that connector + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + if (multiTenancyEnabled) { + assertEquals(tenantId, map.get(TENANT_ID)); + } else { + assertNull(map.get(TENANT_ID)); + } + + // Now try again with a other ID + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(otherTenantIdHeaders).build()); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + // Headers ignored, full response + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + } + + // Now try again with a null ID + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(nullTenantIdHeaders).build()); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + // Headers ignored, full response + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + } + + // Now attempt to update the connector name + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{\"name\":\"Updated name\"}")); + + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(connectorId, map.get(DOC_ID).toString()); + + // Verfify the update + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Updated name", map.get("name")); + + // Try the update with other tenant ID + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{\"name\":\"Other tenant name\"}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + // Verfify the update + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Other tenant name", map.get("name")); + } + + // Try the update with no tenant ID + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{\"name\":\"Null tenant name\"}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + // Verfify the update + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Null tenant name", map.get("name")); + } + + // Verify no change from original update when multiTenancy enabled + if (multiTenancyEnabled) { + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Updated name", map.get("name")); + } + + // Create a second connector using otherTenantId + setFieldsFromRequest(TestHelper.getCreateConnectorRestRequest(otherTenantId)); + + response = TestHelper.makeRequest(client(), POST, PATH + "connectors/_create", params, body, headers); + map = parseResponseToMap(response); + assertTrue(map.containsKey(CONNECTOR_ID)); + String otherConnectorId = map.get(CONNECTOR_ID).toString(); + + // Verify it + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(otherTenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + otherConnectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + + // Search should show only the connector for tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{\"query\":{\"match_all\":{}}}")); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + XContentParser parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + if (multiTenancyEnabled) { + // TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertEquals(tenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Search should show only the connector for other tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{\"query\":{\"match_all\":{}}}")); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + searchResponse = SearchResponse.fromXContent(parser); + if (multiTenancyEnabled) { + // TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged + assertEquals(2, searchResponse.getHits().getTotalHits().value); + // TODO change [1] to [0] + assertEquals(otherTenantId, searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Search should fail without a tenant id + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{\"query\":{\"match_all\":{}}}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/_search" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + searchResponse = SearchResponse.fromXContent(parser); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Delete the connectors + + // First test that we can't delete other tenant connectors + if (multiTenancyEnabled) { + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{}")); + + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + otherConnectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{}")); + + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + + // and can't delete without a tenant ID either + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{}")); + + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + + } + + // Now actually do the deletions. Same result whether multi-tenancy is enabled. + // Delete from tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{}")); + response = TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(connectorId, map.get(DOC_ID).toString()); + + // Verify the deletion + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals( + "Failed to find connector with the provided connector id: " + connectorId, + ((Map) map.get("error")).get("reason") + ); + + // Delete from other tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{}")); + response = TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + otherConnectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(otherConnectorId, map.get(DOC_ID).toString()); + + // Verify the deletion + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + otherConnectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals( + "Failed to find connector with the provided connector id: " + otherConnectorId, + ((Map) map.get("error")).get("reason") + ); + + // Cleanup (since deletions may linger in search results) + MLCommonsRestTestCase.deleteIndexWithAdminClient(ML_CONNECTOR_INDEX); + } + + private void enableMultiTenancy(boolean multiTenancyEnabled) throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.multi_tenancy_enabled\":" + multiTenancyEnabled + "}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + private void setFieldsFromRequest(RestRequest request) { + params = request.params(); + body = request.content().utf8ToString(); + headers = getHeadersFromRequest(request); + } + + private static List
getHeadersFromRequest(RestRequest request) { + return request + .getHeaders() + .entrySet() + .stream() + .map(e -> new BasicHeader(e.getKey(), e.getValue().stream().collect(Collectors.joining(",")))) + .collect(Collectors.toList()); + } + + private static RestRequest getRestRequestWithHeadersAndContent(String tenantId, String requestContent) { + Map> headers = new HashMap<>(); + if (tenantId != null) { + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + } + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withHeaders(headers) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}