Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aggregation memory leak for CCS #78404

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.ccs;

import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.transport.TransportService;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;

public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase {

@Override
protected Collection<String> remoteClusterAlias() {
return List.of("cluster_a");
}

@Override
protected boolean reuseClusters() {
return false;
}

private int indexDocs(Client client, String field, String index) {
int numDocs = between(1, 200);
for (int i = 0; i < numDocs; i++) {
client.prepareIndex(index).setSource(field, "v" + i).get();
}
client.admin().indices().prepareRefresh(index).get();
return numDocs;
}

/**
* This test validates that we do not leak any memory when running CCS in various modes, actual validation is done by test framework
* (leak detection)
* <ul>
* <li>proxy vs non-proxy</li>
* <li>single-phase query-fetch or multi-phase</li>
* <li>minimize roundtrip vs not</li>
* <li>scroll vs no scroll</li>
* </ul>
*/
public void testSearch() throws Exception {
assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo")
.setMapping("f", "type=keyword")
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
indexDocs(client(LOCAL_CLUSTER), "ignored", "demo");
final InternalTestCluster remoteCluster = cluster("cluster_a");
int minRemotes = between(2, 5);
remoteCluster.ensureAtLeastNumDataNodes(minRemotes);
List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
.filter(DiscoveryNode::canContainData)
.map(DiscoveryNode::getName)
.collect(Collectors.toList());
assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(minRemotes));
List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
disconnectFromRemoteClusters();
configureRemoteCluster("cluster_a", seedNodes);
final Settings.Builder allocationFilter = Settings.builder();
if (rarely()) {
allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
} else {
// Provoke using proxy connections
allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
}
assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
.setMapping("f", "type=keyword")
.setSettings(Settings.builder().put(allocationFilter.build())
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
.setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
int docs = indexDocs(client("cluster_a"), "f", "prod");

List<ActionFuture<SearchResponse>> futures = new ArrayList<>();
for (int i = 0; i < 10; ++i) {
String[] indices = randomBoolean() ? new String[] { "demo", "cluster_a:prod" } : new String[] { "cluster_a:prod" };
final SearchRequest searchRequest = new SearchRequest(indices);
searchRequest.allowPartialSearchResults(false);
boolean scroll = randomBoolean();
searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())
.aggregation(terms("f").field("f").size(docs + between(scroll ? 1 : 0, 10))).size(between(0, 1000)));
if (scroll) {
searchRequest.scroll("30s");
}
searchRequest.setCcsMinimizeRoundtrips(rarely());
futures.add(client(LOCAL_CLUSTER).search(searchRequest));
}

for (ActionFuture<SearchResponse> future : futures) {
SearchResponse searchResponse = future.get();
if (searchResponse.getScrollId() != null) {
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
clearScrollRequest.scrollIds(List.of(searchResponse.getScrollId()));
client(LOCAL_CLUSTER).clearScroll(clearScrollRequest).get();
}

Terms terms = searchResponse.getAggregations().get("f");
assertThat(terms.getBuckets().size(), equalTo(docs));
for (Terms.Bucket bucket : terms.getBuckets()) {
assertThat(bucket.getDocCount(), equalTo(1L));
}
}
}

@Override
protected void configureRemoteCluster(String clusterAlias, Collection<String> seedNodes) throws Exception {
if (rarely()) {
super.configureRemoteCluster(clusterAlias, seedNodes);
} else {
final Settings.Builder settings = Settings.builder();
final String seedNode = randomFrom(seedNodes);
final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, seedNode);
final String seedAddress = transportService.boundAddress().publishAddress().toString();

settings.put("cluster.remote." + clusterAlias + ".mode", "proxy");
settings.put("cluster.remote." + clusterAlias + ".proxy_address", seedAddress);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void sendExecuteQuery(Transport.Connection connection, final ShardSearchR
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.
final boolean fetchDocuments = request.numberOfShards() == 1;
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new;
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true);

final ActionListener<? super SearchPhaseResult> handler = responseWrapper.apply(connection, listener);
transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.common.io.stream;

import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.Releasable;

Expand Down Expand Up @@ -50,6 +51,12 @@ public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Read
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readReleasableBytesReference());
}

public static <T extends Writeable> DelayableWriteable<T> referencing(Writeable.Reader<T> reader, StreamInput in) throws IOException {
try (ReleasableBytesReference serialized = in.readReleasableBytesReference()) {
return new Referencing<>(deserialize(reader, in.getVersion(), in.namedWriteableRegistry(), serialized));
}
}

private DelayableWriteable() {}

/**
Expand All @@ -67,7 +74,7 @@ private DelayableWriteable() {}
* {@code true} if the {@linkplain Writeable} is being stored in
* serialized form, {@code false} otherwise.
*/
abstract boolean isSerialized();
public abstract boolean isSerialized();

/**
* Returns the serialized size of the inner {@link Writeable}.
Expand Down Expand Up @@ -104,7 +111,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
}

@Override
boolean isSerialized() {
public boolean isSerialized() {
return false;
}

Expand Down Expand Up @@ -169,11 +176,7 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public T expand() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(serializedAtVersion);
return reader.read(in);
}
return deserialize(reader, serializedAtVersion, registry, serialized);
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding serialized delayed writeable", e);
}
Expand All @@ -185,7 +188,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
}

@Override
boolean isSerialized() {
public boolean isSerialized() {
return true;
}

Expand Down Expand Up @@ -214,6 +217,15 @@ public static long getSerializedSize(Writeable ref) {
}
}

private static <T> T deserialize(Reader<T> reader, Version serializedAtVersion, NamedWriteableRegistry registry,
BytesReference serialized) throws IOException {
try (StreamInput in =
registry == null ? serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(serializedAtVersion);
return reader.read(in);
}
}

private static class CountingStreamOutput extends StreamOutput {
long size = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import static org.elasticsearch.common.lucene.Lucene.writeTopDocs;

public final class QuerySearchResult extends SearchPhaseResult {

private int from;
private int size;
private TopDocsAndMaxScore topDocsAndMaxScore;
Expand Down Expand Up @@ -65,6 +64,15 @@ public QuerySearchResult() {
}

public QuerySearchResult(StreamInput in) throws IOException {
this(in, false);
}

/**
* Read the object, but using a delayed aggregations field when delayedAggregations=true. Using this, the caller must ensure that
* either `consumeAggs` or `releaseAggs` is called if `hasAggs() == true`.
* @param delayedAggregations whether to use delayed aggregations or not
*/
public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException {
super(in);
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
isNull = in.readBoolean();
Expand All @@ -73,7 +81,7 @@ public QuerySearchResult(StreamInput in) throws IOException {
}
if (isNull == false) {
ShardSearchContextId id = new ShardSearchContextId(in);
readFromWithId(id, in);
readFromWithId(id, in, delayedAggregations);
}
}

Expand Down Expand Up @@ -316,6 +324,10 @@ public boolean hasSearchContext() {
}

public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOException {
readFromWithId(id, in, false);
}

private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean delayedAggregations) throws IOException {
this.contextId = id;
from = in.readVInt();
size = in.readVInt();
Expand All @@ -333,7 +345,11 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc
boolean success = false;
try {
if (hasAggs) {
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
if (delayedAggregations) {
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
} else {
aggregations = DelayableWriteable.referencing(InternalAggregations::readFrom, in);
}
}
if (in.readBoolean()) {
suggest = new Suggest(in);
Expand All @@ -359,6 +375,8 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc

@Override
public void writeTo(StreamOutput out) throws IOException {
// we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending.
assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak";
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeBoolean(isNull);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ public void testRoundTripFromDelayedWithNamedWriteable() throws IOException {
public void testRoundTripFromDelayedFromOldVersion() throws IOException {
Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion());
assertTrue(original.isSerialized());
roundTripTestCase(original, Example::new);
}

public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException {
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion());
assertTrue(original.isSerialized());
roundTripTestCase(original, NamedHolder::new);
}

Expand All @@ -160,14 +158,20 @@ public void testAsSerializedIsNoopOnSerialized() throws IOException {

private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
assertTrue(roundTripped.isSerialized());
assertThat(roundTripped.expand(), equalTo(original.expand()));
}

private <T extends Writeable> DelayableWriteable<T> roundTrip(DelayableWriteable<T> original,
Writeable.Reader<T> reader, Version version) throws IOException {
return copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
DelayableWriteable<T> delayed = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
in -> DelayableWriteable.delayed(reader, in), version);
assertTrue(delayed.isSerialized());

DelayableWriteable<T> referencing = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
in -> DelayableWriteable.referencing(reader, in), version);
assertFalse(referencing.isSerialized());

return randomFrom(delayed, referencing);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.elasticsearch.test.ESTestCase;

import static java.util.Collections.emptyList;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

public class QuerySearchResultTests extends ESTestCase {

Expand Down Expand Up @@ -68,8 +70,10 @@ private static QuerySearchResult createTestInstance() throws Exception {

public void testSerialization() throws Exception {
QuerySearchResult querySearchResult = createTestInstance();
boolean delayed = randomBoolean();
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry,
QuerySearchResult::new, Version.CURRENT);
delayed ? in -> new QuerySearchResult(in, true) : QuerySearchResult::new,
Version.CURRENT);
assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId());
assertNull(deserialized.getSearchShardTarget());
assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f);
Expand All @@ -78,9 +82,11 @@ public void testSerialization() throws Exception {
assertEquals(querySearchResult.size(), deserialized.size());
assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs());
if (deserialized.hasAggs()) {
assertThat(deserialized.aggregations().isSerialized(), is(delayed));
Aggregations aggs = querySearchResult.consumeAggs();
Aggregations deserializedAggs = deserialized.consumeAggs();
assertEquals(aggs.asList(), deserializedAggs.asList());
assertThat(deserialized.aggregations(), is(nullValue()));
}
assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ protected void disconnectFromRemoteClusters() throws Exception {
for (String clusterAlias : clusterAliases) {
if (clusterAlias.equals(LOCAL_CLUSTER) == false) {
settings.putNull("cluster.remote." + clusterAlias + ".seeds");
settings.putNull("cluster.remote." + clusterAlias + ".mode");
settings.putNull("cluster.remote." + clusterAlias + ".proxy_address");
}
}
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
Expand Down