Skip to content

Commit

Permalink
Added specific rescore exception, refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 3, 2024
1 parent 16a70b5 commit 6125e56
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
Expand All @@ -37,6 +36,7 @@
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;
import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -189,23 +189,22 @@ private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List<TopFieldDocs> topDoc
}

private TopDocsAndMaxScore getTopDocsAndMaxScore(List<TopDocs> topDocs, HybridSearchCollector hybridSearchCollector) {
List<TopDocs> rescoredTopDocs = rescore(topDocs);
float maxScore = calculateMaxScore(rescoredTopDocs, hybridSearchCollector.getMaxScore());
TopDocs finalTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, rescoredTopDocs, hybridSearchCollector.getTotalHits()),
rescoredTopDocs
);
if (shouldRescore()) {
topDocs = rescore(topDocs);
}
float maxScore = calculateMaxScore(topDocs, hybridSearchCollector.getMaxScore());
TopDocs finalTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs);
return new TopDocsAndMaxScore(finalTopDocs, maxScore);
}

private List<TopDocs> rescore(List<TopDocs> topDocs) {
private boolean shouldRescore() {
List<RescoreContext> rescoreContexts = searchContext.rescore();
boolean shouldRescore = Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty();
if (!shouldRescore) {
return topDocs;
}
return Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty();
}

private List<TopDocs> rescore(List<TopDocs> topDocs) {
List<TopDocs> rescoredTopDocs = topDocs;
for (RescoreContext ctx : rescoreContexts) {
for (RescoreContext ctx : searchContext.rescore()) {
rescoredTopDocs = rescoredTopDocs(ctx, rescoredTopDocs);
}
return rescoredTopDocs;
Expand All @@ -220,8 +219,8 @@ private List<TopDocs> rescoredTopDocs(final RescoreContext ctx, final List<TopDo
try {
result.add(ctx.rescorer().rescore(topDoc, searchContext.searcher(), ctx));
} catch (IOException exception) {
log.error("rescore failed for hybrid query", exception);
throw new OpenSearchException("rescore failed", exception);
log.error("rescore failed for hybrid query in collector_manager.reduce call", exception);
throw new HybridSearchRescoreQueryException(exception);
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query.exception;

import org.opensearch.OpenSearchException;

/**
* Exception thrown when there is an issue with the hybrid search rescore query.
*/
public class HybridSearchRescoreQueryException extends OpenSearchException {

public HybridSearchRescoreQueryException(Throwable cause) {
super("rescore failed for hybrid query", cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package org.opensearch.neuralsearch.search.query;

import com.carrotsearch.randomizedtesting.RandomizedTest;

import java.io.IOException;
import java.util.Arrays;
import lombok.SneakyThrows;
import org.apache.lucene.document.FieldType;
Expand Down Expand Up @@ -47,6 +49,7 @@
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
Expand All @@ -66,6 +69,7 @@

import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.rescore.Rescorer;
import org.opensearch.search.rescore.RescorerBuilder;
import org.opensearch.search.sort.SortAndFormats;

Expand Down Expand Up @@ -1004,4 +1008,74 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the
reader2.close();
directory2.close();
}

@SneakyThrows
public void testReduceAndRescore_whenRescorerThrowsException_thenFail() {
SearchContext searchContext = mock(SearchContext.class);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(
QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext),
QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext)
)
);
when(searchContext.query()).thenReturn(hybridQueryWithTerm);
ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class);
IndexReader indexReader = mock(IndexReader.class);
when(indexReader.numDocs()).thenReturn(3);
when(indexSearcher.getIndexReader()).thenReturn(indexReader);
when(searchContext.searcher()).thenReturn(indexSearcher);
when(searchContext.size()).thenReturn(2);
IndexReaderContext indexReaderContext = mock(IndexReaderContext.class);
when(indexReader.getContext()).thenReturn(indexReaderContext);

Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>();
when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap);
when(searchContext.shouldUseConcurrentSearch()).thenReturn(false);

Directory directory = newDirectory();
final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.flush();
w.commit();

IndexReader reader = DirectoryReader.open(w);
IndexSearcher searcher = newSearcher(reader);

RescoreContext rescoreContext = mock(RescoreContext.class);
Rescorer rescorer = mock(Rescorer.class);
when(rescoreContext.rescorer()).thenReturn(rescorer);
when(rescorer.rescore(any(), any(), any())).thenThrow(new IOException("something happened with rescorer"));
List<RescoreContext> rescoreContexts = List.of(rescoreContext);
when(searchContext.rescore()).thenReturn(rescoreContexts);

CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext);
HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager1.newCollector();

Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST);
collector.setWeight(weight);

LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0);
LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext);

BulkScorer scorer = weight.bulkScorer(leafReaderContext);
scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs());
leafCollector.finish();

expectThrows(HybridSearchRescoreQueryException.class, () -> hybridCollectorManager1.reduce(List.of()));

// release resources
w.close();
reader.close();
directory.close();
}
}

0 comments on commit 6125e56

Please sign in to comment.