Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect warnings in compute service (#103031) #103079

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/changelog/103031.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pr: 103031
summary: Collect warnings in compute service
area: ES|QL
type: bug
issues:
- 100163
- 103028
- 102871
- 102982
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,11 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.tasks.TaskCancelledException;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

/**
Expand All @@ -41,11 +36,10 @@ public DriverRunner(ThreadContext threadContext) {
*/
public void runToCompletion(List<Driver> drivers, ActionListener<Void> listener) {
AtomicReference<Exception> failure = new AtomicReference<>();
AtomicArray<Map<String, List<String>>> responseHeaders = new AtomicArray<>(drivers.size());
var responseHeadersCollector = new ResponseHeadersCollector(threadContext);
CountDown counter = new CountDown(drivers.size());
for (int i = 0; i < drivers.size(); i++) {
Driver driver = drivers.get(i);
int driverIndex = i;
ActionListener<Void> driverListener = new ActionListener<>() {
@Override
public void onResponse(Void unused) {
Expand Down Expand Up @@ -80,9 +74,9 @@ public void onFailure(Exception e) {
}

private void done() {
responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders());
responseHeadersCollector.collect();
if (counter.countDown()) {
mergeResponseHeaders(responseHeaders);
responseHeadersCollector.finish();
Exception error = failure.get();
if (error != null) {
listener.onFailure(error);
Expand All @@ -96,23 +90,4 @@ private void done() {
start(driver, driverListener);
}
}

private void mergeResponseHeaders(AtomicArray<Map<String, List<String>>> responseHeaders) {
final Map<String, Set<String>> merged = new HashMap<>();
for (int i = 0; i < responseHeaders.length(); i++) {
final Map<String, List<String>> resp = responseHeaders.get(i);
if (resp == null || resp.isEmpty()) {
continue;
}
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
// Use LinkedHashSet to retain the order of the values
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
}
}
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
for (String v : e.getValue()) {
threadContext.addResponseHeader(e.getKey(), v);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

/**
* A helper class that can be used to collect and merge response headers from multiple child requests.
*/
public final class ResponseHeadersCollector {
private final ThreadContext threadContext;
private final Queue<Map<String, List<String>>> collected = ConcurrentCollections.newQueue();

public ResponseHeadersCollector(ThreadContext threadContext) {
this.threadContext = threadContext;
}

/**
* Called when a child request is completed to collect the response headers of the responding thread
*/
public void collect() {
Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders();
if (responseHeaders.isEmpty() == false) {
collected.add(responseHeaders);
}
}

/**
* Called when all child requests are completed. This will merge all collected response headers
* from the child requests and restore to the current thread.
*/
public void finish() {
final Map<String, Set<String>> merged = new HashMap<>();
Map<String, List<String>> resp;
while ((resp = collected.poll()) != null) {
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
// Use LinkedHashSet to retain the order of the values
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
}
}
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
for (String v : e.getValue()) {
threadContext.addResponseHeader(e.getKey(), v);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.equalTo;

public class ResponseHeadersCollectorTests extends ESTestCase {

public void testCollect() {
int numThreads = randomIntBetween(1, 10);
TestThreadPool threadPool = new TestThreadPool(
getTestClass().getSimpleName(),
new FixedExecutorBuilder(Settings.EMPTY, "test", numThreads, 1024, "test", EsExecutors.TaskTrackingConfig.DEFAULT)
);
Set<String> expectedWarnings = new HashSet<>();
try {
ThreadContext threadContext = threadPool.getThreadContext();
var collector = new ResponseHeadersCollector(threadContext);
PlainActionFuture<Void> future = new PlainActionFuture<>();
Runnable mergeAndVerify = () -> {
collector.finish();
List<String> actualWarnings = threadContext.getResponseHeaders().getOrDefault("Warnings", List.of());
assertThat(Sets.newHashSet(actualWarnings), equalTo(expectedWarnings));
};
try (RefCountingListener refs = new RefCountingListener(ActionListener.runAfter(future, mergeAndVerify))) {
CyclicBarrier barrier = new CyclicBarrier(numThreads);
for (int i = 0; i < numThreads; i++) {
String warning = "warning-" + i;
expectedWarnings.add(warning);
ActionListener<Void> listener = ActionListener.runBefore(refs.acquire(), collector::collect);
threadPool.schedule(new ActionRunnable<>(listener) {
@Override
protected void doRun() throws Exception {
barrier.await(30, TimeUnit.SECONDS);
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.addResponseHeader("Warnings", warning);
listener.onResponse(null);
}
}
}, TimeValue.timeValueNanos(between(0, 1000_000)), threadPool.executor("test"));
}
}
future.actionGet(TimeValue.timeValueSeconds(30));
} finally {
terminate(threadPool);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.transport.TransportService;

import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
public class WarningsIT extends AbstractEsqlIntegTestCase {

public void testCollectWarnings() {
final String node1, node2;
if (randomBoolean()) {
internalCluster().ensureAtLeastNumDataNodes(2);
node1 = randomDataNode().getName();
node2 = randomValueOtherThan(node1, () -> randomDataNode().getName());
} else {
node1 = randomDataNode().getName();
node2 = randomDataNode().getName();
}

int numDocs1 = randomIntBetween(1, 15);
assertAcked(
client().admin()
.indices()
.prepareCreate("index-1")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node1))
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs1; i++) {
client().prepareIndex("index-1").setSource("host", "192." + i).get();
}
int numDocs2 = randomIntBetween(1, 15);
assertAcked(
client().admin()
.indices()
.prepareCreate("index-2")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node2))
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs2; i++) {
client().prepareIndex("index-2").setSource("host", "10." + i).get();
}

DiscoveryNode coordinator = randomFrom(clusterService().state().nodes().stream().toList());
client().admin().indices().prepareRefresh("index-1", "index-2").get();

EsqlQueryRequest request = new EsqlQueryRequest();
request.query("FROM index-* | EVAL ip = to_ip(host) | STATS s = COUNT(*) by ip | KEEP ip | LIMIT 100");
request.pragmas(randomPragmas());
PlainActionFuture<EsqlQueryResponse> future = new PlainActionFuture<>();
client(coordinator.getName()).execute(EsqlQueryAction.INSTANCE, request, ActionListener.runBefore(future, () -> {
var threadpool = internalCluster().getInstance(TransportService.class, coordinator.getName()).getThreadPool();
Map<String, List<String>> responseHeaders = threadpool.getThreadContext().getResponseHeaders();
List<String> warnings = responseHeaders.getOrDefault("Warning", List.of())
.stream()
.filter(w -> w.contains("is not an IP string literal"))
.toList();
int expectedWarnings = Math.min(20, numDocs1 + numDocs2);
// we cap the number of warnings per node
assertThat(warnings.size(), greaterThanOrEqualTo(expectedWarnings));
}));
future.actionGet(30, TimeUnit.SECONDS).close();
}

private DiscoveryNode randomDataNode() {
return randomFrom(clusterService().state().nodes().getDataNodes().values());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
Expand Down Expand Up @@ -148,6 +149,8 @@ public void execute(
LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", dataNodePlan, requestFilter);

String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan);
var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext());
listener = ActionListener.runBefore(listener, responseHeadersCollector::finish);
computeTargetNodes(
rootTask,
requestFilter,
Expand All @@ -168,7 +171,16 @@ public void execute(
exchangeSource.addCompletionListener(requestRefs.acquire());
// run compute on the coordinator
var computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null);
runCompute(rootTask, computeContext, coordinatorPlan, cancelOnFailure(rootTask, cancelled, requestRefs.acquire()));
runCompute(
rootTask,
computeContext,
coordinatorPlan,
cancelOnFailure(
rootTask,
cancelled,
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
)
);
// run compute on remote nodes
// TODO: This is wrong, we need to be able to cancel
runComputeOnRemoteNodes(
Expand All @@ -178,7 +190,11 @@ public void execute(
dataNodePlan,
exchangeSource,
targetNodes,
() -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(unused -> null)
() -> cancelOnFailure(
rootTask,
cancelled,
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
)
);
}
})
Expand All @@ -192,7 +208,7 @@ private void runComputeOnRemoteNodes(
PhysicalPlan dataNodePlan,
ExchangeSourceHandler exchangeSource,
List<TargetNode> targetNodes,
Supplier<ActionListener<DataNodeResponse>> listener
Supplier<ActionListener<Void>> listener
) {
// Do not complete the exchange sources until we have linked all remote sinks
final SubscribableListener<Void> blockingSinkFuture = new SubscribableListener<>();
Expand Down Expand Up @@ -221,7 +237,7 @@ private void runComputeOnRemoteNodes(
new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan),
rootTask,
TransportRequestOptions.EMPTY,
new ActionListenerResponseHandler<>(delegate, DataNodeResponse::new, esqlExecutor)
new ActionListenerResponseHandler<>(delegate.map(ignored -> null), DataNodeResponse::new, esqlExecutor)
);
})
);
Expand Down Expand Up @@ -432,7 +448,10 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T
runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(unused -> {
// don't return until all pages are fetched
exchangeSink.addCompletionListener(
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null))
ContextPreservingActionListener.wrapPreservingContext(
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)),
transportService.getThreadPool().getThreadContext()
)
);
}, e -> {
exchangeService.finishSinkHandler(sessionId, e);
Expand Down