Skip to content

Commit

Permalink
Handle trackTotalHitsUpTo and disabling local hits tracking
Browse files Browse the repository at this point in the history
Adapt TopDocsStats so it can be reused.
  • Loading branch information
javanna committed Jan 18, 2019
1 parent 205f0aa commit a325129
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
Boolean terminatedEarly = null;
if (queryResults.isEmpty()) { // early terminate we have nothing to reduce
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
timedOut, terminatedEarly, null, null, null, SortedTopDocs.EMPTY, null, numReducePhases, 0, 0, true);
}
final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
Expand Down Expand Up @@ -508,7 +508,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
timedOut, terminatedEarly, suggest, aggregations, shardResults, sortedTopDocs,
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}
Expand Down Expand Up @@ -577,11 +577,7 @@ public static final class ReducedQueryPhase {
}
this.totalHits = totalHits;
this.fetchHits = fetchHits;
if (Float.isInfinite(maxScore)) {
this.maxScore = Float.NaN;
} else {
this.maxScore = maxScore;
}
this.maxScore = maxScore;
this.timedOut = timedOut;
this.terminatedEarly = terminatedEarly;
this.suggest = suggest;
Expand Down Expand Up @@ -744,7 +740,7 @@ static final class TopDocsStats {
private long totalHits;
private TotalHits.Relation totalHitsRelation;
long fetchHits;
float maxScore = Float.NEGATIVE_INFINITY;
private float maxScore = Float.NEGATIVE_INFINITY;

TopDocsStats() {
this(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
Expand All @@ -756,6 +752,10 @@ static final class TopDocsStats {
this.totalHitsRelation = Relation.EQUAL_TO;
}

float getMaxScore() {
return Float.isInfinite(maxScore) ? Float.NaN : maxScore;
}

TotalHits getTotalHits() {
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
Expand All @@ -52,6 +53,8 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function;

import static org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
import static org.elasticsearch.action.search.SearchResponse.Clusters;

/**
Expand All @@ -66,15 +69,17 @@
final class SearchResponseMerger {
private final int from;
private final int size;
int trackTotalHitsUpTo;
private final SearchTimeProvider searchTimeProvider;
private final Clusters clusters;
private final Function<Boolean, ReduceContext> reduceContextFunction;
private final List<SearchResponse> searchResponses = new CopyOnWriteArrayList<>();

SearchResponseMerger(int from, int size, SearchTimeProvider searchTimeProvider, Clusters clusters,
SearchResponseMerger(int from, int size, int trackTotalHitsUpTo, SearchTimeProvider searchTimeProvider, Clusters clusters,
Function<Boolean, ReduceContext> reduceContextFunction) {
this.from = from;
this.size = size;
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
this.searchTimeProvider = Objects.requireNonNull(searchTimeProvider);
this.clusters = Objects.requireNonNull(clusters);
this.reduceContextFunction = Objects.requireNonNull(reduceContextFunction);
Expand Down Expand Up @@ -102,7 +107,6 @@ SearchResponse getMergedResponse() {
Boolean terminatedEarly = null;
//the current reduce phase counts as one
int numReducePhases = 1;
float maxScore = Float.NEGATIVE_INFINITY;
List<ShardSearchFailure> failures = new ArrayList<>();
Map<String, ProfileShardResult> profileResults = new HashMap<>();
List<InternalAggregations> aggs = new ArrayList<>();
Expand All @@ -111,6 +115,8 @@ SearchResponse getMergedResponse() {
Map<String, List<Suggest.Suggestion>> groupedSuggestions = new HashMap<>();
Boolean trackTotalHits = null;

TopDocsStats topDocsStats = new TopDocsStats(trackTotalHitsUpTo);

for (SearchResponse searchResponse : searchResponses) {
totalShards += searchResponse.getTotalShards();
skippedShards += searchResponse.getSkippedShards();
Expand Down Expand Up @@ -139,12 +145,10 @@ SearchResponse getMergedResponse() {
}

SearchHits searchHits = searchResponse.getHits();
if (Float.isNaN(searchHits.getMaxScore()) == false) {
maxScore = Math.max(maxScore, searchHits.getMaxScore());
}

final TotalHits totalHits;
if (searchHits.getTotalHits() == null) {
//in case we did't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs
//in case we didn't track total hits, we get null from each cluster, but we need to set 0 eq to the TopDocs
totalHits = new TotalHits(0, TotalHits.Relation.EQUAL_TO);
assert trackTotalHits == null || trackTotalHits == false;
trackTotalHits = false;
Expand All @@ -153,7 +157,9 @@ SearchResponse getMergedResponse() {
assert trackTotalHits == null || trackTotalHits;
trackTotalHits = true;
}
topDocsList.add(searchHitsToTopDocs(searchHits, totalHits, shards));
TopDocs topDocs = searchHitsToTopDocs(searchHits, totalHits, shards);
topDocsStats.add(new TopDocsAndMaxScore(topDocs, searchHits.getMaxScore()));
topDocsList.add(topDocs);
}

//now that we've gone through all the hits and we collected all the shards they come from, we can assign shardIndex to each shard
Expand All @@ -165,13 +171,15 @@ SearchResponse getMergedResponse() {
for (TopDocs topDocs : topDocsList) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
FieldDocAndSearchHit fieldDocAndSearchHit = (FieldDocAndSearchHit) scoreDoc;
//When hits come from the indices with same names on multiple clusters and same shard identifier, we rely on such indices
//to have a different uuid across multiple clusters. That's how they will get a different shardIndex.
ShardId shardId = fieldDocAndSearchHit.searchHit.getShard().getShardId();
fieldDocAndSearchHit.shardIndex = shards.get(shardId);
}
}

TopDocs topDocs = SearchPhaseController.mergeTopDocs(topDocsList, size, from);
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, Float.isInfinite(maxScore) ? Float.NaN : maxScore, trackTotalHits);
TopDocs topDocs = mergeTopDocs(topDocsList, size, from);
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
InternalAggregations reducedAggs = InternalAggregations.reduce(aggs, reduceContextFunction.apply(true));
ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
Expand Down Expand Up @@ -250,7 +258,7 @@ private static TopDocs searchHitsToTopDocs(SearchHits searchHits, TotalHits tota
return topDocs;
}

private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, boolean trackTotalHits) {
private static SearchHits topDocsToSearchHits(TopDocs topDocs, TopDocsStats topDocsStats) {
SearchHit[] searchHits = new SearchHit[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
FieldDocAndSearchHit scoreDoc = (FieldDocAndSearchHit)topDocs.scoreDocs[i];
Expand All @@ -268,9 +276,8 @@ private static SearchHits topDocsToSearchHits(TopDocs topDocs, float maxScore, b
collapseValues = collapseTopFieldDocs.collapseValues;
}
}
//in case we didn't track total hits, we got null from each cluster, and we need to set null to the final response
final TotalHits totalHits = trackTotalHits ? topDocs.totalHits : null;
return new SearchHits(searchHits, totalHits, maxScore, sortFields, collapseField, collapseValues);
return new SearchHits(searchHits, topDocsStats.getTotalHits(), topDocsStats.getMaxScore(),
sortFields, collapseField, collapseValues);
}

private static void setShardIndex(Collection<List<FieldDoc>> shardResults) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void testSortIsIdempotent() throws Exception {
assertEquals(sortedDocs[i].shardIndex, sortedDocs2[i].shardIndex);
assertEquals(sortedDocs[i].score, sortedDocs2[i].score, 0.0f);
}
assertEquals(topDocsStats.maxScore, topDocsStats2.maxScore, 0.0f);
assertEquals(topDocsStats.getMaxScore(), topDocsStats2.getMaxScore(), 0.0f);
assertEquals(topDocsStats.getTotalHits().value, topDocsStats2.getTotalHits().value);
assertEquals(topDocsStats.getTotalHits().relation, topDocsStats2.getTotalHits().relation);
assertEquals(topDocsStats.fetchHits, topDocsStats2.fetchHits);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.search.profile.SearchProfileShardResultsTests;
Expand Down Expand Up @@ -84,7 +85,7 @@ public void testMergeTookInMillis() throws InterruptedException {
SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime);
SearchResponse.Clusters clusters = SearchResponseTests.randomClusters();
SearchResponseMerger merger = new SearchResponseMerger(randomIntBetween(0, 1000), randomIntBetween(0, 10000),
timeProvider, clusters, flag -> null);
SearchContext.TRACK_TOTAL_HITS_ACCURATE, timeProvider, clusters, flag -> null);
for (int i = 0; i < numResponses; i++) {
SearchResponse searchResponse = new SearchResponse(InternalSearchResponse.empty(), null, 1, 1, 0, randomLong(),
ShardSearchFailure.EMPTY_ARRAY, SearchResponseTests.randomClusters());
Expand All @@ -97,7 +98,8 @@ public void testMergeTookInMillis() throws InterruptedException {

public void testMergeShardFailures() throws InterruptedException {
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
PriorityQueue<Tuple<ShardId, ShardSearchFailure>> priorityQueue = new PriorityQueue<>(Comparator.comparing(Tuple::v1));
int numIndices = numResponses * randomIntBetween(1, 3);
Iterator<Map.Entry<String, Index[]>> indicesPerCluster = randomRealisticIndices(numIndices, numResponses).entrySet().iterator();
Expand Down Expand Up @@ -136,7 +138,8 @@ public void testMergeShardFailures() throws InterruptedException {

public void testMergeShardFailuresNullShardId() throws InterruptedException {
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
List<ShardSearchFailure> expectedFailures = new ArrayList<>();
for (int i = 0; i < numResponses; i++) {
int numFailures = randomIntBetween(1, 50);
Expand All @@ -157,7 +160,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException {

public void testMergeProfileResults() throws InterruptedException {
SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
searchTimeProvider, SearchResponse.Clusters.EMPTY, flag -> null);
Map<String, ProfileShardResult> expectedProfile = new HashMap<>();
for (int i = 0; i < numResponses; i++) {
SearchProfileShardResults profile = SearchProfileShardResultsTests.createTestItem();
Expand Down Expand Up @@ -206,10 +210,14 @@ public void testMergeSearchHits() throws InterruptedException {
sortFields = null;
scoreSort = true;
}
TotalHits.Relation totalHitsRelation = frequently() ? randomFrom(TotalHits.Relation.values()) : null;
Tuple<Integer, TotalHits.Relation> randomTrackTotalHits = randomTrackTotalHits();
int trackTotalHitsUpTo = randomTrackTotalHits.v1();
TotalHits.Relation totalHitsRelation = randomTrackTotalHits.v2();

PriorityQueue<SearchHit> priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields));
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, timeProvider, clusters, flag -> null);
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, trackTotalHitsUpTo,
timeProvider, clusters, flag -> null);

TotalHits expectedTotalHits = null;
int expectedTotal = 0;
int expectedSuccessful = 0;
Expand All @@ -232,11 +240,10 @@ public void testMergeSearchHits() throws InterruptedException {
expectedSkipped += skipped;

TotalHits totalHits = null;
if (totalHitsRelation != null) {
//TODO totalHits may overflow if each cluster reports a very high number?
if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) {
totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation);
long previousValue = expectedTotalHits == null ? 0 : expectedTotalHits.value;
expectedTotalHits = new TotalHits(previousValue + totalHits.value, totalHitsRelation);
expectedTotalHits = new TotalHits(Math.min(previousValue + totalHits.value, trackTotalHitsUpTo), totalHitsRelation);
}

final int numDocs = totalHits == null || totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value;
Expand Down Expand Up @@ -321,6 +328,19 @@ public void testMergeSearchHits() throws InterruptedException {
}
}

private static Tuple<Integer, TotalHits.Relation> randomTrackTotalHits() {
switch(randomIntBetween(0, 2)) {
case 0:
return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_DISABLED, null);
case 1:
return Tuple.tuple(randomIntBetween(10, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
case 2:
return Tuple.tuple(SearchContext.TRACK_TOTAL_HITS_ACCURATE, TotalHits.Relation.EQUAL_TO);
default:
throw new UnsupportedOperationException();
}
}

private static SearchHit[] randomSearchHitArray(int numDocs, int numResponses, String clusterAlias, Index[] indices, float maxScore,
int scoreFactor, SortField[] sortFields, PriorityQueue<SearchHit> priorityQueue) {
SearchHit[] hits = new SearchHit[numDocs];
Expand Down

0 comments on commit a325129

Please sign in to comment.