Skip to content

Commit

Permalink
Schedule request in worker thread (#748)
Browse files Browse the repository at this point in the history
Signed-off-by: penghuo <[email protected]>
  • Loading branch information
penghuo authored Aug 17, 2022
1 parent 23a4a88 commit ce15448
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.OK;
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.sql.opensearch.executor.Scheduler.schedule;

import com.alibaba.druid.sql.parser.ParserException;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -147,19 +148,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

Format format = SqlRequestParam.getFormat(request.params());

// Route request to new query engine if it's supported already
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
sqlRequest.getSql(), request.path(), request.params());
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId());
return result;
}
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
QueryContext.getRequestId(), newSqlRequest);

final QueryAction queryAction = explainRequest(client, sqlRequest, format);
return channel -> executeSqlRequest(request, queryAction, client, channel);
return channel -> schedule(client, () -> {
try {
// Route request to new query engine if it's supported already
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
sqlRequest.getSql(), request.path(), request.params());
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
LOG.info("[{}] Request is handled by new SQL query engine",
QueryContext.getRequestId());
result.accept(channel);
} else {
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
QueryContext.getRequestId(), newSqlRequest);
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, channel);
}
} catch (Exception e) {
logAndPublishMetrics(e);
reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
}
});
} catch (Exception e) {
logAndPublishMetrics(e);
return channel -> reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.executor;

import java.util.Map;
import lombok.experimental.UtilityClass;
import org.apache.logging.log4j.ThreadContext;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.threadpool.ThreadPool;

/** The scheduler which schedule the task run in sql-worker thread pool. */
@UtilityClass
public class Scheduler {

public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker";

public static void schedule(NodeClient client, Runnable task) {
ThreadPool threadPool = client.threadPool();
threadPool.schedule(withCurrentContext(task), new TimeValue(0), SQL_WORKER_THREAD_POOL_NAME);
}

private static Runnable withCurrentContext(final Runnable task) {
final Map<String, String> currentContext = ThreadContext.getImmutableContext();
return () -> {
ThreadContext.putAll(currentContext);
task.run();
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.executor;

import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.client.node.NodeClient;
import org.opensearch.threadpool.ThreadPool;

@ExtendWith(MockitoExtension.class)
class SchedulerTest {
@Test
public void schedule() {
NodeClient nodeClient = mock(NodeClient.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(nodeClient.threadPool()).thenReturn(threadPool);

doAnswer(
invocation -> {
Runnable task = invocation.getArgument(0);
task.run();
return null;
})
.when(threadPool)
.schedule(any(), any(), any());
AtomicBoolean isRun = new AtomicBoolean(false);
Scheduler.schedule(nodeClient, () -> isRun.set(true));
assertTrue(isRun.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.rest.RestStatus.OK;
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.sql.opensearch.executor.Scheduler.schedule;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -112,42 +113,43 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nod
PPLQueryRequestFactory.getPPLRequest(request)
);

return channel ->
nodeClient.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
sendResponse(channel, OK, response.getResult());
}

@Override
public void onFailure(Exception e) {
if (transportPPLQueryRequest.isExplainRequest()) {
LOG.error("Error happened during explain", e);
sendResponse(
channel,
INTERNAL_SERVER_ERROR,
"Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof IllegalAccessException) {
return channel -> schedule(nodeClient, () ->
nodeClient.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
sendResponse(channel, OK, response.getResult());
}

@Override
public void onFailure(Exception e) {
if (transportPPLQueryRequest.isExplainRequest()) {
LOG.error("Error happened during explain", e);
sendResponse(
channel,
INTERNAL_SERVER_ERROR,
"Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof IllegalAccessException) {
reportError(channel, e, BAD_REQUEST);
} else {
LOG.error("Error happened during query handling", e);
if (isClientError(e)) {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS)
.increment();
reportError(channel, e, BAD_REQUEST);
} else {
LOG.error("Error happened during query handling", e);
if (isClientError(e)) {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS)
.increment();
reportError(channel, e, BAD_REQUEST);
} else {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
.increment();
reportError(channel, e, SERVICE_UNAVAILABLE);
}
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
.increment();
reportError(channel, e, SERVICE_UNAVAILABLE);
}
}
});
}
}
));
}

private void sendResponse(RestChannel channel, RestStatus status, String content) {
Expand Down

0 comments on commit ce15448

Please sign in to comment.