Skip to content

Commit

Permalink
Fix test cases since the error message change
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 11, 2024
1 parent 6628933 commit ac3053a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,6 @@ public static ModelTensors processOutput(
return ModelTensors.builder().mlModelTensors(modelTensors).build();
}

public static ModelTensors processErrorResponse(String errorResponse) {
return ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build()))
.build();
}

private static String fillProcessFunctionParameter(Map<String, String> parameters, String processFunction) {
if (processFunction != null && processFunction.contains("${parameters.")) {
Map<String, String> tmpParameters = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processErrorResponse;
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

import java.nio.ByteBuffer;
Expand All @@ -26,6 +26,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
Expand Down Expand Up @@ -118,25 +119,24 @@ private void processResponse(
Map<String, String> parameters,
Map<Integer, ModelTensors> tensorOutputs
) {
ModelTensors tensors;
if (Strings.isBlank(body)) {
log.error("Remote model response body is empty!");
tensors = processErrorResponse("Remote model response is empty!");
actionListener.onFailure(new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST));
} else {
if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
log.error("Remote server returned error code: {}", statusCode);
tensors = processErrorResponse(body);
actionListener.onFailure(new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode)));
} else {
try {
tensors = processOutput(body, connector, scriptService, parameters);
ModelTensors tensors = processOutput(body, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.put(countDownLatch.getSequence(), tensors);
} catch (Exception e) {
log.error("Failed to process response body: {}", body, e);
tensors = processErrorResponse(body);
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
}
}
}
tensors.setStatusCode(statusCode);
tensorOutputs.put(countDownLatch.getSequence(), tensors);
}

private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {
private Map<Integer, ModelTensors> tensorOutputs = new ConcurrentHashMap<>();
private Connector connector;

private Connector noProcessFunctionConnector;

@Mock
private SdkHttpFullResponse sdkHttpResponse;
@Mock
Expand Down Expand Up @@ -77,6 +79,21 @@ public void setup() {
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();

ConnectorAction noProcessFunctionPredictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
noProcessFunctionConnector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(noProcessFunctionPredictAction))
.build();
mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
Expand All @@ -95,11 +112,23 @@ public void test_OnHeaders() {
}

@Test
public void test_OnStream() {
public void test_OnStream_with_postProcessFunction_bedRock() {
String response = "{\n"
+ " \"embedding\": [\n"
+ " 0.46484375,\n"
+ " -0.017822266,\n"
+ " 0.17382812,\n"
+ " 0.10595703,\n"
+ " 0.875,\n"
+ " 0.19140625,\n"
+ " -0.36914062,\n"
+ " -0.0011978149\n"
+ " ]\n"
+ "}";
Publisher<ByteBuffer> stream = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap("hello world".getBytes()));
s.onNext(ByteBuffer.wrap(response.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
Expand All @@ -110,7 +139,34 @@ public void test_OnStream() {
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener).onResponse(captor.capture());
assert captor.getValue().size() == 1;
assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("hello world");
assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8;
}

@Test
public void test_OnStream_without_postProcessFunction() {
Publisher<ByteBuffer> stream = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap("{\"key\": \"hello world\"}".getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
MLSdkAsyncHttpResponseHandler noProcessFunctionMlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
parameters,
tensorOutputs,
noProcessFunctionConnector,
scriptService
);
noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener).onResponse(captor.capture());
assert captor.getValue().size() == 1;
assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("key").equals("hello world");
}

@Test
Expand Down Expand Up @@ -150,7 +206,7 @@ public void test_MLResponseSubscriber_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Remote model response is empty!\"}");
assert captor.getValue().getMessage().equals("No response from model");
}

@Test
Expand Down Expand Up @@ -245,7 +301,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}");
assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
assert captor.getValue().status().getStatus() == 500;
}

Expand Down Expand Up @@ -311,7 +367,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}");
assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
assert captor.getValue().status().getStatus() == 500;
}

Expand All @@ -328,16 +384,9 @@ public void test_onComplete_empty_response_body() {
}
};
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener, times(1)).onResponse(captor.capture());
assert captor
.getValue()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap()
.get("remote_response")
.equals("Remote model response is empty!");
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("No response from model");
}

@Test
Expand Down

0 comments on commit ac3053a

Please sign in to comment.