Skip to content

Commit

Permalink
Change some name of functiion to provide a clearer and more precise d…
Browse files Browse the repository at this point in the history
…escription.

Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 22, 2024
1 parent d19132a commit 607ff2d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import org.apache.lucene.search.Weight;

/**
* Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual
* scores for each sub-query.
* Implementation of Query interface for type NeuralSparseQuery when TwoPhaseNeuralSparse Enabled.
* Initialized, it currentQuery include all tokenQuery. After
*/
@AllArgsConstructor
@Getter
Expand Down Expand Up @@ -86,7 +86,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return currentQuery.createWeight(searcher, scoreMode, boost);
}

public void extractLowScoreToken() {
public void setCurrentQueryToHighScoreTokenQuery() {
this.currentQuery = highScoreTokenQuery;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import lombok.extern.log4j.Log4j2;

import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addTwoPhaseNeuralSparseQuery;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addSecondPhaseRescoreContextFromValidNeuralSparseQuery;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery;
Expand All @@ -45,7 +45,7 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
addTwoPhaseNeuralSparseQuery(query, searchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, searchContext);
if (!isHybridQuery(query, searchContext)) {
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import static java.lang.Integer.min;
import static org.opensearch.index.IndexSettings.MAX_RESCORE_WINDOW_SETTING;

/**
* Util class for do two phase preprocess for the NeuralSparseQuery.
* Include adding the second phase query to searchContext and set the currentQuery to highScoreTokenQuery.
*/
public class NeuralSparseTwoPhaseUtil {

private static float populateQueryWeightsMapAndGetWindowSizeExpansion(
Expand All @@ -48,7 +52,7 @@ private static float populateQueryWeightsMapAndGetWindowSizeExpansion(
}
} else if (query instanceof NeuralSparseQuery) {
query2Weight.put(((NeuralSparseQuery) query).getLowScoreTokenQuery(), weight);
((NeuralSparseQuery) query).extractLowScoreToken();
((NeuralSparseQuery) query).setCurrentQueryToHighScoreTokenQuery();
windoSizeExpansion = max(windoSizeExpansion, ((NeuralSparseQuery) query).getRescoreWindowSizeExpansion());
}
// ToDo Support for other compound query.
Expand All @@ -68,7 +72,12 @@ private static Query getNestedTwoPhaseQuery(Map<Query, Float> query2weight) {
return builder.build();
}

public static void addTwoPhaseNeuralSparseQuery(final Query query, SearchContext searchContext) {
/**
*
* @param query The whole query include neuralSparseQuery to executed.
* @param searchContext The searchContext with this query.
*/
public static void addSecondPhaseRescoreContextFromValidNeuralSparseQuery(final Query query, SearchContext searchContext) {
Map<Query, Float> query2weight = new HashMap<>();
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f);
Query twoPhaseQuery;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ public final class NeuralSearchSettings {
Setting.Property.NodeScope
);

/**
* Use this setting to manage if a neuralSparseQuery build a two-phase query of not.
*/
public static final Setting<Boolean> NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED = Setting.boolSetting(
"plugins.neural_search.neural_sparse.two_phase.default_enabled",
true,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

/**
* Control the number of TopDocs rescored by the second phase of NeuralSparseQuery.
*/
public static final Setting<Float> NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION = Setting.floatSetting(
"plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion",
5f,
Expand All @@ -49,6 +55,9 @@ public final class NeuralSearchSettings {
Setting.Property.Dynamic
);

/**
* Control the token score threshold to splitting the NeuralSparseQuery.
*/
public static final Setting<Float> NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO = Setting.floatSetting(
"plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio",
0.4f,
Expand All @@ -58,7 +67,10 @@ public final class NeuralSearchSettings {
Setting.Property.Dynamic
);

// the default value is consistent with core settings MAX_RESCORE_WINDOW_SETTING
/**
* Control the max rescore windows size of the second phase of NeuralSparseQuery.
* The default value is consistent with core settings MAX_RESCORE_WINDOW_SETTING.
*/
public static final Setting<Integer> NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE = Setting.intSetting(
"plugins.neural_search.neural_sparse.two_phase.max_window_size",
10000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testEqualsAndHashCode() {
public void testExtractLowScoreToken_thenCurrentChanged() {
NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(currentQuery, highScoreTokenQuery, lowScoreTokenQuery, 5.0f);
assertSame(neuralSparseQuery.getCurrentQuery(), currentQuery);
neuralSparseQuery.extractLowScoreToken();
neuralSparseQuery.setCurrentQueryToHighScoreTokenQuery();
assertSame(neuralSparseQuery.getCurrentQuery(), highScoreTokenQuery);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addTwoPhaseNeuralSparseQuery;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addSecondPhaseRescoreContextFromValidNeuralSparseQuery;

public class NeuralSparseTwoPhaseUtilTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -98,30 +98,30 @@ public void testInitialize() {
@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenQuery2WeightEmpty_thenNoRescoreAdded() {
Query query = mock(Query.class);
addTwoPhaseNeuralSparseQuery(query, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext);
verify(mockSearchContext, never()).addRescore(any());
}

@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreAdded() {
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(normalNeuralSparseQuery, mock(DoubleValuesSource.class));
addTwoPhaseNeuralSparseQuery(functionScoreQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(functionScoreQuery, mockSearchContext);
DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(Collections.emptyList(), 1.0f);
addTwoPhaseNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext);
List<Query> subQueries = new ArrayList<>();
List<Query> filterQueries = new ArrayList<>();
subQueries.add(normalNeuralSparseQuery);
filterQueries.add(new MatchAllDocsQuery());
HybridQuery hybridQuery = new HybridQuery(subQueries, filterQueries);
addTwoPhaseNeuralSparseQuery(hybridQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(hybridQuery, mockSearchContext);
assertEquals(normalNeuralSparseQuery.getCurrentQuery(), currentQuery);
verify(mockSearchContext, never()).addRescore(any());
}

@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenSingleEntryInQuery2Weight_thenRescoreAdded() {
NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5.0f);
addTwoPhaseNeuralSparseQuery(neuralSparseQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(neuralSparseQuery, mockSearchContext);
verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class));
}

Expand All @@ -135,7 +135,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenCompoundBooleanQuery_thenRescor
queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD);
queryBuilder.add(boostQuery2, BooleanClause.Occur.SHOULD);
BooleanQuery booleanQuery = queryBuilder.build();
addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(booleanQuery, mockSearchContext);
verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class));
}

Expand All @@ -155,7 +155,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo
queryBuilder.add(boostQuery3, BooleanClause.Occur.FILTER);
queryBuilder.add(boostQuery4, BooleanClause.Occur.MUST_NOT);
BooleanQuery booleanQuery = queryBuilder.build();
addTwoPhaseNeuralSparseQuery(booleanQuery, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(booleanQuery, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand All @@ -179,7 +179,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo
@SneakyThrows
public void testWindowSize_whenNormalConditions_thenWindowSizeIsAsSet() {
NeuralSparseQuery query = normalNeuralSparseQuery;
addTwoPhaseNeuralSparseQuery(query, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<QueryRescorer.QueryRescoreContext> rescoreContextArgumentCaptor = ArgumentCaptor.forClass(
QueryRescorer.QueryRescoreContext.class
);
Expand All @@ -192,10 +192,10 @@ public void testWindowSize_whenBoundaryConditions_thenThrowException() {

NeuralSparseQuery query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), 5000f);
NeuralSparseQuery finalQuery1 = query;
expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery1, mockSearchContext); });
expectThrows(IllegalArgumentException.class, () -> { addSecondPhaseRescoreContextFromValidNeuralSparseQuery(finalQuery1, mockSearchContext); });
query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), Float.MAX_VALUE);
NeuralSparseQuery finalQuery = query;
expectThrows(IllegalArgumentException.class, () -> { addTwoPhaseNeuralSparseQuery(finalQuery, mockSearchContext); });
expectThrows(IllegalArgumentException.class, () -> { addSecondPhaseRescoreContextFromValidNeuralSparseQuery(finalQuery, mockSearchContext); });
}

@SneakyThrows
Expand All @@ -207,7 +207,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal
List<RescoreContext> rescoreContextList = Arrays.asList(mockContext1, mockContext2);
when(mockSearchContext.rescore()).thenReturn(rescoreContextList);
NeuralSparseQuery query = normalNeuralSparseQuery;
addTwoPhaseNeuralSparseQuery(query, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand All @@ -218,7 +218,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal
public void testEmptyRescoreListWeight_whenRescoreListEmpty_thenDefaultWeightUsed() {
when(mockSearchContext.rescore()).thenReturn(Collections.emptyList());
NeuralSparseQuery query = normalNeuralSparseQuery;
addTwoPhaseNeuralSparseQuery(query, mockSearchContext);
addSecondPhaseRescoreContextFromValidNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand Down

0 comments on commit 607ff2d

Please sign in to comment.