diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index b22aa9669fbbc..b587c11aba690 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -443,7 +443,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection reduceContextFunction; private final List searchResponses = new CopyOnWriteArrayList<>(); - SearchResponseMerger(int from, int size, SearchTimeProvider searchTimeProvider, Clusters clusters, + SearchResponseMerger(int from, int size, int trackTotalHitsUpTo, SearchTimeProvider searchTimeProvider, Clusters clusters, Function reduceContextFunction) { this.from = from; this.size = size; + this.trackTotalHitsUpTo = trackTotalHitsUpTo; this.searchTimeProvider = Objects.requireNonNull(searchTimeProvider); this.clusters = Objects.requireNonNull(clusters); this.reduceContextFunction = Objects.requireNonNull(reduceContextFunction); @@ -102,7 +107,6 @@ SearchResponse getMergedResponse() { Boolean terminatedEarly = null; //the current reduce phase counts as one int numReducePhases = 1; - float maxScore = Float.NEGATIVE_INFINITY; List failures = new ArrayList<>(); Map profileResults = new HashMap<>(); List aggs = new ArrayList<>(); @@ -111,6 +115,8 @@ SearchResponse getMergedResponse() { Map> groupedSuggestions = new HashMap<>(); Boolean trackTotalHits = null; + TopDocsStats topDocsStats = new TopDocsStats(trackTotalHitsUpTo); + for (SearchResponse searchResponse : searchResponses) { totalShards += searchResponse.getTotalShards(); skippedShards += searchResponse.getSkippedShards(); @@ -139,12 +145,10 @@ SearchResponse getMergedResponse() { } SearchHits searchHits = searchResponse.getHits(); - if (Float.isNaN(searchHits.getMaxScore()) == false) { - maxScore = Math.max(maxScore, searchHits.getMaxScore()); - } + final TotalHits totalHits; if (searchHits.getTotalHits() == null) { - //in case we did't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs + //in case we didn't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO); assert trackTotalHits == null || trackTotalHits == false; trackTotalHits = false; @@ -153,7 +157,9 @@ SearchResponse getMergedResponse() { assert trackTotalHits == null || trackTotalHits; trackTotalHits = true; } - topDocsList.add(searchHitsToTopDocs(searchHits, totalHits, shards)); + TopDocs topDocs = searchHitsToTopDocs(searchHits, totalHits, shards); + topDocsStats.add(new TopDocsAndMaxScore(topDocs, searchHits.getMaxScore())); + topDocsList.add(topDocs); } //now that we've gone through all the hits and we collected all the shards they come from, we can assign shardIndex to each shard @@ -165,13 +171,15 @@ SearchResponse getMergedResponse() { for (TopDocs topDocs : topDocsList) { for (ScoreDoc scoreDoc : topDocs.scoreDocs) { FieldDocAndSearchHit fieldDocAndSearchHit = (FieldDocAndSearchHit) scoreDoc; + //When hits come from the indices with same names on multiple clusters and same shard identifier, we rely on such indices + //to have a different uuid across multiple clusters. That's how they will get a different shardIndex. ShardId shardId = fieldDocAndSearchHit.searchHit.getShard().getShardId(); fieldDocAndSearchHit.shardIndex = shards.get(shardId); } } - TopDocs topDocs = SearchPhaseController.mergeTopDocs(topDocsList, size, from); - SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, Float.isInfinite(maxScore) ? Float.NaN : maxScore, trackTotalHits); + TopDocs topDocs = mergeTopDocs(topDocsList, size, from); + SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats); Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions)); InternalAggregations reducedAggs = InternalAggregations.reduce(aggs, reduceContextFunction.apply(true)); ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY); @@ -250,7 +258,7 @@ private static TopDocs searchHitsToTopDocs(SearchHits searchHits, TotalHits tota return topDocs; } - private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, boolean trackTotalHits) { + private static SearchHits topDocsToSearchHits(TopDocs topDocs, TopDocsStats topDocsStats) { SearchHit[] searchHits = new SearchHit[topDocs.scoreDocs.length]; for (int i = 0; i < topDocs.scoreDocs.length; i++) { FieldDocAndSearchHit scoreDoc = (FieldDocAndSearchHit)topDocs.scoreDocs[i]; @@ -268,9 +276,8 @@ private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, b collapseValues = collapseTopFieldDocs.collapseValues; } } - //in case we didn't track total hits, we got null from each cluster, and we need to set null to the final response - final TotalHits totalHits = trackTotalHits ? topDocs.totalHits : null; - return new SearchHits(searchHits, totalHits, maxScore, sortFields, collapseField, collapseValues); + return new SearchHits(searchHits, topDocsStats.getTotalHits(), topDocsStats.getMaxScore(), + sortFields, collapseField, collapseValues); } private static void setShardIndex(Collection> shardResults) { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index a5ab81d83fbcd..23e5626e11aa9 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -139,7 +139,7 @@ public void testSortIsIdempotent() throws Exception { assertEquals(sortedDocs[i].shardIndex, sortedDocs2[i].shardIndex); assertEquals(sortedDocs[i].score, sortedDocs2[i].score, 0.0f); } - assertEquals(topDocsStats.maxScore, topDocsStats2.maxScore, 0.0f); + assertEquals(topDocsStats.getMaxScore(), topDocsStats2.getMaxScore(), 0.0f); assertEquals(topDocsStats.getTotalHits().value, topDocsStats2.getTotalHits().value); assertEquals(topDocsStats.getTotalHits().relation, topDocsStats2.getTotalHits().relation); assertEquals(topDocsStats.fetchHits, topDocsStats2.fetchHits); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java index 19af374d0be30..f2e49a109810d 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.ProfileShardResult; import org.elasticsearch.search.profile.SearchProfileShardResults; import org.elasticsearch.search.profile.SearchProfileShardResultsTests; @@ -84,7 +85,7 @@ public void testMergeTookInMillis() throws InterruptedException { SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime); SearchResponse.Clusters clusters = SearchResponseTests.randomClusters(); SearchResponseMerger merger = new SearchResponseMerger(randomIntBetween(0, 1000), randomIntBetween(0, 10000), - timeProvider, clusters, flag -> null); + SearchContext.TRACK_TOTAL_HITS_ACCURATE, timeProvider, clusters, flag -> null); for (int i = 0; i < numResponses; i++) { SearchResponse searchResponse = new SearchResponse(InternalSearchResponse.empty(), null, 1, 1, 0, randomLong(), ShardSearchFailure.EMPTY_ARRAY, SearchResponseTests.randomClusters()); @@ -97,7 +98,8 @@ public void testMergeTookInMillis() throws InterruptedException { public void testMergeShardFailures() throws InterruptedException { SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0); - SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); + SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE, + searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); PriorityQueue> priorityQueue = new PriorityQueue<>(Comparator.comparing(Tuple::v1)); int numIndices = numResponses * randomIntBetween(1, 3); Iterator> indicesPerCluster = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); @@ -136,7 +138,8 @@ public void testMergeShardFailures() throws InterruptedException { public void testMergeShardFailuresNullShardId() throws InterruptedException { SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0); - SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); + SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE, + searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); List expectedFailures = new ArrayList<>(); for (int i = 0; i < numResponses; i++) { int numFailures = randomIntBetween(1, 50); @@ -157,7 +160,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException { public void testMergeProfileResults() throws InterruptedException { SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0); - SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); + SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE, + searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null); Map expectedProfile = new HashMap<>(); for (int i = 0; i < numResponses; i++) { SearchProfileShardResults profile = SearchProfileShardResultsTests.createTestItem(); @@ -206,10 +210,14 @@ public void testMergeSearchHits() throws InterruptedException { sortFields = null; scoreSort = true; } - TotalHits.Relation totalHitsRelation = frequently() ? randomFrom(TotalHits.Relation.values()) : null; + Tuple randomTrackTotalHits = randomTrackTotalHits(); + int trackTotalHitsUpTo = randomTrackTotalHits.v1(); + TotalHits.Relation totalHitsRelation = randomTrackTotalHits.v2(); PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); - SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, timeProvider, clusters, flag -> null); + SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, trackTotalHitsUpTo, + timeProvider, clusters, flag -> null); + TotalHits expectedTotalHits = null; int expectedTotal = 0; int expectedSuccessful = 0; @@ -232,11 +240,10 @@ public void testMergeSearchHits() throws InterruptedException { expectedSkipped += skipped; TotalHits totalHits = null; - if (totalHitsRelation != null) { - //TODO totalHits may overflow if each cluster reports a very high number? + if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) { totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); long previousValue = expectedTotalHits == null ? 0 : expectedTotalHits.value; - expectedTotalHits = new TotalHits(previousValue + totalHits.value, totalHitsRelation); + expectedTotalHits = new TotalHits(Math.min(previousValue + totalHits.value, trackTotalHitsUpTo), totalHitsRelation); } final int numDocs = totalHits == null || totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; @@ -321,6 +328,19 @@ public void testMergeSearchHits() throws InterruptedException { } } + private static Tuple randomTrackTotalHits() { + switch(randomIntBetween(0, 2)) { + case 0: + return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_DISABLED, null); + case 1: + return Tuple.tuple(randomIntBetween(10, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + case 2: + return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_ACCURATE, TotalHits.Relation.EQUAL_TO); + default: + throw new UnsupportedOperationException(); + } + } + private static SearchHit[] randomSearchHitArray(int numDocs, int numResponses, String clusterAlias, Index[] indices, float maxScore, int scoreFactor, SortField[] sortFields, PriorityQueue priorityQueue) { SearchHit[] hits = new SearchHit[numDocs];