Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
(cherry picked from commit a98dbbf)

Co-authored-by: Sicheng Song <[email protected]>
  • Loading branch information
2 people authored and dhrubo-os committed May 17, 2024
1 parent 3e124c5 commit 78c8c1e
Showing 1 changed file with 252 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,70 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
assertFalse(((String) responseMap.get("text")).isEmpty());
}

public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
throws IOException,
InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
try {
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
verifyResponse.accept(responseMap);
} catch (Exception e) {
verifyException.accept(e);
}
}

public void testPredictRemoteModelWithCorrectInterface() throws IOException, InterruptedException {
testPredictRemoteModelWithInterface("correctInterface", (responseMap) -> {
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");
if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
}, null);
}

public void testPredictRemoteModelWithWrongInputInterface() throws IOException, InterruptedException {
testPredictRemoteModelWithInterface("wrongInputInterface", null, (exception) -> {
assertTrue(exception instanceof org.opensearch.client.ResponseException);
String stackTrace = ExceptionUtils.getStackTrace(exception);
assertTrue(stackTrace.contains("Error validating input schema"));
});
}

public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, InterruptedException {
testPredictRemoteModelWithInterface("wrongOutputInterface", null, (exception) -> {
assertTrue(exception instanceof org.opensearch.client.ResponseException);
String stackTrace = ExceptionUtils.getStackTrace(exception);
assertTrue(stackTrace.contains("Error validating output schema"));
});
}

public void testUndeployRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
Expand Down Expand Up @@ -777,6 +841,183 @@ public static Response registerRemoteModel(String name, String connectorId) thro
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/model_groups/_register",
null,
TestHelper.toHttpEntity(registerModelGroupEntity),
null
);
Map responseMap = parseResponseToMap(response);
assertEquals((String) responseMap.get("status"), "CREATED");
String modelGroupId = (String) responseMap.get("model_group_id");

final String openaiConnectorEntityWithCorrectInterface = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {\n"
+ " \"properties\": {\n"
+ " \"parameters\": {\n"
+ " \"properties\": {\n"
+ " \"prompt\": {\n"
+ " \"type\": \"string\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"output\": {\n"
+ " \"properties\": {\n"
+ " \"inference_results\": {\n"
+ " \"type\": \"array\",\n"
+ " \"items\": {\n"
+ " \"type\": \"object\",\n"
+ " \"properties\": {\n"
+ " \"output\": {\n"
+ " \"type\": \"array\",\n"
+ " \"items\": {\n"
+ " \"properties\": {\n"
+ " \"name\": {\n"
+ " \"type\": \"string\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " },\n"
+ " \"dataAsMap\": {\n"
+ " \"type\": \"object\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"description\": \"This is a test description field\"\n"
+ " },\n"
+ " \"status_code\": {\n"
+ " \"type\": \"integer\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

final String openaiConnectorEntityWithWrongInputInterface = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {\n"
+ " \"properties\": {\n"
+ " \"parameters\": {\n"
+ " \"properties\": {\n"
+ " \"prompt\": {\n"
+ " \"type\": \"integer\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

final String openaiConnectorEntityWithWrongOutputInterface = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"interface\": {\n"
+ " \"output\": {\n"
+ " \"properties\": {\n"
+ " \"inference_results\": {\n"
+ " \"type\": \"array\",\n"
+ " \"items\": {\n"
+ " \"type\": \"object\",\n"
+ " \"properties\": {\n"
+ " \"output\": {\n"
+ " \"type\": \"integer\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " },\n"
+ " \"status_code\": {\n"
+ " \"type\": \"integer\",\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"description\": \"This is a test description field\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

switch (testCase) {
case "correctInterface":
return TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/models/_register",
null,
TestHelper.toHttpEntity(openaiConnectorEntityWithCorrectInterface),
null
);
case "wrongInputInterface":
return TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/models/_register",
null,
TestHelper.toHttpEntity(openaiConnectorEntityWithWrongInputInterface),
null
);
case "wrongOutputInterface":
return TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/models/_register",
null,
TestHelper.toHttpEntity(openaiConnectorEntityWithWrongOutputInterface),
null
);
default:
throw new IllegalArgumentException("Invalid test case");
}
}

public static Response deployRemoteModel(String modelId) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null);
}
Expand Down Expand Up @@ -831,4 +1072,15 @@ public String registerRemoteModel() throws IOException {
logger.info("task ID created: {}", taskId);
return taskId;
}

public String registerRemoteModelWithInterface(String testCase) throws IOException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
logger.info("task ID created: {}", taskId);
return taskId;
}
}

0 comments on commit 78c8c1e

Please sign in to comment.