diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index cb3f1e80a393c..a6d7eb837941d 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -43,6 +43,7 @@ import org.elasticsearch.search.Scroll; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchSortValuesAndFormats; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.sort.FieldSortBuilder; @@ -212,7 +213,7 @@ public ShardSearchRequest( this.shardRequestIndex = shardRequestIndex; this.numberOfShards = numberOfShards; this.searchType = searchType; - this.source = source; + this.source(source); this.requestCache = requestCache; this.aliasFilter = aliasFilter; this.indexBoost = indexBoost; @@ -285,7 +286,7 @@ public ShardSearchRequest(ShardSearchRequest clone) { this.searchType = clone.searchType; this.numberOfShards = clone.numberOfShards; this.scroll = clone.scroll; - this.source = clone.source; + this.source(clone.source); this.aliasFilter = clone.aliasFilter; this.indexBoost = clone.indexBoost; this.nowInMillis = clone.nowInMillis; @@ -392,6 +393,13 @@ public void setAliasFilter(AliasFilter aliasFilter) { } public void source(SearchSourceBuilder source) { + if (source != null && source.pointInTimeBuilder() != null) { + // Discard the actual point in time as data nodes don't use it to reduce the memory usage and the serialization cost + // of shard-level search requests. However, we need to assign as a dummy PIT instead of null as we verify PIT for + // slice requests on data nodes. + source = source.shallowCopy(); + source.pointInTimeBuilder(new PointInTimeBuilder("")); + } this.source = source; } diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index ca7ca2f47d4eb..9a95ab04bab4c 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -22,6 +22,10 @@ import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.ClosePointInTimeAction; +import org.elasticsearch.action.search.ClosePointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeAction; +import org.elasticsearch.action.search.OpenPointInTimeRequest; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -78,6 +82,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; @@ -105,6 +110,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; @@ -113,6 +119,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; @@ -126,6 +133,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.Matchers.not; import static org.mockito.Mockito.mock; public class SearchServiceTests extends ESSingleNodeTestCase { @@ -1824,6 +1832,38 @@ public void testWaitOnRefreshTimeout() { assertThat(ex.getMessage(), containsString("Wait for seq_no [0] refreshed timed out [")); } + public void testMinimalSearchSourceInShardRequests() { + createIndex("test"); + int numDocs = between(0, 10); + for (int i = 0; i < numDocs; i++) { + client().prepareIndex("test").setSource("id", Integer.toString(i)).get(); + } + client().admin().indices().prepareRefresh("test").get(); + + String pitId = client().execute( + OpenPointInTimeAction.INSTANCE, + new OpenPointInTimeRequest("test").keepAlive(TimeValue.timeValueMinutes(10)) + ).actionGet().getPointInTimeId(); + final MockSearchService searchService = (MockSearchService) getInstanceFromNode(SearchService.class); + final List shardRequests = new CopyOnWriteArrayList<>(); + searchService.setOnCreateSearchContext(ctx -> shardRequests.add(ctx.request())); + try { + SearchRequest searchRequest = new SearchRequest().source( + new SearchSourceBuilder().size(between(numDocs, numDocs * 2)).pointInTimeBuilder(new PointInTimeBuilder(pitId)) + ); + final SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertHitCount(searchResponse, numDocs); + } finally { + client().execute(ClosePointInTimeAction.INSTANCE, new ClosePointInTimeRequest(pitId)).actionGet(); + } + assertThat(shardRequests, not(emptyList())); + for (ShardSearchRequest shardRequest : shardRequests) { + assertNotNull(shardRequest.source()); + assertNotNull(shardRequest.source().pointInTimeBuilder()); + assertThat(shardRequest.source().pointInTimeBuilder().getEncodedId(), equalTo("")); + } + } + private ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) { return new ReaderContext( new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()),