From 1a659c8048dd23cdab3ab881622bf8cd355e1631 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Fri, 6 Dec 2024 15:44:11 -0800 Subject: [PATCH] Fixes Two Flaky IT classes RestMLGuardrailsIT & ToolIntegrationWithLLMTest (#3253) * fix uneeded call to get model_id for task api within RestMLGuardrailsIT Following #3244 this IT called the task api to check the model id again however this is redundant. Instead one can directly pull the model_id upon creating the model group. Manual testing was done to see that the behavior is intact, this should help reduce the calls within a IT to make it less flaky Signed-off-by: Brian Flores * fix ToolIntegrationWithLLMTest model undeploy race condition Previously the test class attempted to delete a model without fully knowing if the model was undeployed in time. This change adds a waiting for 5 retries each 1 second to check the status of the model and only when undeployed will it proceed to delete the model. When the number of retries are exceeded it throws a error indicating a deeper problem. Manual testing was done to check that the model is undeployed by searching for the specific model via the checkForModelUndeployedStatus method. Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores --- .../ml/rest/MLCommonsRestTestCase.java | 2 +- .../ml/rest/RestMLGuardrailsIT.java | 63 +++++++++---------- .../ml/tools/ToolIntegrationWithLLMTest.java | 23 ++++--- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 6f710ea1de..6871747fc7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -977,7 +977,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt } return taskDone.get(); }, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS); - assertTrue(taskDone.get()); + assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get()); } public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index fbabf1dbb7..bd1c9536bf 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); @@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("choices"); + if (responseList == null) { assertTrue(checkThrottlingOpenAI(responseMap)); return; @@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); @@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); } @@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "accept". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, acceptRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Predict. predictInput = "{\n" + " \"parameters\": {\n" @@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + + // Create the model ID response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "reject". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, rejectRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index b46d2d6a3b..a9b3f84bfc 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -9,7 +9,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.function.Predicate; @@ -32,7 +31,7 @@ @Log4j2 public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT { - private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 30; + private static final int MAX_RETRIES = 5; private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; protected HttpServer server; @@ -72,16 +71,17 @@ public void stopMockLLM() { @After public void deleteModel() throws IOException { undeployModel(modelId); + checkForModelUndeployedStatus(modelId); deleteModel(client(), modelId, null); } @SneakyThrows - private void waitModelUndeployed(String modelId) { + private void checkForModelUndeployedStatus(String modelId) { Predicate condition = response -> { try { Map responseInMap = parseResponseToMap(response); MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString()); - return Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED).contains(state); + return MLModelState.UNDEPLOYED.equals(state); } catch (Exception e) { return false; } @@ -91,16 +91,25 @@ private void waitModelUndeployed(String modelId) { @SneakyThrows protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate condition) { - for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) { + for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); if (condition.test(response)) { return response; } - logger.info("The {}-th response: {}", i, response.toString()); + logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString()); Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); } - fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds."); + fail( + String + .format( + Locale.ROOT, + "The response failed to meet condition after %d attempts. Attempted to perform %s : %s", + MAX_RETRIES, + method, + endpoint + ) + ); return null; }