diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java new file mode 100644 index 0000000000000..4ff0089cdbb71 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.transport.TransportService; +import org.hamcrest.Matchers; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; + +public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase { + + @Override + protected Collection remoteClusterAlias() { + return List.of("cluster_a"); + } + + @Override + protected boolean reuseClusters() { + return false; + } + + private int indexDocs(Client client, String field, String index) { + int numDocs = between(1, 200); + for (int i = 0; i < numDocs; i++) { + client.prepareIndex(index).setSource(field, "v" + i).get(); + } + client.admin().indices().prepareRefresh(index).get(); + return numDocs; + } + + /** + * This test validates that we do not leak any memory when running CCS in various modes, actual validation is done by test framework + * (leak detection) + * + */ + public void testSearch() throws Exception { + assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo") + .setMapping("f", "type=keyword") + .setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3)))); + indexDocs(client(LOCAL_CLUSTER), "ignored", "demo"); + final InternalTestCluster remoteCluster = cluster("cluster_a"); + int minRemotes = between(2, 5); + remoteCluster.ensureAtLeastNumDataNodes(minRemotes); + List remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false) + .filter(DiscoveryNode::canContainData) + .map(DiscoveryNode::getName) + .collect(Collectors.toList()); + assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(minRemotes)); + List seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes); + disconnectFromRemoteClusters(); + configureRemoteCluster("cluster_a", seedNodes); + final Settings.Builder allocationFilter = Settings.builder(); + if (rarely()) { + allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes)); + } else { + // Provoke using proxy connections + allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes)); + } + assertAcked(client("cluster_a").admin().indices().prepareCreate("prod") + .setMapping("f", "type=keyword") + .setSettings(Settings.builder().put(allocationFilter.build()) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3)))); + assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod") + .setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut()); + int docs = indexDocs(client("cluster_a"), "f", "prod"); + + List> futures = new ArrayList<>(); + for (int i = 0; i < 10; ++i) { + String[] indices = randomBoolean() ? new String[] { "demo", "cluster_a:prod" } : new String[] { "cluster_a:prod" }; + final SearchRequest searchRequest = new SearchRequest(indices); + searchRequest.allowPartialSearchResults(false); + boolean scroll = randomBoolean(); + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()) + .aggregation(terms("f").field("f").size(docs + between(scroll ? 1 : 0, 10))).size(between(0, 1000))); + if (scroll) { + searchRequest.scroll("30s"); + } + searchRequest.setCcsMinimizeRoundtrips(rarely()); + futures.add(client(LOCAL_CLUSTER).search(searchRequest)); + } + + for (ActionFuture future : futures) { + SearchResponse searchResponse = future.get(); + if (searchResponse.getScrollId() != null) { + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.scrollIds(List.of(searchResponse.getScrollId())); + client(LOCAL_CLUSTER).clearScroll(clearScrollRequest).get(); + } + + Terms terms = searchResponse.getAggregations().get("f"); + assertThat(terms.getBuckets().size(), equalTo(docs)); + for (Terms.Bucket bucket : terms.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(1L)); + } + } + } + + @Override + protected void configureRemoteCluster(String clusterAlias, Collection seedNodes) throws Exception { + if (rarely()) { + super.configureRemoteCluster(clusterAlias, seedNodes); + } else { + final Settings.Builder settings = Settings.builder(); + final String seedNode = randomFrom(seedNodes); + final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, seedNode); + final String seedAddress = transportService.boundAddress().publishAddress().toString(); + + settings.put("cluster.remote." + clusterAlias + ".mode", "proxy"); + settings.put("cluster.remote." + clusterAlias + ".proxy_address", seedAddress); + client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index abc400734d368..41860c52174d4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -138,7 +138,7 @@ public void sendExecuteQuery(Transport.Connection connection, final ShardSearchR // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request // this used to be the QUERY_AND_FETCH which doesn't exist anymore. final boolean fetchDocuments = request.numberOfShards() == 1; - Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; + Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true); final ActionListener handler = responseWrapper.apply(connection, listener); transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task, diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java b/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java index d0e725651b16d..f103a4c11d5d6 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java @@ -9,6 +9,7 @@ package org.elasticsearch.common.io.stream; import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.core.Releasable; @@ -50,6 +51,12 @@ public static DelayableWriteable delayed(Writeable.Read return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readReleasableBytesReference()); } + public static DelayableWriteable referencing(Writeable.Reader reader, StreamInput in) throws IOException { + try (ReleasableBytesReference serialized = in.readReleasableBytesReference()) { + return new Referencing<>(deserialize(reader, in.getVersion(), in.namedWriteableRegistry(), serialized)); + } + } + private DelayableWriteable() {} /** @@ -67,7 +74,7 @@ private DelayableWriteable() {} * {@code true} if the {@linkplain Writeable} is being stored in * serialized form, {@code false} otherwise. */ - abstract boolean isSerialized(); + public abstract boolean isSerialized(); /** * Returns the serialized size of the inner {@link Writeable}. @@ -104,7 +111,7 @@ public Serialized asSerialized(Reader reader, NamedWriteableRegistry regis } @Override - boolean isSerialized() { + public boolean isSerialized() { return false; } @@ -169,11 +176,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public T expand() { try { - try (StreamInput in = registry == null ? - serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) { - in.setVersion(serializedAtVersion); - return reader.read(in); - } + return deserialize(reader, serializedAtVersion, registry, serialized); } catch (IOException e) { throw new RuntimeException("unexpected error expanding serialized delayed writeable", e); } @@ -185,7 +188,7 @@ public Serialized asSerialized(Reader reader, NamedWriteableRegistry regis } @Override - boolean isSerialized() { + public boolean isSerialized() { return true; } @@ -214,6 +217,15 @@ public static long getSerializedSize(Writeable ref) { } } + private static T deserialize(Reader reader, Version serializedAtVersion, NamedWriteableRegistry registry, + BytesReference serialized) throws IOException { + try (StreamInput in = + registry == null ? serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) { + in.setVersion(serializedAtVersion); + return reader.read(in); + } + } + private static class CountingStreamOutput extends StreamOutput { long size = 0; diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index f35cb57314dba..2f594abfe7d9a 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -33,7 +33,6 @@ import static org.elasticsearch.common.lucene.Lucene.writeTopDocs; public final class QuerySearchResult extends SearchPhaseResult { - private int from; private int size; private TopDocsAndMaxScore topDocsAndMaxScore; @@ -65,6 +64,15 @@ public QuerySearchResult() { } public QuerySearchResult(StreamInput in) throws IOException { + this(in, false); + } + + /** + * Read the object, but using a delayed aggregations field when delayedAggregations=true. Using this, the caller must ensure that + * either `consumeAggs` or `releaseAggs` is called if `hasAggs() == true`. + * @param delayedAggregations whether to use delayed aggregations or not + */ + public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException { super(in); if (in.getVersion().onOrAfter(Version.V_7_7_0)) { isNull = in.readBoolean(); @@ -73,7 +81,7 @@ public QuerySearchResult(StreamInput in) throws IOException { } if (isNull == false) { ShardSearchContextId id = new ShardSearchContextId(in); - readFromWithId(id, in); + readFromWithId(id, in, delayedAggregations); } } @@ -316,6 +324,10 @@ public boolean hasSearchContext() { } public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOException { + readFromWithId(id, in, false); + } + + private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean delayedAggregations) throws IOException { this.contextId = id; from = in.readVInt(); size = in.readVInt(); @@ -333,7 +345,11 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc boolean success = false; try { if (hasAggs) { - aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in); + if (delayedAggregations) { + aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in); + } else { + aggregations = DelayableWriteable.referencing(InternalAggregations::readFrom, in); + } } if (in.readBoolean()) { suggest = new Suggest(in); @@ -359,6 +375,8 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc @Override public void writeTo(StreamOutput out) throws IOException { + // we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending. + assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak"; if (out.getVersion().onOrAfter(Version.V_7_7_0)) { out.writeBoolean(isNull); } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java index b6d57e2997ce0..d18e9976e0968 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java @@ -134,14 +134,12 @@ public void testRoundTripFromDelayedWithNamedWriteable() throws IOException { public void testRoundTripFromDelayedFromOldVersion() throws IOException { Example e = new Example(randomAlphaOfLength(5)); DelayableWriteable original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion()); - assertTrue(original.isSerialized()); roundTripTestCase(original, Example::new); } public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException { NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5))); DelayableWriteable original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion()); - assertTrue(original.isSerialized()); roundTripTestCase(original, NamedHolder::new); } @@ -160,14 +158,20 @@ public void testAsSerializedIsNoopOnSerialized() throws IOException { private void roundTripTestCase(DelayableWriteable original, Writeable.Reader reader) throws IOException { DelayableWriteable roundTripped = roundTrip(original, reader, Version.CURRENT); - assertTrue(roundTripped.isSerialized()); assertThat(roundTripped.expand(), equalTo(original.expand())); } private DelayableWriteable roundTrip(DelayableWriteable original, Writeable.Reader reader, Version version) throws IOException { - return copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out), + DelayableWriteable delayed = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out), in -> DelayableWriteable.delayed(reader, in), version); + assertTrue(delayed.isSerialized()); + + DelayableWriteable referencing = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out), + in -> DelayableWriteable.referencing(reader, in), version); + assertFalse(referencing.isSerialized()); + + return randomFrom(delayed, referencing); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java b/server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java index bf645d137cda9..2e77f6bff064c 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java @@ -33,6 +33,8 @@ import org.elasticsearch.test.ESTestCase; import static java.util.Collections.emptyList; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class QuerySearchResultTests extends ESTestCase { @@ -68,8 +70,10 @@ private static QuerySearchResult createTestInstance() throws Exception { public void testSerialization() throws Exception { QuerySearchResult querySearchResult = createTestInstance(); + boolean delayed = randomBoolean(); QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, - QuerySearchResult::new, Version.CURRENT); + delayed ? in -> new QuerySearchResult(in, true) : QuerySearchResult::new, + Version.CURRENT); assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId()); assertNull(deserialized.getSearchShardTarget()); assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f); @@ -78,9 +82,11 @@ public void testSerialization() throws Exception { assertEquals(querySearchResult.size(), deserialized.size()); assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs()); if (deserialized.hasAggs()) { + assertThat(deserialized.aggregations().isSerialized(), is(delayed)); Aggregations aggs = querySearchResult.consumeAggs(); Aggregations deserializedAggs = deserialized.consumeAggs(); assertEquals(aggs.asList(), deserializedAggs.asList()); + assertThat(deserialized.aggregations(), is(nullValue())); } assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly()); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index 138d330884d83..95d4766a91687 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -134,6 +134,8 @@ protected void disconnectFromRemoteClusters() throws Exception { for (String clusterAlias : clusterAliases) { if (clusterAlias.equals(LOCAL_CLUSTER) == false) { settings.putNull("cluster.remote." + clusterAlias + ".seeds"); + settings.putNull("cluster.remote." + clusterAlias + ".mode"); + settings.putNull("cluster.remote." + clusterAlias + ".proxy_address"); } } client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();