Skip to content

Commit

Permalink
Tie-break completion suggestions with same score and surface form (#3…
Browse files Browse the repository at this point in the history
…9564)

In case multiple completion suggestion entries have the same score and
surface form, the order in which such options will be returned is
currently not deterministic.

With this commmit we introduce tie-breaking for such situations, based
on shard id, index name, index uuid and doc id like we already do for
 ordinary search hits. With this change we also make shardIndex
mandatory when sorting and comparing completion suggestion options,
which was previously only needed later when fetching hits).

Also, we need to make sure shardIndex is properly set when merging
completion suggestions coming from multiple clusters in
`SearchResponseMerger`
  • Loading branch information
javanna committed Mar 5, 2019
1 parent ca83408 commit 9d02114
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -152,6 +153,16 @@ SearchResponse getMergedResponse(Clusters clusters) {
List<Suggest.Suggestion> suggestionList = groupedSuggestions.computeIfAbsent(entries.getName(), s -> new ArrayList<>());
suggestionList.add(entries);
}
List<CompletionSuggestion> completionSuggestions = suggest.filter(CompletionSuggestion.class);
for (CompletionSuggestion completionSuggestion : completionSuggestions) {
for (CompletionSuggestion.Entry options : completionSuggestion) {
for (CompletionSuggestion.Entry.Option option : options) {
SearchShardTarget shard = option.getHit().getShard();
ShardIdAndClusterAlias shardId = new ShardIdAndClusterAlias(shard.getShardId(), shard.getClusterAlias());
shards.putIfAbsent(shardId, null);
}
}
}
}

SearchHits searchHits = searchResponse.getHits();
Expand All @@ -174,14 +185,15 @@ SearchResponse getMergedResponse(Clusters clusters) {
}

//after going through all the hits and collecting all their distinct shards, we can assign shardIndex and set it to the ScoreDocs
setShardIndex(shards, topDocsList);
setTopDocsShardIndex(shards, topDocsList);
setSuggestShardIndex(shards, groupedSuggestions);
TopDocs topDocs = mergeTopDocs(topDocsList, size, from);
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
InternalAggregations reducedAggs = InternalAggregations.reduce(aggs, reduceContextFunction.apply(true));
ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
SearchProfileShardResults profileShardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
//make failures ordering consistent with ordinary search and CCS
//make failures ordering consistent between ordinary search and CCS by looking at the shard they come from
Arrays.sort(shardFailures, FAILURES_COMPARATOR);
InternalSearchResponse response = new InternalSearchResponse(mergedSearchHits, reducedAggs, suggest, profileShardResults,
topDocsStats.timedOut, topDocsStats.terminatedEarly, numReducePhases);
Expand Down Expand Up @@ -275,14 +287,8 @@ private static TopDocs searchHitsToTopDocs(SearchHits searchHits, TotalHits tota
return topDocs;
}

private static void setShardIndex(Map<ShardIdAndClusterAlias, Integer> shards, List<TopDocs> topDocsList) {
{
//assign a different shardIndex to each shard, based on their shardId natural ordering and their cluster alias
int shardIndex = 0;
for (Map.Entry<ShardIdAndClusterAlias, Integer> shard : shards.entrySet()) {
shard.setValue(shardIndex++);
}
}
private static void setTopDocsShardIndex(Map<ShardIdAndClusterAlias, Integer> shards, List<TopDocs> topDocsList) {
assignShardIndex(shards);
//go through all the scoreDocs from each cluster and set their corresponding shardIndex
for (TopDocs topDocs : topDocsList) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
Expand All @@ -295,6 +301,34 @@ private static void setShardIndex(Map<ShardIdAndClusterAlias, Integer> shards, L
}
}

private static void setSuggestShardIndex(Map<ShardIdAndClusterAlias, Integer> shards,
Map<String, List<Suggest.Suggestion>> groupedSuggestions) {
assignShardIndex(shards);
for (List<Suggest.Suggestion> suggestions : groupedSuggestions.values()) {
for (Suggest.Suggestion suggestion : suggestions) {
if (suggestion instanceof CompletionSuggestion) {
CompletionSuggestion completionSuggestion = (CompletionSuggestion) suggestion;
for (CompletionSuggestion.Entry options : completionSuggestion) {
for (CompletionSuggestion.Entry.Option option : options) {
SearchShardTarget shard = option.getHit().getShard();
ShardIdAndClusterAlias shardId = new ShardIdAndClusterAlias(shard.getShardId(), shard.getClusterAlias());
assert shards.containsKey(shardId);
option.setShardIndex(shards.get(shardId));
}
}
}
}
}
}

private static void assignShardIndex(Map<ShardIdAndClusterAlias, Integer> shards) {
//assign a different shardIndex to each shard, based on their shardId natural ordering and their cluster alias
int shardIndex = 0;
for (Map.Entry<ShardIdAndClusterAlias, Integer> shard : shards.entrySet()) {
shard.setValue(shardIndex++);
}
}

private static SearchHits topDocsToSearchHits(TopDocs topDocs, TopDocsStats topDocsStats) {
SearchHit[] searchHits = new SearchHit[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
Expand Down Expand Up @@ -340,6 +374,7 @@ private static final class ShardIdAndClusterAlias implements Comparable<ShardIdA

ShardIdAndClusterAlias(ShardId shardId, String clusterAlias) {
this.shardId = shardId;
assert clusterAlias != null : "clusterAlias is null";
this.clusterAlias = clusterAlias;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,16 @@ private static final class OptionPriorityQueue extends PriorityQueue<ShardOption

@Override
protected boolean lessThan(ShardOptions a, ShardOptions b) {
return COMPARATOR.compare(a.current, b.current) < 0;
int compare = COMPARATOR.compare(a.current, b.current);
if (compare != 0) {
return compare < 0;
}
ScoreDoc aDoc = a.current.getDoc();
ScoreDoc bDoc = b.current.getDoc();
if (aDoc.shardIndex == bDoc.shardIndex) {
return aDoc.doc < bDoc.doc;
}
return aDoc.shardIndex < bDoc.shardIndex;
}
}

Expand All @@ -157,6 +166,7 @@ private ShardOptions(Iterator<Entry.Option> optionsIterator) {
assert optionsIterator.hasNext();
this.optionsIterator = optionsIterator;
this.current = optionsIterator.next();
assert this.current.getDoc().shardIndex != -1 : "shardIndex is not set";
}

boolean advanceToNextOption() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
import org.junit.Before;

import java.util.ArrayList;
Expand All @@ -64,6 +65,8 @@
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class SearchResponseMergerTests extends ESTestCase {
Expand Down Expand Up @@ -241,17 +244,25 @@ public void testMergeProfileResults() throws InterruptedException {
assertEquals(expectedProfile, mergedResponse.getProfileResults());
}

public void testMergeSuggestions() throws InterruptedException {
public void testMergeCompletionSuggestions() throws InterruptedException {
String suggestionName = randomAlphaOfLengthBetween(4, 8);
boolean skipDuplicates = randomBoolean();
int size = randomIntBetween(1, 100);
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0), flag -> null);
for (int i = 0; i < numResponses; i++) {
List<Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>>> suggestions =
new ArrayList<>();
CompletionSuggestion completionSuggestion = new CompletionSuggestion(suggestionName, size, skipDuplicates);
CompletionSuggestion completionSuggestion = new CompletionSuggestion(suggestionName, size, false);
CompletionSuggestion.Entry options = new CompletionSuggestion.Entry(new Text("suggest"), 0, 10);
options.addOption(new CompletionSuggestion.Entry.Option(randomInt(), new Text("suggestion"), i, Collections.emptyMap()));
int docId = randomIntBetween(0, Integer.MAX_VALUE);
CompletionSuggestion.Entry.Option option = new CompletionSuggestion.Entry.Option(docId,
new Text(randomAlphaOfLengthBetween(5, 10)), i, Collections.emptyMap());
SearchHit hit = new SearchHit(docId);
ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10),
randomIntBetween(0, Integer.MAX_VALUE));
String clusterAlias = randomBoolean() ? "" : randomAlphaOfLengthBetween(5, 10);
hit.shard(new SearchShardTarget("node", shardId, clusterAlias, OriginalIndices.NONE));
option.setHit(hit);
options.addOption(option);
completionSuggestion.addTerm(options);
suggestions.add(completionSuggestion);
Suggest suggest = new Suggest(suggestions);
Expand All @@ -275,14 +286,69 @@ public void testMergeSuggestions() throws InterruptedException {
mergedResponse.getSuggest().getSuggestion(suggestionName);
assertEquals(1, suggestion.getEntries().size());
Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option> options = suggestion.getEntries().get(0);
assertEquals(skipDuplicates ? 1 : Math.min(numResponses, size), options.getOptions().size());
assertEquals(Math.min(numResponses, size), options.getOptions().size());
int i = numResponses;
for (Suggest.Suggestion.Entry.Option option : options) {
assertEquals("suggestion", option.getText().string());
assertEquals(--i, option.getScore(), 0f);
}
}

public void testMergeCompletionSuggestionsTieBreak() throws InterruptedException {
String suggestionName = randomAlphaOfLengthBetween(4, 8);
int size = randomIntBetween(1, 100);
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0), flag -> null);
for (int i = 0; i < numResponses; i++) {
List<Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>>> suggestions =
new ArrayList<>();
CompletionSuggestion completionSuggestion = new CompletionSuggestion(suggestionName, size, false);
CompletionSuggestion.Entry options = new CompletionSuggestion.Entry(new Text("suggest"), 0, 10);
int docId = randomIntBetween(0, Integer.MAX_VALUE);
CompletionSuggestion.Entry.Option option = new CompletionSuggestion.Entry.Option(docId, new Text("suggestion"), 1F,
Collections.emptyMap());
SearchHit searchHit = new SearchHit(docId);
searchHit.shard(new SearchShardTarget("node", new ShardId("index", "uuid", randomIntBetween(0, Integer.MAX_VALUE)),
randomBoolean() ? RemoteClusterService.LOCAL_CLUSTER_GROUP_KEY : randomAlphaOfLengthBetween(5, 10), OriginalIndices.NONE));
option.setHit(searchHit);
options.addOption(option);
completionSuggestion.addTerm(options);
suggestions.add(completionSuggestion);
Suggest suggest = new Suggest(suggestions);
SearchHits searchHits = new SearchHits(new SearchHit[0], null, Float.NaN);
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, suggest, null, false, null, 1);
SearchResponse searchResponse = new SearchResponse(internalSearchResponse, null, 1, 1, 0, randomLong(),
ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);
addResponse(searchResponseMerger, searchResponse);
}
awaitResponsesAdded();
assertEquals(numResponses, searchResponseMerger.numResponses());
SearchResponse.Clusters clusters = SearchResponseTests.randomClusters();
SearchResponse mergedResponse = searchResponseMerger.getMergedResponse(clusters);
assertSame(clusters, mergedResponse.getClusters());
assertEquals(numResponses, mergedResponse.getTotalShards());
assertEquals(numResponses, mergedResponse.getSuccessfulShards());
assertEquals(0, mergedResponse.getSkippedShards());
assertEquals(0, mergedResponse.getFailedShards());
assertEquals(0, mergedResponse.getShardFailures().length);
CompletionSuggestion suggestion = mergedResponse.getSuggest().getSuggestion(suggestionName);
assertEquals(1, suggestion.getEntries().size());
CompletionSuggestion.Entry options = suggestion.getEntries().get(0);
assertEquals(Math.min(numResponses, size), options.getOptions().size());
int lastShardId = 0;
String lastClusterAlias = null;
for (CompletionSuggestion.Entry.Option option : options) {
assertEquals("suggestion", option.getText().string());
SearchShardTarget shard = option.getHit().getShard();
int currentShardId = shard.getShardId().id();
assertThat(currentShardId, greaterThanOrEqualTo(lastShardId));
if (currentShardId == lastShardId) {
assertThat(shard.getClusterAlias(), greaterThan(lastClusterAlias));
} else {
lastShardId = currentShardId;
}
lastClusterAlias = shard.getClusterAlias();
}
}

public void testMergeAggs() throws InterruptedException {
SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0),
flag -> new InternalAggregation.ReduceContext(null, null, flag));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@

import static org.elasticsearch.search.suggest.Suggest.COMPARATOR;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class CompletionSuggestionTests extends ESTestCase {

public void testToReduce() {
public void testReduce() {
List<Suggest.Suggestion<CompletionSuggestion.Entry>> shardSuggestions = new ArrayList<>();
int nShards = randomIntBetween(1, 10);
String name = randomAlphaOfLength(10);
Expand All @@ -50,8 +52,10 @@ public void testToReduce() {
Suggest.Suggestion<CompletionSuggestion.Entry> suggestion = randomFrom(shardSuggestions);
CompletionSuggestion.Entry entry = suggestion.getEntries().get(0);
if (entry.getOptions().size() < size) {
entry.addOption(new CompletionSuggestion.Entry.Option(i, new Text(""),
maxScore - i, Collections.emptyMap()));
CompletionSuggestion.Entry.Option option = new CompletionSuggestion.Entry.Option(i, new Text(""),
maxScore - i, Collections.emptyMap());
option.setShardIndex(randomIntBetween(0, Integer.MAX_VALUE));
entry.addOption(option);
}
}
CompletionSuggestion reducedSuggestion = (CompletionSuggestion) shardSuggestions.get(0).reduce(shardSuggestions);
Expand All @@ -64,7 +68,7 @@ public void testToReduce() {
}
}

public void testToReduceWithDuplicates() {
public void testReduceWithDuplicates() {
List<Suggest.Suggestion<CompletionSuggestion.Entry>> shardSuggestions = new ArrayList<>();
int nShards = randomIntBetween(2, 10);
String name = randomAlphaOfLength(10);
Expand All @@ -85,6 +89,7 @@ public void testToReduceWithDuplicates() {
String surfaceForm = randomFrom(surfaceForms);
CompletionSuggestion.Entry.Option newOption =
new CompletionSuggestion.Entry.Option(j, new Text(surfaceForm), maxScore - j, Collections.emptyMap());
newOption.setShardIndex(0);
entry.addOption(newOption);
options.add(newOption);
}
Expand All @@ -100,4 +105,42 @@ public void testToReduceWithDuplicates() {
assertThat(reducedSuggestion.getOptions().size(), lessThanOrEqualTo(size));
assertEquals(expected, reducedSuggestion.getOptions());
}

public void testReduceTiebreak() {
List<Suggest.Suggestion<CompletionSuggestion.Entry>> shardSuggestions = new ArrayList<>();
Text surfaceForm = new Text(randomAlphaOfLengthBetween(5, 10));
float score = randomFloat();
int numResponses = randomIntBetween(2, 10);
String name = randomAlphaOfLength(10);
int size = randomIntBetween(10, 100);
for (int i = 0; i < numResponses; i++) {
CompletionSuggestion suggestion = new CompletionSuggestion(name, size, false);
CompletionSuggestion.Entry entry = new CompletionSuggestion.Entry(new Text(""), 0, 0);
suggestion.addTerm(entry);
int shardIndex = 0;
for (int j = 0; j < size; j++) {
CompletionSuggestion.Entry.Option newOption =
new CompletionSuggestion.Entry.Option((j + 1) * (i + 1), surfaceForm, score, Collections.emptyMap());
newOption.setShardIndex(shardIndex++);
entry.addOption(newOption);
}
shardSuggestions.add(suggestion);
}
CompletionSuggestion reducedSuggestion = (CompletionSuggestion) shardSuggestions.get(0).reduce(shardSuggestions);
assertNotNull(reducedSuggestion);
List<CompletionSuggestion.Entry.Option> options = reducedSuggestion.getOptions();
assertThat(options.size(), lessThanOrEqualTo(size));
int shardIndex = 0;
int docId = -1;
for (CompletionSuggestion.Entry.Option option : options) {
assertThat(option.getDoc().shardIndex, greaterThanOrEqualTo(shardIndex));
if (option.getDoc().shardIndex == shardIndex) {
assertThat(option.getDoc().doc, greaterThan(docId));
} else {
assertThat(option.getDoc().shardIndex, equalTo(shardIndex + 1));
shardIndex = option.getDoc().shardIndex;
}
docId = option.getDoc().doc;
}
}
}

0 comments on commit 9d02114

Please sign in to comment.