Skip to content

Commit

Permalink
send response in xcontent, if any exception, use plain text (opensear…
Browse files Browse the repository at this point in the history
…ch-project#2858)

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Aug 30, 2024
1 parent 7ecff1a commit 1da79ce
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
Expand Down Expand Up @@ -132,7 +133,21 @@ private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) t

private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(e, status.getStatus());
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString()));
try {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("status", errorMessage.getStatus());
builder.startObject("error");
builder.field("type", errorMessage.getType());
builder.field("reason", errorMessage.getReason());
builder.field("details", errorMessage.getDetails());
builder.endObject();
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), builder));
} catch (Exception exception) {
log.error("Failed to build xContent for an error response, so reply with a plain string.", exception);
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString()));
}
}

private boolean isClientError(Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -281,4 +282,59 @@ public void testPrepareRequestSystemException() throws Exception {
"{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}";
assertEquals(expectedError, response.content().utf8ToString());
}

public void testAgentExecutionResponseXContent() throws Exception {
RestRequest request = getExecuteAgentRestRequest();
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener
.onFailure(
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
when(channel.newBuilder()).thenReturn(XContentFactory.jsonBuilder());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.AGENT, input.getFunctionName());
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
assertEquals(RestStatus.BAD_REQUEST, response.status());
assertEquals("application/json; charset=UTF-8", response.contentType());
String expectedError =
"{\"status\":400,\"error\":{\"type\":\"IllegalArgumentException\",\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\"}}";
assertEquals(expectedError, response.content().utf8ToString());
}

public void testAgentExecutionResponsePlainText() throws Exception {
RestRequest request = getExecuteAgentRestRequest();
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener
.onFailure(
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.AGENT, input.getFunctionName());
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
assertEquals(RestStatus.BAD_REQUEST, response.status());
assertEquals("text/plain; charset=UTF-8", response.contentType());
String expectedError =
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
assertEquals(expectedError, response.content().utf8ToString());
}
}

0 comments on commit 1da79ce

Please sign in to comment.