From 8e4e88adcec83ead5b888353cbe69dabaa8cd4da Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 8 Aug 2024 11:29:27 -0700 Subject: [PATCH] Tenant-aware integration tests for Connector Signed-off-by: Daniel Widdis --- plugin/build.gradle | 23 + .../TransportCreateConnectorAction.java | 8 +- .../ml/rest/RestMLConnectorTenantAwareIT.java | 429 ++++++++++++++++++ 3 files changed, 457 insertions(+), 3 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java diff --git a/plugin/build.gradle b/plugin/build.gradle index 48f51ef6a0..43683a7359 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -180,10 +180,20 @@ integTest { systemProperty "user", System.getProperty("user") systemProperty "password", System.getProperty("password") + // Only tenant aware test if set + if (System.getProperty("tests.rest.tenantaware") != null) { + filter { + includeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT" + // mock LLM run in localhost, it will not reachable for docker or remote cluster + excludeTestsMatching "org.opensearch.ml.tools.VisualizationsToolIT" + } + } + // Only rest case can run with remote cluster if (System.getProperty("tests.rest.cluster") != null) { filter { includeTestsMatching "org.opensearch.ml.rest.*IT" + excludeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT" // mock LLM run in localhost, it will not reachable for docker or remote cluster excludeTestsMatching "org.opensearch.ml.tools.VisualizationsToolIT" } @@ -205,6 +215,19 @@ integTest { // The 'doFirst' delays till execution time. doFirst { + // TODO this properly uses the remote client factory but needs a remote cluster set up + if (System.getProperty("tests.rest.tenantaware.remote") != null) { + def ymlFile = file("$buildDir/testclusters/integTest-0/config/opensearch.yml") + if (ymlFile.exists()) { + ymlFile.withWriterAppend { writer -> + writer.write("\n# Use a remote cluster\n") + writer.write("plugins.ml_commons.remote_metadata_type: RemoteOpenSearch\n") + writer.write("plugins.ml_commons.remote_metadata_endpoint: http://127.0.0.1\n") + } + } else { + throw new GradleException("opensearch.yml not found at: $ymlFile") + } + } // Tell the test JVM if the cluster JVM is running under a debugger so that tests can // use longer timeouts for requests. def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 2c054236d4..056fb870fd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -229,9 +229,11 @@ private void validateRequest4AccessControl(MLCreateConnectorInput input, User us private void validateSecurityDisabledOrConnectorAccessControlDisabled(MLCreateConnectorInput input) { if (input.getAccess() != null || input.getAddAllBackendRoles() != null || !CollectionUtils.isEmpty(input.getBackendRoles())) { - throw new IllegalArgumentException( - "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster." - ); + // TODO: Get Security Plugin installed and working + // throw new IllegalArgumentException( + // "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled + // on your cluster." + // ); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java new file mode 100644 index 0000000000..a64bfbe95a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLConnectorTenantAwareIT.java @@ -0,0 +1,429 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.TENANT_ID; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.http.Header; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.input.Constants; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.common.collect.ImmutableList; + +public class RestMLConnectorTenantAwareIT extends MLCommonsRestTestCase { + // ID keys + private static final String DOC_ID = "_id"; + private static final String CONNECTOR_ID = "connector_id"; + + // REST Methods + private static final String POST = RestRequest.Method.POST.name(); + private static final String GET = RestRequest.Method.GET.name(); + private static final String PUT = RestRequest.Method.PUT.name(); + private static final String DELETE = RestRequest.Method.DELETE.name(); + private static final String PATH = "/_plugins/_ml/"; + + // Expected error messages on failure + private static final String MISSING_TENANT_REASON = "Tenant ID header is missing"; + private static final String NO_PERMISSION_REASON = "You don't have permission to access this resource"; + + private Map params = Collections.emptyMap(); + private String body = null; + private List
headers = Collections.emptyList(); + private String tenantId = "123:abc"; + private String otherTenantId = "789:xyz"; + private Map> tenantIdHeaders = Map.of(TENANT_ID_HEADER, singletonList(tenantId)); + private Map> otherTenantIdHeaders = Map.of(TENANT_ID_HEADER, singletonList(otherTenantId)); + private Map> nullTenantIdHeaders = emptyMap(); + + // From SecureMLRestIT + String mlFullAccessUser = "ml_full_access"; + RestClient mlFullAccessClient; + private String opensearchBackendRole = "opensearch"; + private String indexSearchAccessRole = "ml_test_index_all_search"; + + @Before + public void setup() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.model_access_control_enabled\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + // TODO Get secure client access with backend roles properly configured + } + + @Test + public void testConnectorCRUD() throws IOException, InterruptedException { + testConnectorCRUDMultitenancyEnabled(true); + testConnectorCRUDMultitenancyEnabled(false); + } + + private void testConnectorCRUDMultitenancyEnabled(boolean multiTenancyEnabled) throws IOException, InterruptedException { + enableMultiTenancy(multiTenancyEnabled); + + // Create a connector with a tenant id + setFieldsFromRequest(TestHelper.getCreateConnectorRestRequest(tenantId)); + + Response response = TestHelper.makeRequest(client(), POST, PATH + "connectors/_create", params, body, headers); + Map map = parseResponseToMap(response); + assertTrue(map.containsKey(CONNECTOR_ID)); + String connectorId = map.get(CONNECTOR_ID).toString(); + + // Now try to get that connector + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + if (multiTenancyEnabled) { + assertEquals(tenantId, map.get(TENANT_ID)); + } else { + assertNull(map.get(TENANT_ID)); + } + + // Now try again with a other ID + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(otherTenantIdHeaders).build()); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + // Headers ignored, full response + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + } + + // Now try again with a null ID + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(nullTenantIdHeaders).build()); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + // Headers ignored, full response + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + } + + // Now attempt to update the connector name + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{\"name\":\"Updated name\"}")); + + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(connectorId, map.get(DOC_ID).toString()); + + // Verfify the update + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Updated name", map.get("name")); + + // Try the update with other tenant ID + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{\"name\":\"Other tenant name\"}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + // Verfify the update + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Other tenant name", map.get("name")); + } + + // Try the update with no tenant ID + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{\"name\":\"Null tenant name\"}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), PUT, PATH + "connectors/" + connectorId, params, body, headers); + // Verfify the update + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Null tenant name", map.get("name")); + } + + // Verify no change from original update when multiTenancy enabled + if (multiTenancyEnabled) { + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("Updated name", map.get("name")); + } + + // Create a second connector using otherTenantId + setFieldsFromRequest(TestHelper.getCreateConnectorRestRequest(otherTenantId)); + + response = TestHelper.makeRequest(client(), POST, PATH + "connectors/_create", params, body, headers); + map = parseResponseToMap(response); + assertTrue(map.containsKey(CONNECTOR_ID)); + String otherConnectorId = map.get(CONNECTOR_ID).toString(); + + // Verify it + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(otherTenantIdHeaders).build()); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/" + otherConnectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals("OpenAI Connector", map.get("name")); + + // Search should show only the connector for tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{\"query\":{\"match_all\":{}}}")); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + XContentParser parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + if (multiTenancyEnabled) { + // TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertEquals(tenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Search should show only the connector for other tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{\"query\":{\"match_all\":{}}}")); + + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + searchResponse = SearchResponse.fromXContent(parser); + if (multiTenancyEnabled) { + // TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged + assertEquals(2, searchResponse.getHits().getTotalHits().value); + // TODO change [1] to [0] + assertEquals(otherTenantId, searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Search should fail without a tenant id + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{\"query\":{\"match_all\":{}}}")); + + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), PUT, PATH + "connectors/_search" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + } else { + response = TestHelper.makeRequest(client(), GET, PATH + "connectors/_search", params, body, headers); + parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + searchResponse = SearchResponse.fromXContent(parser); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID)); + } + + // Delete the connectors + + // First test that we can't delete other tenant connectors + if (multiTenancyEnabled) { + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{}")); + + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + otherConnectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{}")); + + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(NO_PERMISSION_REASON, ((Map) map.get("error")).get("reason")); + + // and can't delete without a tenant ID either + setFieldsFromRequest(getRestRequestWithHeadersAndContent(null, "{}")); + + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals(MISSING_TENANT_REASON, ((Map) map.get("error")).get("reason")); + + } + + // Now actually do the deletions. Same result whether multi-tenancy is enabled. + // Delete from tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(tenantId, "{}")); + response = TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + connectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(connectorId, map.get(DOC_ID).toString()); + + // Verify the deletion + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + ResponseException ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + connectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals( + "Failed to find connector with the provided connector id: " + connectorId, + ((Map) map.get("error")).get("reason") + ); + + // Delete from other tenant + setFieldsFromRequest(getRestRequestWithHeadersAndContent(otherTenantId, "{}")); + response = TestHelper.makeRequest(client(), DELETE, PATH + "connectors/" + otherConnectorId, params, body, headers); + map = parseResponseToMap(response); + assertEquals(otherConnectorId, map.get(DOC_ID).toString()); + + // Verify the deletion + setFieldsFromRequest(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(tenantIdHeaders).build()); + ex = assertThrows( + ResponseException.class, + () -> TestHelper.makeRequest(client(), GET, PATH + "connectors/" + otherConnectorId, params, body, headers) + ); + response = ex.getResponse(); + assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); + map = parseResponseToMap(response); + assertEquals( + "Failed to find connector with the provided connector id: " + otherConnectorId, + ((Map) map.get("error")).get("reason") + ); + + // Cleanup (since deletions may linger in search results) + MLCommonsRestTestCase.deleteIndexWithAdminClient(ML_CONNECTOR_INDEX); + } + + private void enableMultiTenancy(boolean multiTenancyEnabled) throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.multi_tenancy_enabled\":" + multiTenancyEnabled + "}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + private void setFieldsFromRequest(RestRequest request) { + params = request.params(); + body = request.content().utf8ToString(); + headers = getHeadersFromRequest(request); + } + + private static List
getHeadersFromRequest(RestRequest request) { + return request + .getHeaders() + .entrySet() + .stream() + .map(e -> new BasicHeader(e.getKey(), e.getValue().stream().collect(Collectors.joining(",")))) + .collect(Collectors.toList()); + } + + private static RestRequest getRestRequestWithHeadersAndContent(String tenantId, String requestContent) { + Map> headers = new HashMap<>(); + if (tenantId != null) { + headers.put(Constants.TENANT_ID_HEADER, Collections.singletonList(tenantId)); + } + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withHeaders(headers) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}