Skip to content

Commit

Permalink
move seder to ml-commons
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Nov 25, 2023
1 parent 347f08c commit 7d5e8e7
Showing 1 changed file with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

package org.opensearch.ml.rest;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
Expand All @@ -24,12 +24,14 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;


import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
Expand All @@ -53,17 +55,6 @@ public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT));
}

// @Override
// public List<ReplacedRoute> replacedRoutes() {
// return Arrays.asList(
// new ReplacedRoute(
// RestRequest.Method.POST, QUERY_API_ENDPOINT,
// RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT),
// new ReplacedRoute(
// RestRequest.Method.POST, EXPLAIN_API_ENDPOINT,
// RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT));
// }

@Override
public String getName() {
return "ml_ppl_query_action";
Expand All @@ -87,7 +78,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nod
nodeClient.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
wrapActionListener(
getPPLTransportActionListener(
new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
Expand Down Expand Up @@ -123,11 +114,28 @@ private void sendResponse(RestChannel channel, RestStatus status, String content
private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, e.getMessage()));
}
private ActionListener<TransportPPLQueryResponse> wrapActionListener(
final ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> {
TransportPPLQueryResponse pplQueryResponse = TransportPPLQueryResponse.fromActionResponse(r);
listener.onResponse(pplQueryResponse);

private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> {
listener.onResponse(fromActionResponse(r));
}, listener::onFailure);
}

private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof TransportPPLQueryResponse) {
return (TransportPPLQueryResponse) actionResponse;
}

try (
ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new TransportPPLQueryResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e);
}

}
}

0 comments on commit 7d5e8e7

Please sign in to comment.