From ac3053aa747ad55f9fed8ad9817ea2dd26ba4569 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 11 Apr 2024 20:03:23 +0800 Subject: [PATCH] Fix test cases since the error message change Signed-off-by: zane-neo --- .../algorithms/remote/ConnectorUtils.java | 7 -- .../remote/MLSdkAsyncHttpResponseHandler.java | 16 ++-- .../MLSdkAsyncHttpResponseHandlerTest.java | 81 +++++++++++++++---- 3 files changed, 73 insertions(+), 31 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 8b6d310283..af788c8b34 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -225,13 +225,6 @@ public static ModelTensors processOutput( return ModelTensors.builder().mlModelTensors(modelTensors).build(); } - public static ModelTensors processErrorResponse(String errorResponse) { - return ModelTensors - .builder() - .mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build())) - .build(); - } - private static String fillProcessFunctionParameter(Map parameters, String processFunction) { if (processFunction != null && processFunction.contains("${parameters.")) { Map tmpParameters = new HashMap<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 388ba7bbcf..d647633120 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -7,7 +7,7 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processErrorResponse; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; import java.nio.ByteBuffer; @@ -26,6 +26,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.script.ScriptService; import org.reactivestreams.Publisher; @@ -118,25 +119,24 @@ private void processResponse( Map parameters, Map tensorOutputs ) { - ModelTensors tensors; if (Strings.isBlank(body)) { log.error("Remote model response body is empty!"); - tensors = processErrorResponse("Remote model response is empty!"); + actionListener.onFailure(new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST)); } else { if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { log.error("Remote server returned error code: {}", statusCode); - tensors = processErrorResponse(body); + actionListener.onFailure(new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode))); } else { try { - tensors = processOutput(body, connector, scriptService, parameters); + ModelTensors tensors = processOutput(body, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); + tensorOutputs.put(countDownLatch.getSequence(), tensors); } catch (Exception e) { log.error("Failed to process response body: {}", body, e); - tensors = processErrorResponse(body); + actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); } } } - tensors.setStatusCode(statusCode); - tensorOutputs.put(countDownLatch.getSequence(), tensors); } private void reOrderTensorResponses(Map tensorOutputs) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index 5635a8b4fb..b7c598d6bd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -49,6 +49,8 @@ public class MLSdkAsyncHttpResponseHandlerTest { private Map tensorOutputs = new ConcurrentHashMap<>(); private Connector connector; + private Connector noProcessFunctionConnector; + @Mock private SdkHttpFullResponse sdkHttpResponse; @Mock @@ -77,6 +79,21 @@ public void setup() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); + + ConnectorAction noProcessFunctionPredictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + noProcessFunctionConnector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(noProcessFunctionPredictAction)) + .build(); mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( countDownLatch, actionListener, @@ -95,11 +112,23 @@ public void test_OnHeaders() { } @Test - public void test_OnStream() { + public void test_OnStream_with_postProcessFunction_bedRock() { + String response = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; Publisher stream = s -> { try { s.onSubscribe(mock(Subscription.class)); - s.onNext(ByteBuffer.wrap("hello world".getBytes())); + s.onNext(ByteBuffer.wrap(response.getBytes())); s.onComplete(); } catch (Throwable e) { s.onError(e); @@ -110,7 +139,34 @@ public void test_OnStream() { ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); verify(actionListener).onResponse(captor.capture()); assert captor.getValue().size() == 1; - assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("hello world"); + assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8; + } + + @Test + public void test_OnStream_without_postProcessFunction() { + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap("{\"key\": \"hello world\"}".getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + MLSdkAsyncHttpResponseHandler noProcessFunctionMlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( + countDownLatch, + actionListener, + parameters, + tensorOutputs, + noProcessFunctionConnector, + scriptService + ); + noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(actionListener).onResponse(captor.capture()); + assert captor.getValue().size() == 1; + assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("key").equals("hello world"); } @Test @@ -150,7 +206,7 @@ public void test_MLResponseSubscriber_onError() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - assert captor.getValue().getMessage().equals("{\"remote_response\":\"Remote model response is empty!\"}"); + assert captor.getValue().getMessage().equals("No response from model"); } @Test @@ -245,7 +301,7 @@ public void test_onComplete_partial_success_exceptionSecond() { mlSdkAsyncHttpResponseHandler2.onStream(stream2); ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}"); + assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED"); assert captor.getValue().status().getStatus() == 500; } @@ -311,7 +367,7 @@ public void test_onComplete_partial_success_exceptionFirst() { mlSdkAsyncHttpResponseHandler1.onStream(stream1); ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}"); + assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED"); assert captor.getValue().status().getStatus() == 500; } @@ -328,16 +384,9 @@ public void test_onComplete_empty_response_body() { } }; mlSdkAsyncHttpResponseHandler.onStream(stream); - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(actionListener, times(1)).onResponse(captor.capture()); - assert captor - .getValue() - .get(0) - .getMlModelTensors() - .get(0) - .getDataAsMap() - .get("remote_response") - .equals("Remote model response is empty!"); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("No response from model"); } @Test