Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Nov 25, 2023
1 parent c840012 commit 64189a4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 117 deletions.
24 changes: 12 additions & 12 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,20 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
FunctionName functionName = FunctionName.from((String) sourceAsMap.get(MLModel.ALGORITHM_FIELD));
MLModelState state = MLModelState.from((String) sourceAsMap.get(MLModel.MODEL_STATE_FIELD));
Long lastUpdateTime = sourceAsMap.containsKey(MLModel.LAST_UPDATED_TIME_FIELD)
? (Long) sourceAsMap.get(MLModel.LAST_UPDATED_TIME_FIELD)
: null;
? (Long) sourceAsMap.get(MLModel.LAST_UPDATED_TIME_FIELD)
: null;
int planningWorkerNodeCount = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD)
? (int) sourceAsMap.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD)
: 0;
? (int) sourceAsMap.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD)
: 0;
int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
: 0;
? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
: 0;
boolean deployToAllNodes = sourceAsMap.containsKey(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
? (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
: false;
? (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
: false;
List<String> planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD)
? (List<String>) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD)
: new ArrayList<>();
? (List<String>) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD)
: new ArrayList<>();
if (deployToAllNodes) {
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName);
planningWorkerNodeCount = eligibleNodes.length;
Expand Down Expand Up @@ -312,8 +312,8 @@ private MLModelState getNewModelState(
if (currentWorkerNodeCount == 0
&& state != MLModelState.DEPLOY_FAILED
&& !(state == MLModelState.DEPLOYING
&& lastUpdateTime != null
&& lastUpdateTime + DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS > Instant.now().toEpochMilli())) {
&& lastUpdateTime != null
&& lastUpdateTime + DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS > Instant.now().toEpochMilli())) {
// If model not deployed to any node and no node is deploying the model, then set model state as DEPLOY_FAILED
return MLModelState.DEPLOY_FAILED;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
//import org.opensearch.ml.rest.MyRestPPLQueryAction;
import org.opensearch.ml.rest.MyRestPPLQueryAction;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLDeleteConnectorAction;
Expand Down
196 changes: 92 additions & 104 deletions plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@

package org.opensearch.ml.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
Expand All @@ -24,118 +37,93 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class MyRestPPLQueryAction extends BaseRestHandler {
public static final String QUERY_API_ENDPOINT = "_ml/_ppl";
public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain";
public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl";
public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain";

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;
private static final Logger LOG = LogManager.getLogger();

public class MyRestPPLQueryAction extends BaseRestHandler {
public static final String QUERY_API_ENDPOINT = "_ml/_ppl";
public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain";
public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl";
public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain";

private static final Logger LOG = LogManager.getLogger();

/** Constructor of RestPPLQueryAction. */
public MyRestPPLQueryAction() {
super();
}

@Override
public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT));
}

@Override
public String getName() {
return "ml_ppl_query_action";
}

@Override
protected Set<String> responseParams() {
Set<String> responseParams = new HashSet<>(super.responseParams());
responseParams.addAll(Arrays.asList("format", "sanitize"));
return responseParams;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request));
LOG.info("request classloader: " + transportPPLQueryRequest.getClass().getClassLoader());
LOG.info("response classloader:" + TransportPPLQueryResponse.class.getClassLoader());

return channel ->
nodeClient.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
getPPLTransportActionListener(
new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
sendResponse(channel, OK, response.getResult());
}

@Override
public void onFailure(Exception e) {
if (e instanceof IllegalAccessException) {
LOG.error("Error happened during query handling", e);
reportError(channel, e, BAD_REQUEST);
} else if (transportPPLQueryRequest.isExplainRequest()) {
LOG.error("Error happened during explain", e);
sendResponse(
channel,
INTERNAL_SERVER_ERROR,
"Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof OpenSearchSecurityException) {
OpenSearchSecurityException exception = (OpenSearchSecurityException) e;
reportError(channel, exception, exception.status());
} else {
LOG.error("Error happened during query handling", e);
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}
}));
}
/** Constructor of RestPPLQueryAction. */
public MyRestPPLQueryAction() {
super();
}

private void sendResponse(RestChannel channel, RestStatus status, String content) {
channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content));
}
@Override
public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT));
}

private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, e.getMessage()));
}
@Override
public String getName() {
return "ml_ppl_query_action";
}

private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> {
listener.onResponse(fromActionResponse(r));
}, listener::onFailure);
}
@Override
protected Set<String> responseParams() {
Set<String> responseParams = new HashSet<>(super.responseParams());
responseParams.addAll(Arrays.asList("format", "sanitize"));
return responseParams;
}

private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof TransportPPLQueryResponse) {
return (TransportPPLQueryResponse) actionResponse;
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) {
TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request));
LOG.info("request classloader: " + transportPPLQueryRequest.getClass().getClassLoader());
LOG.info("response classloader:" + TransportPPLQueryResponse.class.getClassLoader());

return channel -> nodeClient
.execute(PPLQueryAction.INSTANCE, transportPPLQueryRequest, getPPLTransportActionListener(new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
sendResponse(channel, OK, response.getResult());
}

@Override
public void onFailure(Exception e) {
if (e instanceof IllegalAccessException) {
LOG.error("Error happened during query handling", e);
reportError(channel, e, BAD_REQUEST);
} else if (transportPPLQueryRequest.isExplainRequest()) {
LOG.error("Error happened during explain", e);
sendResponse(channel, INTERNAL_SERVER_ERROR, "Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof OpenSearchSecurityException) {
OpenSearchSecurityException exception = (OpenSearchSecurityException) e;
reportError(channel, exception, exception.status());
} else {
LOG.error("Error happened during query handling", e);
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}
}));
}

private void sendResponse(RestChannel channel, RestStatus status, String content) {
channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content));
}

try (
ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new TransportPPLQueryResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e);
private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, e.getMessage()));
}

}
private <T extends ActionResponse> ActionListener<T> getPPLTransportActionListener(ActionListener<TransportPPLQueryResponse> listener) {
return ActionListener.wrap(r -> { listener.onResponse(fromActionResponse(r)); }, listener::onFailure);
}

private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof TransportPPLQueryResponse) {
return (TransportPPLQueryResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new TransportPPLQueryResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e);
}

}
}

0 comments on commit 64189a4

Please sign in to comment.