-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add integration test covering bulk API
Signed-off-by: Daniel Widdis <[email protected]>
- Loading branch information
Showing
2 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
148 changes: 148 additions & 0 deletions
148
plugin/src/test/java/org/opensearch/ml/rest/RestMLModelUndeployTenantAwareIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.ml.common.CommonValue.TENANT_ID; | ||
import static org.opensearch.ml.common.MLModel.MODEL_STATE_FIELD; | ||
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; | ||
import static org.opensearch.ml.rest.RestMLRAGSearchProcessorIT.COHERE_CONNECTOR_BLUEPRINT; | ||
|
||
import java.util.Map; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import org.opensearch.client.Response; | ||
import org.opensearch.client.ResponseException; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
public class RestMLModelUndeployTenantAwareIT extends MLCommonsTenantAwareRestTestCase { | ||
|
||
// Tests the client.bulk API used for undeploying models | ||
public void testModelDeployUndeploy() throws Exception { | ||
boolean multiTenancyEnabled = isMultiTenancyEnabled(); | ||
|
||
/* | ||
* Setup | ||
*/ | ||
// Create a connector to use | ||
RestRequest createConnectorRequest = getRestRequestWithHeadersAndContent(tenantId, COHERE_CONNECTOR_BLUEPRINT); | ||
Response response = makeRequest(createConnectorRequest, POST, CONNECTORS_PATH + "_create"); | ||
assertOK(response); | ||
Map<String, Object> map = responseToMap(response); | ||
assertTrue(map.containsKey(CONNECTOR_ID)); | ||
String connectorId = map.get(CONNECTOR_ID).toString(); | ||
|
||
/* | ||
* Create | ||
*/ | ||
// Register and deploy a remote model with a tenant id | ||
RestRequest registerModelRequest = getRestRequestWithHeadersAndContent( | ||
tenantId, | ||
registerRemoteModelContent("test model", connectorId, null) | ||
); | ||
response = makeRequest(registerModelRequest, POST, MODELS_PATH + "_register?deploy=true"); | ||
assertOK(response); | ||
map = responseToMap(response); | ||
assertTrue(map.containsKey(MODEL_ID_FIELD)); | ||
String modelId = map.get(MODEL_ID_FIELD).toString(); | ||
|
||
/* | ||
* Get | ||
*/ | ||
// Now get that model and confirm it's deployed | ||
assertBusy(() -> { | ||
Response getResponse = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); | ||
assertOK(getResponse); | ||
Map<String, Object> responseMap = responseToMap(getResponse); | ||
assertEquals("DEPLOYED", responseMap.get(MODEL_STATE_FIELD).toString()); | ||
if (multiTenancyEnabled) { | ||
assertEquals(tenantId, responseMap.get(TENANT_ID)); | ||
} else { | ||
assertNull(responseMap.get(TENANT_ID)); | ||
} | ||
}, 20, TimeUnit.SECONDS); | ||
|
||
/* | ||
* Test delete/deploy interaction | ||
*/ | ||
// Attempt to delete, should fail because it's deployed | ||
ResponseException ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, DELETE, MODELS_PATH + modelId)); | ||
response = ex.getResponse(); | ||
assertBadRequest(response); | ||
map = responseToMap(response); | ||
assertEquals(DEPLOYED_REASON, getErrorReasonFromResponseMap(map)); | ||
|
||
// Verify still exists | ||
response = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); | ||
assertOK(response); | ||
|
||
/* | ||
* Undeploy | ||
*/ | ||
// Undeploy the model which uses the bulk API | ||
if (multiTenancyEnabled) { | ||
// Try with the wrong tenant | ||
ex = assertThrows(ResponseException.class, () -> makeRequest(otherTenantRequest, POST, MODELS_PATH + modelId + "/_undeploy")); | ||
response = ex.getResponse(); | ||
map = responseToMap(response); | ||
if (DDB) { | ||
assertNotFound(response); | ||
assertTrue(getErrorReasonFromResponseMap(map).startsWith("Failed to find model")); | ||
} else { | ||
assertForbidden(response); | ||
assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); | ||
} | ||
|
||
// Try with a null tenant | ||
ex = assertThrows(ResponseException.class, () -> makeRequest(nullTenantRequest, POST, MODELS_PATH + modelId + "/_undeploy")); | ||
response = ex.getResponse(); | ||
assertForbidden(response); | ||
map = responseToMap(response); | ||
assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); | ||
} | ||
|
||
// Now do with correct tenant | ||
response = makeRequest(tenantRequest, POST, MODELS_PATH + modelId + "/_undeploy"); | ||
assertOK(response); | ||
// This is an MLUndeployControllerNodeResponse | ||
map = responseToMap(response); | ||
// This map's keys are the nodes, and the values are a map with "stats" key | ||
// One of these is a map object with modelId as key and "undeployed" as value | ||
String expectedValue = modelId + "=undeployed"; | ||
assertTrue(map.toString().contains(expectedValue)); | ||
|
||
// Verify the undeploy update | ||
response = makeRequest(tenantRequest, GET, MODELS_PATH + modelId); | ||
assertOK(response); | ||
map = responseToMap(response); | ||
assertEquals("UNDEPLOYED", map.get(MODEL_STATE_FIELD).toString()); | ||
if (multiTenancyEnabled) { | ||
assertEquals(tenantId, map.get(TENANT_ID)); | ||
} else { | ||
assertNull(map.get(TENANT_ID)); | ||
} | ||
|
||
/* | ||
* Delete | ||
*/ | ||
// Delete, should now succeed because it's deployed | ||
response = makeRequest(tenantRequest, DELETE, MODELS_PATH + modelId); | ||
assertOK(response); | ||
map = responseToMap(response); | ||
assertEquals(modelId, map.get(DOC_ID).toString()); | ||
|
||
// Verify the deletion | ||
ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, GET, MODELS_PATH + modelId)); | ||
response = ex.getResponse(); | ||
assertNotFound(response); | ||
map = responseToMap(response); | ||
assertEquals("Failed to find model with the provided model id: " + modelId, getErrorReasonFromResponseMap(map)); | ||
|
||
/* | ||
* Cleanup other resources created | ||
*/ | ||
deleteAndWaitForSearch(tenantId, CONNECTORS_PATH, connectorId, 0); | ||
} | ||
} |