diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 6f7579c9c7..06d1ba1c73 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -9,6 +9,7 @@ import static org.opensearch.rest.RestStatus.BAD_REQUEST; import static org.opensearch.rest.RestStatus.OK; import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.sql.opensearch.executor.Scheduler.schedule; import com.alibaba.druid.sql.parser.ParserException; import com.google.common.collect.ImmutableList; @@ -147,19 +148,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli Format format = SqlRequestParam.getFormat(request.params()); - // Route request to new query engine if it's supported already - SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(), - sqlRequest.getSql(), request.path(), request.params()); - RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client); - if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) { - LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId()); - return result; - } - LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", - QueryContext.getRequestId(), newSqlRequest); - - final QueryAction queryAction = explainRequest(client, sqlRequest, format); - return channel -> executeSqlRequest(request, queryAction, client, channel); + return channel -> schedule(client, () -> { + try { + // Route request to new query engine if it's supported already + SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(), + sqlRequest.getSql(), request.path(), request.params()); + RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client); + if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) { + LOG.info("[{}] Request is handled by new SQL query engine", + QueryContext.getRequestId()); + result.accept(channel); + } else { + LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", + QueryContext.getRequestId(), newSqlRequest); + QueryAction queryAction = explainRequest(client, sqlRequest, format); + executeSqlRequest(request, queryAction, client, channel); + } + } catch (Exception e) { + logAndPublishMetrics(e); + reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + } + }); } catch (Exception e) { logAndPublishMetrics(e); return channel -> reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/Scheduler.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/Scheduler.java new file mode 100644 index 0000000000..5567d1f9b2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/Scheduler.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.executor; + +import java.util.Map; +import lombok.experimental.UtilityClass; +import org.apache.logging.log4j.ThreadContext; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; + +/** The scheduler which schedule the task run in sql-worker thread pool. */ +@UtilityClass +public class Scheduler { + + public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; + + public static void schedule(NodeClient client, Runnable task) { + ThreadPool threadPool = client.threadPool(); + threadPool.schedule(withCurrentContext(task), new TimeValue(0), SQL_WORKER_THREAD_POOL_NAME); + } + + private static Runnable withCurrentContext(final Runnable task) { + final Map currentContext = ThreadContext.getImmutableContext(); + return () -> { + ThreadContext.putAll(currentContext); + task.run(); + }; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/SchedulerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/SchedulerTest.java new file mode 100644 index 0000000000..f14bda7a95 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/SchedulerTest.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.executor; + +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.node.NodeClient; +import org.opensearch.threadpool.ThreadPool; + +@ExtendWith(MockitoExtension.class) +class SchedulerTest { + @Test + public void schedule() { + NodeClient nodeClient = mock(NodeClient.class); + ThreadPool threadPool = mock(ThreadPool.class); + when(nodeClient.threadPool()).thenReturn(threadPool); + + doAnswer( + invocation -> { + Runnable task = invocation.getArgument(0); + task.run(); + return null; + }) + .when(threadPool) + .schedule(any(), any(), any()); + AtomicBoolean isRun = new AtomicBoolean(false); + Scheduler.schedule(nodeClient, () -> isRun.set(true)); + assertTrue(isRun.get()); + } +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index 0a65fb7b0d..e9202d96e8 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -9,6 +9,7 @@ import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.rest.RestStatus.OK; import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.sql.opensearch.executor.Scheduler.schedule; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -112,42 +113,43 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nod PPLQueryRequestFactory.getPPLRequest(request) ); - return channel -> - nodeClient.execute( - PPLQueryAction.INSTANCE, - transportPPLQueryRequest, - new ActionListener<>() { - @Override - public void onResponse(TransportPPLQueryResponse response) { - sendResponse(channel, OK, response.getResult()); - } - - @Override - public void onFailure(Exception e) { - if (transportPPLQueryRequest.isExplainRequest()) { - LOG.error("Error happened during explain", e); - sendResponse( - channel, - INTERNAL_SERVER_ERROR, - "Failed to explain the query due to error: " + e.getMessage()); - } else if (e instanceof IllegalAccessException) { + return channel -> schedule(nodeClient, () -> + nodeClient.execute( + PPLQueryAction.INSTANCE, + transportPPLQueryRequest, + new ActionListener<>() { + @Override + public void onResponse(TransportPPLQueryResponse response) { + sendResponse(channel, OK, response.getResult()); + } + + @Override + public void onFailure(Exception e) { + if (transportPPLQueryRequest.isExplainRequest()) { + LOG.error("Error happened during explain", e); + sendResponse( + channel, + INTERNAL_SERVER_ERROR, + "Failed to explain the query due to error: " + e.getMessage()); + } else if (e instanceof IllegalAccessException) { + reportError(channel, e, BAD_REQUEST); + } else { + LOG.error("Error happened during query handling", e); + if (isClientError(e)) { + Metrics.getInstance() + .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS) + .increment(); reportError(channel, e, BAD_REQUEST); } else { - LOG.error("Error happened during query handling", e); - if (isClientError(e)) { - Metrics.getInstance() - .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS) - .increment(); - reportError(channel, e, BAD_REQUEST); - } else { - Metrics.getInstance() - .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS) - .increment(); - reportError(channel, e, SERVICE_UNAVAILABLE); - } + Metrics.getInstance() + .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS) + .increment(); + reportError(channel, e, SERVICE_UNAVAILABLE); } } - }); + } + } + )); } private void sendResponse(RestChannel channel, RestStatus status, String content) {