Skip to content

Commit

Permalink
Fix aggregation memory leak for CCS (elastic#78404)
Browse files Browse the repository at this point in the history
When a CCS search is proxied, the memory for the aggregations on the
proxy node would not be freed.

Now only use the non-copying byte referencing version on the coordinating node,
which itself ensures that memory is freed by calling `consumeAggs`.
  • Loading branch information
henningandersen committed Oct 4, 2021
1 parent 69e5452 commit 1c311e7
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.ccs;

import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.transport.TransportService;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;

public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase {

@Override
protected Collection<String> remoteClusterAlias() {
return org.elasticsearch.core.List.of("cluster_a");
}

@Override
protected boolean reuseClusters() {
return false;
}

private int indexDocs(Client client, String field, String index) {
int numDocs = between(1, 200);
for (int i = 0; i < numDocs; i++) {
client.prepareIndex(index, "_doc").setSource(field, "v" + i).get();
}
client.admin().indices().prepareRefresh(index).get();
return numDocs;
}

/**
* This test validates that we do not leak any memory when running CCS in various modes, actual validation is done by test framework
* (leak detection)
* <ul>
* <li>proxy vs non-proxy</li>
* <li>single-phase query-fetch or multi-phase</li>
* <li>minimize roundtrip vs not</li>
* <li>scroll vs no scroll</li>
* </ul>
*/
public void testSearch() throws Exception {
assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo")
.addMapping("_doc", "f", "type=keyword")
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
indexDocs(client(LOCAL_CLUSTER), "ignored", "demo");
final InternalTestCluster remoteCluster = cluster("cluster_a");
int minRemotes = between(2, 5);
remoteCluster.ensureAtLeastNumDataNodes(minRemotes);
List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
.filter(DiscoveryNode::canContainData)
.map(DiscoveryNode::getName)
.collect(Collectors.toList());
assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(minRemotes));
List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
disconnectFromRemoteClusters();
configureRemoteCluster("cluster_a", seedNodes);
final Settings.Builder allocationFilter = Settings.builder();
if (rarely()) {
allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
} else {
// Provoke using proxy connections
allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
}
assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
.addMapping("_doc", "f", "type=keyword")
.setSettings(Settings.builder().put(allocationFilter.build())
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
.setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
int docs = indexDocs(client("cluster_a"), "f", "prod");

List<ActionFuture<SearchResponse>> futures = new ArrayList<>();
for (int i = 0; i < 10; ++i) {
String[] indices = randomBoolean() ? new String[] { "demo", "cluster_a:prod" } : new String[] { "cluster_a:prod" };
final SearchRequest searchRequest = new SearchRequest(indices);
searchRequest.allowPartialSearchResults(false);
boolean scroll = randomBoolean();
searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())
.aggregation(terms("f").field("f").size(docs + between(scroll ? 1 : 0, 10))).size(between(0, 1000)));
if (scroll) {
searchRequest.scroll("30s");
}
searchRequest.setCcsMinimizeRoundtrips(rarely());
futures.add(client(LOCAL_CLUSTER).search(searchRequest));
}

for (ActionFuture<SearchResponse> future : futures) {
SearchResponse searchResponse = future.get();
if (searchResponse.getScrollId() != null) {
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
clearScrollRequest.scrollIds(org.elasticsearch.core.List.of(searchResponse.getScrollId()));
client(LOCAL_CLUSTER).clearScroll(clearScrollRequest).get();
}

Terms terms = searchResponse.getAggregations().get("f");
assertThat(terms.getBuckets().size(), equalTo(docs));
for (Terms.Bucket bucket : terms.getBuckets()) {
assertThat(bucket.getDocCount(), equalTo(1L));
}
}
}

@Override
protected void configureRemoteCluster(String clusterAlias, Collection<String> seedNodes) throws Exception {
if (rarely()) {
super.configureRemoteCluster(clusterAlias, seedNodes);
} else {
final Settings.Builder settings = Settings.builder();
final String seedNode = randomFrom(seedNodes);
final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, seedNode);
final String seedAddress = transportService.boundAddress().publishAddress().toString();

settings.put("cluster.remote." + clusterAlias + ".mode", "proxy");
settings.put("cluster.remote." + clusterAlias + ".proxy_address", seedAddress);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void sendExecuteQuery(Transport.Connection connection, final ShardSearchR
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.
final boolean fetchDocuments = request.numberOfShards() == 1;
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new;
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true);

final ActionListener<? super SearchPhaseResult> handler = responseWrapper.apply(connection, listener);
transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.common.io.stream;

import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.Releasable;

Expand Down Expand Up @@ -50,6 +51,12 @@ public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Read
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readReleasableBytesReference());
}

public static <T extends Writeable> DelayableWriteable<T> referencing(Writeable.Reader<T> reader, StreamInput in) throws IOException {
try (ReleasableBytesReference serialized = in.readReleasableBytesReference()) {
return new Referencing<>(deserialize(reader, in.getVersion(), in.namedWriteableRegistry(), serialized));
}
}

private DelayableWriteable() {}

/**
Expand All @@ -67,7 +74,7 @@ private DelayableWriteable() {}
* {@code true} if the {@linkplain Writeable} is being stored in
* serialized form, {@code false} otherwise.
*/
abstract boolean isSerialized();
public abstract boolean isSerialized();

/**
* Returns the serialized size of the inner {@link Writeable}.
Expand Down Expand Up @@ -104,7 +111,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
}

@Override
boolean isSerialized() {
public boolean isSerialized() {
return false;
}

Expand Down Expand Up @@ -169,11 +176,7 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public T expand() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(serializedAtVersion);
return reader.read(in);
}
return deserialize(reader, serializedAtVersion, registry, serialized);
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding serialized delayed writeable", e);
}
Expand All @@ -185,7 +188,7 @@ public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry regis
}

@Override
boolean isSerialized() {
public boolean isSerialized() {
return true;
}

Expand Down Expand Up @@ -214,6 +217,15 @@ public static long getSerializedSize(Writeable ref) {
}
}

private static <T> T deserialize(Reader<T> reader, Version serializedAtVersion, NamedWriteableRegistry registry,
BytesReference serialized) throws IOException {
try (StreamInput in =
registry == null ? serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(serializedAtVersion);
return reader.read(in);
}
}

private static class CountingStreamOutput extends StreamOutput {
long size = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.elasticsearch.search.suggest.Suggest;

public final class QuerySearchResult extends SearchPhaseResult {

private int from;
private int size;
private TopDocsAndMaxScore topDocsAndMaxScore;
Expand Down Expand Up @@ -67,6 +66,15 @@ public QuerySearchResult() {
}

public QuerySearchResult(StreamInput in) throws IOException {
this(in, false);
}

/**
* Read the object, but using a delayed aggregations field when delayedAggregations=true. Using this, the caller must ensure that
* either `consumeAggs` or `releaseAggs` is called if `hasAggs() == true`.
* @param delayedAggregations whether to use delayed aggregations or not
*/
public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException {
super(in);
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
isNull = in.readBoolean();
Expand All @@ -75,7 +83,7 @@ public QuerySearchResult(StreamInput in) throws IOException {
}
if (isNull == false) {
ShardSearchContextId id = new ShardSearchContextId(in);
readFromWithId(id, in);
readFromWithId(id, in, delayedAggregations);
}
}

Expand Down Expand Up @@ -318,6 +326,10 @@ public boolean hasSearchContext() {
}

public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOException {
readFromWithId(id, in, false);
}

private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean delayedAggregations) throws IOException {
this.contextId = id;
from = in.readVInt();
size = in.readVInt();
Expand All @@ -344,7 +356,11 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc
}
} else {
if (hasAggs) {
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
if (delayedAggregations) {
aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
} else {
aggregations = DelayableWriteable.referencing(InternalAggregations::readFrom, in);
}
}
}
if (in.readBoolean()) {
Expand All @@ -371,6 +387,8 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc

@Override
public void writeTo(StreamOutput out) throws IOException {
// we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending.
assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak";
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeBoolean(isNull);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ public void testRoundTripFromDelayedWithNamedWriteable() throws IOException {
public void testRoundTripFromDelayedFromOldVersion() throws IOException {
Example e = new Example(randomAlphaOfLength(5));
DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion());
assertTrue(original.isSerialized());
roundTripTestCase(original, Example::new);
}

public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException {
NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion());
assertTrue(original.isSerialized());
roundTripTestCase(original, NamedHolder::new);
}

Expand All @@ -160,14 +158,20 @@ public void testAsSerializedIsNoopOnSerialized() throws IOException {

private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
assertTrue(roundTripped.isSerialized());
assertThat(roundTripped.expand(), equalTo(original.expand()));
}

private <T extends Writeable> DelayableWriteable<T> roundTrip(DelayableWriteable<T> original,
Writeable.Reader<T> reader, Version version) throws IOException {
return copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
DelayableWriteable<T> delayed = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
in -> DelayableWriteable.delayed(reader, in), version);
assertTrue(delayed.isSerialized());

DelayableWriteable<T> referencing = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
in -> DelayableWriteable.referencing(reader, in), version);
assertFalse(referencing.isSerialized());

return randomFrom(delayed, referencing);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import java.util.Base64;

import static java.util.Collections.emptyList;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

public class QuerySearchResultTests extends ESTestCase {

Expand Down Expand Up @@ -74,7 +76,9 @@ private static QuerySearchResult createTestInstance() throws Exception {

public void testSerialization() throws Exception {
QuerySearchResult querySearchResult = createTestInstance();
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new);
boolean delayed = randomBoolean();
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry,
delayed ? in -> new QuerySearchResult(in, true) : QuerySearchResult::new);
assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId());
assertNull(deserialized.getSearchShardTarget());
assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f);
Expand All @@ -83,9 +87,11 @@ public void testSerialization() throws Exception {
assertEquals(querySearchResult.size(), deserialized.size());
assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs());
if (deserialized.hasAggs()) {
assertThat(deserialized.aggregations().isSerialized(), is(delayed));
Aggregations aggs = querySearchResult.consumeAggs();
Aggregations deserializedAggs = deserialized.consumeAggs();
assertEquals(aggs.asList(), deserializedAggs.asList());
assertThat(deserialized.aggregations(), is(nullValue()));
}
assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ protected void disconnectFromRemoteClusters() throws Exception {
for (String clusterAlias : clusterAliases) {
if (clusterAlias.equals(LOCAL_CLUSTER) == false) {
settings.putNull("cluster.remote." + clusterAlias + ".seeds");
settings.putNull("cluster.remote." + clusterAlias + ".mode");
settings.putNull("cluster.remote." + clusterAlias + ".proxy_address");
}
}
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
Expand Down

0 comments on commit 1c311e7

Please sign in to comment.