diff --git a/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java index d186ad2bdd..543190e0e8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java @@ -5,16 +5,16 @@ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableList; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchSecurityException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; -import org.opensearch.ml.common.conversation.ActionConstants; -import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; @@ -24,12 +24,14 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; - +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.function.Function; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; @@ -53,17 +55,6 @@ public List routes() { return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT)); } -// @Override -// public List replacedRoutes() { -// return Arrays.asList( -// new ReplacedRoute( -// RestRequest.Method.POST, QUERY_API_ENDPOINT, -// RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT), -// new ReplacedRoute( -// RestRequest.Method.POST, EXPLAIN_API_ENDPOINT, -// RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT)); -// } - @Override public String getName() { return "ml_ppl_query_action"; @@ -87,7 +78,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nod nodeClient.execute( PPLQueryAction.INSTANCE, transportPPLQueryRequest, - wrapActionListener( + getPPLTransportActionListener( new ActionListener<>() { @Override public void onResponse(TransportPPLQueryResponse response) { @@ -123,11 +114,28 @@ private void sendResponse(RestChannel channel, RestStatus status, String content private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { channel.sendResponse(new BytesRestResponse(status, e.getMessage())); } - private ActionListener wrapActionListener( - final ActionListener listener) { - return ActionListener.wrap(r -> { - TransportPPLQueryResponse pplQueryResponse = TransportPPLQueryResponse.fromActionResponse(r); - listener.onResponse(pplQueryResponse); + + private ActionListener getPPLTransportActionListener(ActionListener listener) { + return ActionListener.wrap(r -> { + listener.onResponse(fromActionResponse(r)); }, listener::onFailure); } + + private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof TransportPPLQueryResponse) { + return (TransportPPLQueryResponse) actionResponse; + } + + try ( + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new TransportPPLQueryResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e); + } + + } }