From 1da79ce53754d11ea93ee242e9114ad71f3a406b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 30 Aug 2024 10:56:05 -0700 Subject: [PATCH] send response in xcontent, if any exception, use plain text (#2858) Signed-off-by: Jing Zhang --- .../ml/rest/RestMLExecuteAction.java | 17 +++++- .../ml/rest/RestMLExecuteActionTests.java | 56 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index 3284da09b7..90caee44c5 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -21,6 +21,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; @@ -132,7 +133,21 @@ private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) t private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(e, status.getStatus()); - channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString())); + try { + XContentBuilder builder = channel.newBuilder(); + builder.startObject(); + builder.field("status", errorMessage.getStatus()); + builder.startObject("error"); + builder.field("type", errorMessage.getType()); + builder.field("reason", errorMessage.getReason()); + builder.field("details", errorMessage.getDetails()); + builder.endObject(); + builder.endObject(); + channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), builder)); + } catch (Exception exception) { + log.error("Failed to build xContent for an error response, so reply with a plain string.", exception); + channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString())); + } } private boolean isClientError(Exception e) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java index ac570a6a4d..acbaacab0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java @@ -28,6 +28,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; @@ -281,4 +282,59 @@ public void testPrepareRequestSystemException() throws Exception { "{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}"; assertEquals(expectedError, response.content().utf8ToString()); } + + public void testAgentExecutionResponseXContent() throws Exception { + RestRequest request = getExecuteAgentRestRequest(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) + ); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + when(channel.newBuilder()).thenReturn(XContentFactory.jsonBuilder()); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.AGENT, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + assertEquals("application/json; charset=UTF-8", response.contentType()); + String expectedError = + "{\"status\":400,\"error\":{\"type\":\"IllegalArgumentException\",\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\"}}"; + assertEquals(expectedError, response.content().utf8ToString()); + } + + public void testAgentExecutionResponsePlainText() throws Exception { + RestRequest request = getExecuteAgentRestRequest(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) + ); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.AGENT, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + assertEquals("text/plain; charset=UTF-8", response.contentType()); + String expectedError = + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; + assertEquals(expectedError, response.content().utf8ToString()); + } }