From 6eef35687fd27bd3f3919bb040640c808456dd69 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 7 Nov 2024 13:43:05 -0800 Subject: [PATCH] Add integration test covering bulk API Signed-off-by: Daniel Widdis --- .../MLCommonsTenantAwareRestTestCase.java | 6 + .../RestMLModelUndeployTenantAwareIT.java | 148 ++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLModelUndeployTenantAwareIT.java diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsTenantAwareRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsTenantAwareRestTestCase.java index 15010b96e5..9f2353e846 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsTenantAwareRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsTenantAwareRestTestCase.java @@ -62,6 +62,8 @@ public abstract class MLCommonsTenantAwareRestTestCase extends MLCommonsRestTest // REST Response error reasons protected static final String MISSING_TENANT_REASON = "Tenant ID header is missing"; protected static final String NO_PERMISSION_REASON = "You don't have permission to access this resource"; + protected static final String DEPLOYED_REASON = + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"; // Common constants and fields used in subclasses protected static final String CONNECTOR_ID = "connector_id"; @@ -167,6 +169,10 @@ protected static SearchResponse searchResponseFromResponse(Response response) th return SearchResponse.fromXContent(parser); } + protected static void assertBadRequest(Response response) { + assertEquals(RestStatus.BAD_REQUEST.getStatus(), response.getStatusLine().getStatusCode()); + } + protected static void assertNotFound(Response response) { assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLModelUndeployTenantAwareIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLModelUndeployTenantAwareIT.java new file mode 100644 index 0000000000..bb54ee1882 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLModelUndeployTenantAwareIT.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.CommonValue.TENANT_ID; +import static org.opensearch.ml.common.MLModel.MODEL_STATE_FIELD; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.rest.RestMLRAGSearchProcessorIT.COHERE_CONNECTOR_BLUEPRINT; + +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.rest.RestRequest; + +public class RestMLModelUndeployTenantAwareIT extends MLCommonsTenantAwareRestTestCase { + + // Tests the client.bulk API used for undeploying models + public void testModelDeployUndeploy() throws Exception { + boolean multiTenancyEnabled = isMultiTenancyEnabled(); + + /* + * Setup + */ + // Create a connector to use + RestRequest createConnectorRequest = getRestRequestWithHeadersAndContent(tenantId, COHERE_CONNECTOR_BLUEPRINT); + Response response = makeRequest(createConnectorRequest, POST, CONNECTORS_PATH + "_create"); + assertOK(response); + Map map = responseToMap(response); + assertTrue(map.containsKey(CONNECTOR_ID)); + String connectorId = map.get(CONNECTOR_ID).toString(); + + /* + * Create + */ + // Register and deploy a remote model with a tenant id + RestRequest registerModelRequest = getRestRequestWithHeadersAndContent( + tenantId, + registerRemoteModelContent("test model", connectorId, null) + ); + response = makeRequest(registerModelRequest, POST, MODELS_PATH + "_register?deploy=true"); + assertOK(response); + map = responseToMap(response); + assertTrue(map.containsKey(MODEL_ID_FIELD)); + String modelId = map.get(MODEL_ID_FIELD).toString(); + + /* + * Get + */ + // Now get that model and confirm it's deployed + assertBusy(() -> { + Response getResponse = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); + assertOK(getResponse); + Map responseMap = responseToMap(getResponse); + assertEquals("DEPLOYED", responseMap.get(MODEL_STATE_FIELD).toString()); + if (multiTenancyEnabled) { + assertEquals(tenantId, responseMap.get(TENANT_ID)); + } else { + assertNull(responseMap.get(TENANT_ID)); + } + }, 20, TimeUnit.SECONDS); + + /* + * Test delete/deploy interaction + */ + // Attempt to delete, should fail because it's deployed + ResponseException ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, DELETE, MODELS_PATH + modelId)); + response = ex.getResponse(); + assertBadRequest(response); + map = responseToMap(response); + assertEquals(DEPLOYED_REASON, getErrorReasonFromResponseMap(map)); + + // Verify still exists + response = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); + assertOK(response); + + /* + * Undeploy + */ + // Undeploy the model which uses the bulk API + if (multiTenancyEnabled) { + // Try with the wrong tenant + ex = assertThrows(ResponseException.class, () -> makeRequest(otherTenantRequest, POST, MODELS_PATH + modelId + "/_undeploy")); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertTrue(getErrorReasonFromResponseMap(map).startsWith("Failed to find model")); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + + // Try with a null tenant + ex = assertThrows(ResponseException.class, () -> makeRequest(nullTenantRequest, POST, MODELS_PATH + modelId + "/_undeploy")); + response = ex.getResponse(); + assertForbidden(response); + map = responseToMap(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } + + // Now do with correct tenant + response = makeRequest(tenantRequest, POST, MODELS_PATH + modelId + "/_undeploy"); + assertOK(response); + // This is an MLUndeployControllerNodeResponse + map = responseToMap(response); + // This map's keys are the nodes, and the values are a map with "stats" key + // One of these is a map object with modelId as key and "undeployed" as value + String expectedValue = modelId + "=undeployed"; + assertTrue(map.toString().contains(expectedValue)); + + // Verify the undeploy update + response = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); + assertOK(response); + map = responseToMap(response); + assertEquals("UNDEPLOYED", map.get(MODEL_STATE_FIELD).toString()); + if (multiTenancyEnabled) { + assertEquals(tenantId, map.get(TENANT_ID)); + } else { + assertNull(map.get(TENANT_ID)); + } + + /* + * Delete + */ + // Delete, should now succeed because it's deployed + response = makeRequest(tenantRequest, DELETE, MODELS_PATH + modelId); + assertOK(response); + map = responseToMap(response); + assertEquals(modelId, map.get(DOC_ID).toString()); + + // Verify the deletion + ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, GET, MODELS_PATH + modelId)); + response = ex.getResponse(); + assertNotFound(response); + map = responseToMap(response); + assertEquals("Failed to find model with the provided model id: " + modelId, getErrorReasonFromResponseMap(map)); + + /* + * Cleanup other resources created + */ + deleteAndWaitForSearch(tenantId, CONNECTORS_PATH, connectorId, 0); + } +}