Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Part 3] Concurrent segment search bug in Sorting #808

Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
return KNNQueryBuilder.builder()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to fix neural search package as the main branch is broken and is fixed in this PR. Therefore ignore this change as while cherrypicking on main it will be taken care off.

.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.collector;

import java.util.List;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.TopDocs;

/**
* Common interface class for Hybrid search collectors
*/
public interface HybridSearchCollector extends Collector {
/**
* @return List of topDocs which contains topDocs of individual subqueries.
*/
List<? extends TopDocs> topDocs();

/**
* @return count of total hits per shard
*/
int getTotalHits();

/**
* @return maxScore found on a shard
*/
float getMaxScore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -38,7 +37,7 @@
The individual query results are sorted as per the sort criteria sent in the search request.
*/
@Log4j2
public abstract class HybridTopFieldDocSortCollector implements Collector {
public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector {
private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final Sort sort;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import lombok.Getter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
Expand All @@ -30,7 +29,7 @@
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
*/
@Log4j2
public class HybridTopScoreDocCollector implements Collector {
public class HybridTopScoreDocCollector implements HybridSearchCollector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import java.util.Comparator;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;

/**
* Comparator class that compares two field docs as per the sorting criteria
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryFieldDocComparator implements Comparator<FieldDoc> {
final SortField[] sortFields;
final FieldComparator<?>[] comparators;
final int[] reverseMul;
final Comparator<ScoreDoc> tieBreaker;

public HybridQueryFieldDocComparator(SortField[] sortFields, Comparator<ScoreDoc> tieBreaker) {
this.sortFields = sortFields;
this.tieBreaker = tieBreaker;
comparators = new FieldComparator[sortFields.length];
reverseMul = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
}
}

@Override
public int compare(final FieldDoc firstFD, final FieldDoc secondFD) {
for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
final FieldComparator comp = comparators[compIDX];

final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);

if (cmp != 0) {
return cmp;
}
}
return tieBreakCompare(firstFD, secondFD, tieBreaker);
}

private int tieBreakCompare(ScoreDoc firstDoc, ScoreDoc secondDoc, Comparator<ScoreDoc> tieBreaker) {
assert tieBreaker != null;
int value = tieBreaker.compare(firstDoc, secondDoc);
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement;

/**
* Merges two ScoreDoc arrays into one
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryScoreDocsMerger<T extends ScoreDoc> {

private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3;

/**
* Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects.
* Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from
* other query types.
* Logic is based on assumption that hits of every sub-query are sorted by score.
* Method returns new object and doesn't mutate original ScoreDocs arrays.
* @param sourceScoreDocs original score docs from query result
* @param newScoreDocs new score docs that we need to merge into existing scores
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
* @param comparator comparator to compare the score docs
* @param isSortEnabled flag that show if sort is enabled or disabled
* @return merged array of ScoreDocs objects
*/
public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator<T> comparator, final boolean isSortEnabled) {
if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC
|| Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) {
throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements");
}
// we overshoot and preallocate more than we need - length of both top docs combined.
// we will take only portion of the array at the end
List<T> mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length);
int sourcePointer = 0;
// mark beginning of hybrid query results by start element
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
// new pointer is set to 1 as we don't care about it start-stop element
int newPointer = 1;

while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) {
// every iteration is for results of one sub-query
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
newPointer++;
// simplest case when both arrays have results for sub-query
while (sourcePointer < sourceScoreDocs.length
&& isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])
&& newPointer < newScoreDocs.length
&& isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
if (compareCondition(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer], comparator, isSortEnabled)) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
} else {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// at least one object got exhausted at this point, now merge all elements from object that's left
while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
}
while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// mark end of hybrid query results by end element
mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]);
if (isSortEnabled) {
return mergedScoreDocs.toArray((T[]) new FieldDoc[0]);
}
return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]);
}

private boolean compareCondition(
final ScoreDoc oldScoreDoc,
final ScoreDoc secondScoreDoc,
final Comparator<T> comparator,
final boolean isSortEnabled
) {
// If sorting is enabled then compare condition will be different then normal HybridQuery
if (isSortEnabled) {
return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) < 0;
} else {
return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) >= 0;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;

import java.util.Comparator;
import java.util.Objects;
import org.opensearch.search.sort.SortAndFormats;

/**
* Utility class for merging TopDocs and MaxScore across multiple search queries
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class TopDocsMerger {

private HybridQueryScoreDocsMerger scoreDocsMerger;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private SortAndFormats sortAndFormats;
@VisibleForTesting
protected static Comparator<ScoreDoc> SCORE_DOC_BY_SCORE_COMPARATOR;
@VisibleForTesting
protected static HybridQueryFieldDocComparator FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR;
private final Comparator<ScoreDoc> MERGING_TIE_BREAKER = (o1, o2) -> {
int docIdComparison = Integer.compare(o1.doc, o2.doc);
return docIdComparison;
};

/**
* Uses hybrid query score docs merger to merge internal score docs
*/
TopDocsMerger(final SortAndFormats sortAndFormats) {
this.sortAndFormats = sortAndFormats;
if (this.sortAndFormats != null) {
scoreDocsMerger = new HybridQueryScoreDocsMerger<FieldDoc>();
FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER);
} else {
scoreDocsMerger = new HybridQueryScoreDocsMerger<>();
SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score);
}
}

/**
* Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object.
* @param source TopDocsAndMaxScore for the original query
* @param newTopDocs TopDocsAndMaxScore for the new query
* @return merged TopDocsAndMaxScore object
*/
public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) {
if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) {
return source;
}
TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs);
TopDocsAndMaxScore result = new TopDocsAndMaxScore(
getTopDocs(getMergedScoreDocs(source.topDocs.scoreDocs, newTopDocs.topDocs.scoreDocs), mergedTotalHits),
Math.max(source.maxScore, newTopDocs.maxScore)
);
return result;
}

private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) {
// merged value is a lower bound - if both are equal_to than merged will also be equal_to,
// otherwise assign greater_than_or_equal
TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|| newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation);
}

private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits) {
if (sortAndFormats != null) {
return new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort());
}
return new TopDocs(mergedTotalHits, mergedScoreDocs);
}

private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs) {
// Case 1 when sorting is enabled then below will be the TopDocs format
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1 | [1]
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_1 | [1]

// Case 2 when sorting is disabled then below will be the TopDocs format
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
return scoreDocsMerger.merge(source, newScoreDocs, comparator(), sortAndFormats != null);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}

private Comparator<? extends ScoreDoc> comparator() {
return sortAndFormats != null ? FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR : SCORE_DOC_BY_SCORE_COMPARATOR;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ public static FieldDoc createFieldDocDelimiterElementForHybridSearchResults(fina
return new FieldDoc(docId, MAGIC_NUMBER_DELIMITER, fields);
}

/**
* Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores
* @param scoreDoc score doc object to check on
* @return true if it is a special element
*/
public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not review this class as it already merged in main. The change is showing here because we have cherry picked the commit.

if (Objects.isNull(scoreDoc)) {
return false;
}
return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc);
}

/**
* Checking if passed scoreDocs object is a document score element
* @param scoreDoc score doc object to check on
* @return true if element has score
*/
public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) {
if (Objects.isNull(scoreDoc)) {
return false;
}
return !isHybridQuerySpecialElement(scoreDoc);
}

/**
* This method is for creating dummy sort object for the field docs having magic number scores which acts as delimiters.
* The sort object should be in the same type of the field on which sorting criteria is applied.
Expand Down
Loading
Loading