diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 28ed1bc200..045d24166f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -16,6 +16,7 @@ import java.util.List; import java.util.Map; +import org.apache.commons.collections.MapUtils; import org.apache.http.HttpStatus; import org.apache.logging.log4j.util.Strings; import org.opensearch.OpenSearchStatusException; @@ -38,6 +39,7 @@ @Log4j2 public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler { + public static final String AMZ_ERROR_HEADER = "x-amzn-ErrorType"; @Getter private Integer statusCode; @Getter @@ -80,6 +82,10 @@ public void onHeaders(SdkHttpResponse response) { SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; log.debug("received response headers: " + sdkResponse.headers()); this.statusCode = sdkResponse.statusCode(); + if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { + handleThrottlingInHeader(sdkResponse); + // add more handling here for other exceptions in headers + } } @Override @@ -95,6 +101,31 @@ public void onError(Throwable error) { actionListener.onFailure(new OpenSearchStatusException(errorMessage, status)); } + private void handleThrottlingInHeader(SdkHttpFullResponse sdkResponse) { + if (MapUtils.isEmpty(sdkResponse.headers())) { + return; + } + List errorsInHeader = sdkResponse.headers().get(AMZ_ERROR_HEADER); + if (errorsInHeader == null || errorsInHeader.isEmpty()) { + return; + } + // Check the throttling exception from AMZN servers, e.g. sageMaker. + // See [https://github.com/opensearch-project/ml-commons/issues/2429] for more details. + boolean containsThrottlingException = errorsInHeader.stream().anyMatch(str -> str.startsWith("ThrottlingException")); + if (containsThrottlingException && executionContext.getExceptionHolder().get() == null) { + log.error("Remote server returned error code: {}", statusCode); + executionContext + .getExceptionHolder() + .compareAndSet( + null, + new OpenSearchStatusException( + REMOTE_SERVICE_ERROR + "The request was denied due to remote server throttling.", + RestStatus.fromCode(statusCode) + ) + ); + } + } + private void processResponse( Integer statusCode, String body, diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index 68b5cdeb5f..4ac156c0ee 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -11,6 +11,8 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler.AMZ_ERROR_HEADER; import java.nio.ByteBuffer; import java.util.Arrays; @@ -51,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest { private Connector noProcessFunctionConnector; + private Map> headersMap; + @Mock private SdkHttpFullResponse sdkHttpResponse; @Mock @@ -104,6 +108,7 @@ public void setup() { null ); responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber(); + headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!")); } @Test @@ -112,6 +117,13 @@ public void test_OnHeaders() { assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 200; } + @Test + public void test_OnHeaders_withError() { + when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST); + mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 400; + } + @Test public void test_OnStream_with_postProcessFunction_bedRock() { String response = "{\n" @@ -419,4 +431,168 @@ public void test_onComplete_error_http_status() { System.out.println(captor.getValue().getMessage()); assert captor.getValue().getMessage().contains("runtime error"); } + + @Test + public void test_onComplete_throttle_error_headers() { + String error = "{\"message\": null}"; + SdkHttpResponse response = mock(SdkHttpFullResponse.class); + when(response.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST); + when(response.headers()).thenReturn(headersMap); + mlSdkAsyncHttpResponseHandler.onHeaders(response); + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(error.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler.onStream(stream); + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof OpenSearchStatusException; + System.out.println(captor.getValue().getMessage()); + assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR); + } + + @Test + public void test_onComplete_throttle_exceptionFirst() { + AtomicReference exceptionHolder = new AtomicReference<>(); + String response1 = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + String response2 = "{\"message\": null}"; + CountDownLatch count = new CountDownLatch(2); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(0, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(1, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + + SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST); + when(sdkHttpResponse2.headers()).thenReturn(headersMap); + mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2); + Publisher stream2 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response2.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler2.onStream(stream2); + + SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse1.statusCode()).thenReturn(200); + mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1); + Publisher stream1 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response1.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler1.onStream(stream1); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling."); + assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST; + } + + @Test + public void test_onComplete_throttle_exceptionSecond() { + AtomicReference exceptionHolder = new AtomicReference<>(); + String response1 = "{\n" + + " \"embedding\": [\n" + + " 0.46484375,\n" + + " -0.017822266,\n" + + " 0.17382812,\n" + + " 0.10595703,\n" + + " 0.875,\n" + + " 0.19140625,\n" + + " -0.36914062,\n" + + " -0.0011978149\n" + + " ]\n" + + "}"; + String response2 = "{\"message\": null}"; + CountDownLatch count = new CountDownLatch(2); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(0, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler( + new ExecutionContext(1, count, exceptionHolder), + actionListener, + parameters, + tensorOutputs, + connector, + scriptService, + null + ); + SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse1.statusCode()).thenReturn(200); + mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1); + Publisher stream1 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response1.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler1.onStream(stream1); + + SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class); + when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST); + when(sdkHttpResponse2.headers()).thenReturn(headersMap); + mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2); + Publisher stream2 = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(response2.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler2.onStream(stream2); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling."); + assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST; + } + }