Skip to content

Commit

Permalink
hanlde the throttling error in the response header (#2442)
Browse files Browse the repository at this point in the history
* hanlde the throttling error in the response header

Signed-off-by: Xun Zhang <[email protected]>

* address comments and UT

Signed-off-by: Xun Zhang <[email protected]>

* more UTs

Signed-off-by: Xun Zhang <[email protected]>

* add more comments

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored May 17, 2024
1 parent d74c623 commit 7add721
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.List;
import java.util.Map;

import org.apache.commons.collections.MapUtils;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
Expand All @@ -38,6 +39,7 @@

@Log4j2
public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler {
public static final String AMZ_ERROR_HEADER = "x-amzn-ErrorType";
@Getter
private Integer statusCode;
@Getter
Expand Down Expand Up @@ -80,6 +82,10 @@ public void onHeaders(SdkHttpResponse response) {
SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response;
log.debug("received response headers: " + sdkResponse.headers());
this.statusCode = sdkResponse.statusCode();
if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
handleThrottlingInHeader(sdkResponse);
// add more handling here for other exceptions in headers
}
}

@Override
Expand All @@ -95,6 +101,31 @@ public void onError(Throwable error) {
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
}

private void handleThrottlingInHeader(SdkHttpFullResponse sdkResponse) {
if (MapUtils.isEmpty(sdkResponse.headers())) {
return;
}
List<String> errorsInHeader = sdkResponse.headers().get(AMZ_ERROR_HEADER);
if (errorsInHeader == null || errorsInHeader.isEmpty()) {
return;
}
// Check the throttling exception from AMZN servers, e.g. sageMaker.
// See [https://github.com/opensearch-project/ml-commons/issues/2429] for more details.
boolean containsThrottlingException = errorsInHeader.stream().anyMatch(str -> str.startsWith("ThrottlingException"));
if (containsThrottlingException && executionContext.getExceptionHolder().get() == null) {
log.error("Remote server returned error code: {}", statusCode);
executionContext
.getExceptionHolder()
.compareAndSet(
null,
new OpenSearchStatusException(
REMOTE_SERVICE_ERROR + "The request was denied due to remote server throttling.",
RestStatus.fromCode(statusCode)
)
);
}
}

private void processResponse(
Integer statusCode,
String body,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler.AMZ_ERROR_HEADER;

import java.nio.ByteBuffer;
import java.util.Arrays;
Expand Down Expand Up @@ -51,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {

private Connector noProcessFunctionConnector;

private Map<String, List<String>> headersMap;

@Mock
private SdkHttpFullResponse sdkHttpResponse;
@Mock
Expand Down Expand Up @@ -104,6 +108,7 @@ public void setup() {
null
);
responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber();
headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!"));
}

@Test
Expand All @@ -112,6 +117,13 @@ public void test_OnHeaders() {
assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 200;
}

@Test
public void test_OnHeaders_withError() {
when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 400;
}

@Test
public void test_OnStream_with_postProcessFunction_bedRock() {
String response = "{\n"
Expand Down Expand Up @@ -419,4 +431,168 @@ public void test_onComplete_error_http_status() {
System.out.println(captor.getValue().getMessage());
assert captor.getValue().getMessage().contains("runtime error");
}

@Test
public void test_onComplete_throttle_error_headers() {
String error = "{\"message\": null}";
SdkHttpResponse response = mock(SdkHttpFullResponse.class);
when(response.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
when(response.headers()).thenReturn(headersMap);
mlSdkAsyncHttpResponseHandler.onHeaders(response);
Publisher<ByteBuffer> stream = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(error.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
System.out.println(captor.getValue().getMessage());
assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR);
}

@Test
public void test_onComplete_throttle_exceptionFirst() {
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
String response1 = "{\n"
+ " \"embedding\": [\n"
+ " 0.46484375,\n"
+ " -0.017822266,\n"
+ " 0.17382812,\n"
+ " 0.10595703,\n"
+ " 0.875,\n"
+ " 0.19140625,\n"
+ " -0.36914062,\n"
+ " -0.0011978149\n"
+ " ]\n"
+ "}";
String response2 = "{\"message\": null}";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(0, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(1, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
null
);

SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
Publisher<ByteBuffer> stream2 = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(response2.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler2.onStream(stream2);

SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
when(sdkHttpResponse1.statusCode()).thenReturn(200);
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
Publisher<ByteBuffer> stream1 = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(response1.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
}

@Test
public void test_onComplete_throttle_exceptionSecond() {
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
String response1 = "{\n"
+ " \"embedding\": [\n"
+ " 0.46484375,\n"
+ " -0.017822266,\n"
+ " 0.17382812,\n"
+ " 0.10595703,\n"
+ " 0.875,\n"
+ " 0.19140625,\n"
+ " -0.36914062,\n"
+ " -0.0011978149\n"
+ " ]\n"
+ "}";
String response2 = "{\"message\": null}";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(0, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(1, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
null
);
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
when(sdkHttpResponse1.statusCode()).thenReturn(200);
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
Publisher<ByteBuffer> stream1 = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(response1.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler1.onStream(stream1);

SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
Publisher<ByteBuffer> stream2 = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(response2.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
}

}

0 comments on commit 7add721

Please sign in to comment.