diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 9ca4f390b1..32698c81eb 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -237,6 +237,70 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { assertFalse(((String) responseMap.get("text")).isEmpty()); } + public void testPredictRemoteModelWithInterface(String testCase, Consumer verifyResponse, Consumer verifyException) + throws IOException, + InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; + try { + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + verifyResponse.accept(responseMap); + } catch (Exception e) { + verifyException.accept(e); + } + } + + public void testPredictRemoteModelWithCorrectInterface() throws IOException, InterruptedException { + testPredictRemoteModelWithInterface("correctInterface", (responseMap) -> { + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + }, null); + } + + public void testPredictRemoteModelWithWrongInputInterface() throws IOException, InterruptedException { + testPredictRemoteModelWithInterface("wrongInputInterface", null, (exception) -> { + assertTrue(exception instanceof org.opensearch.client.ResponseException); + String stackTrace = ExceptionUtils.getStackTrace(exception); + assertTrue(stackTrace.contains("Error validating input schema")); + }); + } + + public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, InterruptedException { + testPredictRemoteModelWithInterface("wrongOutputInterface", null, (exception) -> { + assertTrue(exception instanceof org.opensearch.client.ResponseException); + String stackTrace = ExceptionUtils.getStackTrace(exception); + assertTrue(stackTrace.contains("Error validating output schema")); + }); + } + public void testUndeployRemoteModel() throws IOException, InterruptedException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); @@ -777,6 +841,183 @@ public static Response registerRemoteModel(String name, String connectorId) thro .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + final String openaiConnectorEntityWithCorrectInterface = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"interface\": {\n" + + " \"input\": {\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"properties\": {\n" + + " \"prompt\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"output\": {\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"This is a test description field\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"This is a test description field\"\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\",\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + final String openaiConnectorEntityWithWrongInputInterface = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"interface\": {\n" + + " \"input\": {\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"properties\": {\n" + + " \"prompt\": {\n" + + " \"type\": \"integer\",\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + final String openaiConnectorEntityWithWrongOutputInterface = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"interface\": {\n" + + " \"output\": {\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"integer\",\n" + + " \"description\": \"This is a test description field\"\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\",\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"This is a test description field\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + switch (testCase) { + case "correctInterface": + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_register", + null, + TestHelper.toHttpEntity(openaiConnectorEntityWithCorrectInterface), + null + ); + case "wrongInputInterface": + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_register", + null, + TestHelper.toHttpEntity(openaiConnectorEntityWithWrongInputInterface), + null + ); + case "wrongOutputInterface": + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_register", + null, + TestHelper.toHttpEntity(openaiConnectorEntityWithWrongOutputInterface), + null + ); + default: + throw new IllegalArgumentException("Invalid test case"); + } + } + public static Response deployRemoteModel(String modelId) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); } @@ -831,4 +1072,15 @@ public String registerRemoteModel() throws IOException { logger.info("task ID created: {}", taskId); return taskId; } + + public String registerRemoteModelWithInterface(String testCase) throws IOException { + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + logger.info("task ID created: {}", taskId); + return taskId; + } }