diff --git a/docs/changelog/103031.yaml b/docs/changelog/103031.yaml new file mode 100644 index 0000000000000..f63094139f5ca --- /dev/null +++ b/docs/changelog/103031.yaml @@ -0,0 +1,9 @@ +pr: 103031 +summary: Collect warnings in compute service +area: ES|QL +type: bug +issues: + - 100163 + - 103028 + - 102871 + - 102982 diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java index 4f16a615572b7..5de017fbd279e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java @@ -9,16 +9,11 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.tasks.TaskCancelledException; -import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; -import java.util.Set; import java.util.concurrent.atomic.AtomicReference; /** @@ -41,11 +36,10 @@ public DriverRunner(ThreadContext threadContext) { */ public void runToCompletion(List drivers, ActionListener listener) { AtomicReference failure = new AtomicReference<>(); - AtomicArray>> responseHeaders = new AtomicArray<>(drivers.size()); + var responseHeadersCollector = new ResponseHeadersCollector(threadContext); CountDown counter = new CountDown(drivers.size()); for (int i = 0; i < drivers.size(); i++) { Driver driver = drivers.get(i); - int driverIndex = i; ActionListener driverListener = new ActionListener<>() { @Override public void onResponse(Void unused) { @@ -80,9 +74,9 @@ public void onFailure(Exception e) { } private void done() { - responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders()); + responseHeadersCollector.collect(); if (counter.countDown()) { - mergeResponseHeaders(responseHeaders); + responseHeadersCollector.finish(); Exception error = failure.get(); if (error != null) { listener.onFailure(error); @@ -96,23 +90,4 @@ private void done() { start(driver, driverListener); } } - - private void mergeResponseHeaders(AtomicArray>> responseHeaders) { - final Map> merged = new HashMap<>(); - for (int i = 0; i < responseHeaders.length(); i++) { - final Map> resp = responseHeaders.get(i); - if (resp == null || resp.isEmpty()) { - continue; - } - for (Map.Entry> e : resp.entrySet()) { - // Use LinkedHashSet to retain the order of the values - merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue()); - } - } - for (Map.Entry> e : merged.entrySet()) { - for (String v : e.getValue()) { - threadContext.addResponseHeader(e.getKey(), v); - } - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ResponseHeadersCollector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ResponseHeadersCollector.java new file mode 100644 index 0000000000000..8f40664be74d4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ResponseHeadersCollector.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; + +/** + * A helper class that can be used to collect and merge response headers from multiple child requests. + */ +public final class ResponseHeadersCollector { + private final ThreadContext threadContext; + private final Queue>> collected = ConcurrentCollections.newQueue(); + + public ResponseHeadersCollector(ThreadContext threadContext) { + this.threadContext = threadContext; + } + + /** + * Called when a child request is completed to collect the response headers of the responding thread + */ + public void collect() { + Map> responseHeaders = threadContext.getResponseHeaders(); + if (responseHeaders.isEmpty() == false) { + collected.add(responseHeaders); + } + } + + /** + * Called when all child requests are completed. This will merge all collected response headers + * from the child requests and restore to the current thread. + */ + public void finish() { + final Map> merged = new HashMap<>(); + Map> resp; + while ((resp = collected.poll()) != null) { + for (Map.Entry> e : resp.entrySet()) { + // Use LinkedHashSet to retain the order of the values + merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue()); + } + } + for (Map.Entry> e : merged.entrySet()) { + for (String v : e.getValue()) { + threadContext.addResponseHeader(e.getKey(), v); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ResponseHeadersCollectorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ResponseHeadersCollectorTests.java new file mode 100644 index 0000000000000..b09372f3a962c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ResponseHeadersCollectorTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; + +public class ResponseHeadersCollectorTests extends ESTestCase { + + public void testCollect() { + int numThreads = randomIntBetween(1, 10); + TestThreadPool threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder(Settings.EMPTY, "test", numThreads, 1024, "test", EsExecutors.TaskTrackingConfig.DEFAULT) + ); + Set expectedWarnings = new HashSet<>(); + try { + ThreadContext threadContext = threadPool.getThreadContext(); + var collector = new ResponseHeadersCollector(threadContext); + PlainActionFuture future = new PlainActionFuture<>(); + Runnable mergeAndVerify = () -> { + collector.finish(); + List actualWarnings = threadContext.getResponseHeaders().getOrDefault("Warnings", List.of()); + assertThat(Sets.newHashSet(actualWarnings), equalTo(expectedWarnings)); + }; + try (RefCountingListener refs = new RefCountingListener(ActionListener.runAfter(future, mergeAndVerify))) { + CyclicBarrier barrier = new CyclicBarrier(numThreads); + for (int i = 0; i < numThreads; i++) { + String warning = "warning-" + i; + expectedWarnings.add(warning); + ActionListener listener = ActionListener.runBefore(refs.acquire(), collector::collect); + threadPool.schedule(new ActionRunnable<>(listener) { + @Override + protected void doRun() throws Exception { + barrier.await(30, TimeUnit.SECONDS); + try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { + threadContext.addResponseHeader("Warnings", warning); + listener.onResponse(null); + } + } + }, TimeValue.timeValueNanos(between(0, 1000_000)), threadPool.executor("test")); + } + } + future.actionGet(TimeValue.timeValueSeconds(30)); + } finally { + terminate(threadPool); + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/WarningsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/WarningsIT.java new file mode 100644 index 0000000000000..12897979a47e0 --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/WarningsIT.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.transport.TransportService; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") +public class WarningsIT extends AbstractEsqlIntegTestCase { + + public void testCollectWarnings() { + final String node1, node2; + if (randomBoolean()) { + internalCluster().ensureAtLeastNumDataNodes(2); + node1 = randomDataNode().getName(); + node2 = randomValueOtherThan(node1, () -> randomDataNode().getName()); + } else { + node1 = randomDataNode().getName(); + node2 = randomDataNode().getName(); + } + + int numDocs1 = randomIntBetween(1, 15); + assertAcked( + client().admin() + .indices() + .prepareCreate("index-1") + .setSettings(Settings.builder().put("index.routing.allocation.require._name", node1)) + .setMapping("host", "type=keyword") + ); + for (int i = 0; i < numDocs1; i++) { + client().prepareIndex("index-1").setSource("host", "192." + i).get(); + } + int numDocs2 = randomIntBetween(1, 15); + assertAcked( + client().admin() + .indices() + .prepareCreate("index-2") + .setSettings(Settings.builder().put("index.routing.allocation.require._name", node2)) + .setMapping("host", "type=keyword") + ); + for (int i = 0; i < numDocs2; i++) { + client().prepareIndex("index-2").setSource("host", "10." + i).get(); + } + + DiscoveryNode coordinator = randomFrom(clusterService().state().nodes().stream().toList()); + client().admin().indices().prepareRefresh("index-1", "index-2").get(); + + EsqlQueryRequest request = new EsqlQueryRequest(); + request.query("FROM index-* | EVAL ip = to_ip(host) | STATS s = COUNT(*) by ip | KEEP ip | LIMIT 100"); + request.pragmas(randomPragmas()); + PlainActionFuture future = new PlainActionFuture<>(); + client(coordinator.getName()).execute(EsqlQueryAction.INSTANCE, request, ActionListener.runBefore(future, () -> { + var threadpool = internalCluster().getInstance(TransportService.class, coordinator.getName()).getThreadPool(); + Map> responseHeaders = threadpool.getThreadContext().getResponseHeaders(); + List warnings = responseHeaders.getOrDefault("Warning", List.of()) + .stream() + .filter(w -> w.contains("is not an IP string literal")) + .toList(); + int expectedWarnings = Math.min(20, numDocs1 + numDocs2); + // we cap the number of warnings per node + assertThat(warnings.size(), greaterThanOrEqualTo(expectedWarnings)); + })); + future.actionGet(30, TimeUnit.SECONDS).close(); + } + + private DiscoveryNode randomDataNode() { + return randomFrom(clusterService().state().nodes().getDataNodes().values()); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index e10469a4ff97d..3409d8f61d865 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -29,6 +29,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverTaskRunner; +import org.elasticsearch.compute.operator.ResponseHeadersCollector; import org.elasticsearch.compute.operator.exchange.ExchangeResponse; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; @@ -148,6 +149,8 @@ public void execute( LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", dataNodePlan, requestFilter); String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan); + var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext()); + listener = ActionListener.runBefore(listener, responseHeadersCollector::finish); computeTargetNodes( rootTask, requestFilter, @@ -168,7 +171,16 @@ public void execute( exchangeSource.addCompletionListener(requestRefs.acquire()); // run compute on the coordinator var computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null); - runCompute(rootTask, computeContext, coordinatorPlan, cancelOnFailure(rootTask, cancelled, requestRefs.acquire())); + runCompute( + rootTask, + computeContext, + coordinatorPlan, + cancelOnFailure( + rootTask, + cancelled, + ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect) + ) + ); // run compute on remote nodes // TODO: This is wrong, we need to be able to cancel runComputeOnRemoteNodes( @@ -178,7 +190,11 @@ public void execute( dataNodePlan, exchangeSource, targetNodes, - () -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(unused -> null) + () -> cancelOnFailure( + rootTask, + cancelled, + ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect) + ) ); } }) @@ -192,7 +208,7 @@ private void runComputeOnRemoteNodes( PhysicalPlan dataNodePlan, ExchangeSourceHandler exchangeSource, List targetNodes, - Supplier> listener + Supplier> listener ) { // Do not complete the exchange sources until we have linked all remote sinks final SubscribableListener blockingSinkFuture = new SubscribableListener<>(); @@ -221,7 +237,7 @@ private void runComputeOnRemoteNodes( new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan), rootTask, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(delegate, DataNodeResponse::new, esqlExecutor) + new ActionListenerResponseHandler<>(delegate.map(ignored -> null), DataNodeResponse::new, esqlExecutor) ); }) ); @@ -432,7 +448,10 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(unused -> { // don't return until all pages are fetched exchangeSink.addCompletionListener( - ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)) + ContextPreservingActionListener.wrapPreservingContext( + ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)), + transportService.getThreadPool().getThreadContext() + ) ); }, e -> { exchangeService.finishSinkHandler(sessionId, e);