Skip to content

Commit

Permalink
Simplify some logic, and correct some format.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed May 20, 2024
1 parent 9a1d52c commit 91f6a25
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -81,15 +82,18 @@ protected NeuralSparseTwoPhaseProcessor(
* @return request the search request that add the two-phase rescore query of neural sparse query.
*/
@Override
public SearchRequest processRequest(SearchRequest request) throws Exception {
public SearchRequest processRequest(final SearchRequest request) throws Exception {
if (!enabled || ratio == 0f) {
return request;
}
QueryBuilder queryBuilder = request.source().query();
// Collect the nested NeuralSparseQueryBuilder in the whole query.
Multimap<NeuralSparseQueryBuilder, Float> queryBuilderMap = ArrayListMultimap.create();
collectNeuralSparseQueryBuilder(queryBuilder, queryBuilderMap, 1.0f);
queryBuilderMap = collectNeuralSparseQueryBuilder(queryBuilder, 1.0f);
if (queryBuilderMap.isEmpty()) return request;
// Make a nestedQueryBuilder which includes all the two-phase QueryBuilder.
QueryBuilder nestedTwoPhaseQueryBuilder = getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(queryBuilderMap);
// Add it to the rescorer.
nestedTwoPhaseQueryBuilder.boost(getOriginQueryWeightAfterRescore(request.source()));
RescorerBuilder<QueryRescorerBuilder> twoPhaseRescorer = new QueryRescorerBuilder(nestedTwoPhaseQueryBuilder);
int requestSize = request.source().size();
Expand Down Expand Up @@ -156,15 +160,15 @@ public NeuralSparseTwoPhaseProcessor create(
) throws IllegalArgumentException {

boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, true);
Map<String, Object> map = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY);
Map<String, Object> twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY);

float ratio = 0.4f;
float window_expansion = 5.0f;
int max_window_size = 10000;
if (map != null) {
ratio = ((Number) map.getOrDefault(RATIO_KEY, ratio)).floatValue();
window_expansion = ((Number) map.getOrDefault(EXPANSION_KEY, window_expansion)).floatValue();
max_window_size = ((Number) map.getOrDefault(MAX_WINDOW_SIZE_KEY, max_window_size)).intValue();
if (Objects.nonNull(twoPhaseConfigMap)) {
ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue();
window_expansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, window_expansion)).floatValue();
max_window_size = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, max_window_size)).intValue();
}

return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, window_expansion, max_window_size);
Expand All @@ -191,40 +195,25 @@ private float getOriginQueryWeightAfterRescore(final SearchSourceBuilder searchS
.reduce(1.0f, (a, b) -> a * b);
}

private void collectNeuralSparseQueryBuilder(
final QueryBuilder queryBuilder,
final Multimap<NeuralSparseQueryBuilder, Float> neuralSparseQueryBuilderFloatMap,
float baseBoost
) {
if (queryBuilder instanceof BoolQueryBuilder) {
collectNeuralSparseQueryBuilder((BoolQueryBuilder) queryBuilder, neuralSparseQueryBuilderFloatMap, baseBoost);
} else if (queryBuilder instanceof NeuralSparseQueryBuilder) {
collectNeuralSparseQueryBuilder((NeuralSparseQueryBuilder) queryBuilder, neuralSparseQueryBuilderFloatMap, baseBoost);
}
}
private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilder(final QueryBuilder queryBuilder, float baseBoost) {
Multimap<NeuralSparseQueryBuilder, Float> result = ArrayListMultimap.create();

private void collectNeuralSparseQueryBuilder(
final BoolQueryBuilder queryBuilder,
final Multimap<NeuralSparseQueryBuilder, Float> neuralSparseQueryBuilderFloatMap,
float baseBoost
) {
baseBoost *= queryBuilder.boost();
for (QueryBuilder subQuery : queryBuilder.should()) {
if (subQuery instanceof BoolQueryBuilder) {
collectNeuralSparseQueryBuilder(subQuery, neuralSparseQueryBuilderFloatMap, baseBoost);
} else if (subQuery instanceof NeuralSparseQueryBuilder) {
collectNeuralSparseQueryBuilder((NeuralSparseQueryBuilder) subQuery, neuralSparseQueryBuilderFloatMap, baseBoost);
if (queryBuilder instanceof BoolQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) queryBuilder;
float updatedBoost = baseBoost * boolQueryBuilder.boost();
for (QueryBuilder subQuery : boolQueryBuilder.should()) {
Multimap<NeuralSparseQueryBuilder, Float> subResult = collectNeuralSparseQueryBuilder(subQuery, updatedBoost);
result.putAll(subResult);
}
} else if (queryBuilder instanceof NeuralSparseQueryBuilder) {
NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder;
float updatedBoost = baseBoost * neuralSparseQueryBuilder.boost();
NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio);
result.put(modifiedQueryBuilder, updatedBoost);
}
}

private void collectNeuralSparseQueryBuilder(
final NeuralSparseQueryBuilder queryBuilder,
final Multimap<NeuralSparseQueryBuilder, Float> neuralSparseQueryBuilderFloatMap,
float baseBoost
) {
baseBoost *= queryBuilder.boost();
neuralSparseQueryBuilderFloatMap.put(queryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio), baseBoost);
// We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will
// do nothing and just quit.
return result;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
// A parameter that can detect if twoPhase are enabled, when twoPhasePruneRatio equals -1f, it means it two-phase rescoreQueryBuilder's
// subQueryBuilder.
private float twoPhasePruneRatio = 0F;
private NeuralSparseQueryBuilder twoPhaseNeuralSparseQueryBuilder = null;
private Supplier<Map<String, Float>> twoPhaseQueryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

public static void initialize(MLCommonsClientAccessor mlClient) {
Expand Down Expand Up @@ -138,7 +138,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
this.queryTokensSupplier(splitTokens.get(true)::get);
copy.queryTokensSupplier(splitTokens.get(false)::get);
} else copy.queryTokensSupplier(new SetOnce<Map<String, Float>>()::get);
this.twoPhaseNeuralSparseQueryBuilder = copy;
this.twoPhaseQueryTokensSupplier = copy.twoPhaseQueryTokensSupplier();
return copy;
}

Expand Down Expand Up @@ -300,7 +300,7 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (queryTokensSupplier != null) {
if (Objects.nonNull(queryTokensSupplier)) {
return this;
}
validateForRewrite(queryText, modelId);
Expand All @@ -314,7 +314,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.maxTokenScore(maxTokenScore)
.queryTokensSupplier(queryTokensSetOnce::get)
.twoPhasePruneRatio(twoPhasePruneRatio)
.twoPhaseNeuralSparseQueryBuilder(twoPhaseNeuralSparseQueryBuilder);
.twoPhaseQueryTokensSupplier(twoPhaseQueryTokensSupplier);
}

private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map<String, Float>> setOnce) {
Expand All @@ -323,13 +323,13 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
List.of(queryText),
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (twoPhaseNeuralSparseQueryBuilder != null) {
if (twoPhaseQueryTokensSupplier != null) {
Map<Boolean, SetOnce<Map<String, Float>>> splitSetOnce = getSplitSetOnceByScoreThreshold(
queryTokens,
twoPhasePruneRatio
);
setOnce.set(splitSetOnce.get(true).get());
twoPhaseNeuralSparseQueryBuilder.queryTokensSupplier(splitSetOnce.get(false)::get);
twoPhaseQueryTokensSupplier = splitSetOnce.get(false)::get;
} else {
setOnce.set(queryTokens);
}
Expand All @@ -343,8 +343,10 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
final MappedFieldType ft = context.fieldMapper(fieldName);
validateFieldType(ft);
Map<String, Float> queryTokens = queryTokensSupplier.get();
if (null == queryTokens) {
if (twoPhasePruneRatio == -1f) return new MatchNoDocsQuery();
if (Objects.isNull(queryTokens)) {
if (twoPhasePruneRatio == -1f) {
return new MatchNoDocsQuery();
}
throw new IllegalArgumentException("Query tokens cannot be null.");
}
BooleanQuery.Builder builder = new BooleanQuery.Builder();
Expand Down Expand Up @@ -379,8 +381,8 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false;
if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
if (Objects.nonNull(twoPhaseNeuralSparseQueryBuilder) && Objects.isNull(obj.twoPhaseNeuralSparseQueryBuilder)) return false;
if (Objects.isNull(twoPhaseNeuralSparseQueryBuilder) && Objects.nonNull(obj.twoPhaseNeuralSparseQueryBuilder)) return false;
if (Objects.nonNull(twoPhaseQueryTokensSupplier) && Objects.isNull(obj.twoPhaseQueryTokensSupplier)) return false;
if (Objects.isNull(twoPhaseQueryTokensSupplier) && Objects.nonNull(obj.twoPhaseQueryTokensSupplier)) return false;

EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
Expand All @@ -390,8 +392,8 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (Objects.nonNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
if (Objects.nonNull(twoPhaseNeuralSparseQueryBuilder)) {
equalsBuilder.append(twoPhaseNeuralSparseQueryBuilder, obj.twoPhaseNeuralSparseQueryBuilder);
if (Objects.nonNull(twoPhaseQueryTokensSupplier)) {
equalsBuilder.append(twoPhaseQueryTokensSupplier, obj.twoPhaseQueryTokensSupplier);
}
return equalsBuilder.isEquals();
}
Expand All @@ -406,8 +408,8 @@ protected int doHashCode() {
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
if (twoPhaseNeuralSparseQueryBuilder != null) {
builder.append(twoPhaseNeuralSparseQueryBuilder);
if (twoPhaseQueryTokensSupplier != null) {
builder.append(twoPhaseQueryTokensSupplier);
}
return builder.toHashCode();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then
}
}


@SneakyThrows
protected void initializeIndexIfNotExist(String indexName) {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.SetOnce;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
Expand Down Expand Up @@ -68,6 +73,36 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio
assertNotNull(searchRequest.source().rescores());
}

public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(neuralQueryBuilder);
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(boolQueryBuilder));
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
processor.processRequest(searchRequest);
BoolQueryBuilder queryBuilder = (BoolQueryBuilder) searchRequest.source().query();
NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder.should().get(0);
assertEquals(neuralSparseQueryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
assertNotNull(searchRequest.source().rescores());
}

public void testProcessRequestWithRescorer_whenTwoPhaseEnabled_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
QueryRescorerBuilder queryRescorerBuilder = new QueryRescorerBuilder(new MatchAllQueryBuilder());
queryRescorerBuilder.setRescoreQueryWeight(0f);
searchRequest.source().addRescorer(queryRescorerBuilder);
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
processor.processRequest(searchRequest);
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
assertNotNull(searchRequest.source().rescores());
}

public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
Expand All @@ -80,6 +115,39 @@ public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Excepti
assertNull(searchRequest.source().rescores());
}

@SneakyThrows
public void testProcessRequest_whenTwoPhaseEnabledAndOutOfWindowSize_thenThrowException() {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
QueryRescorerBuilder queryRescorerBuilder = new QueryRescorerBuilder(new MatchAllQueryBuilder());
queryRescorerBuilder.setRescoreQueryWeight(0f);
searchRequest.source().addRescorer(queryRescorerBuilder);
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 400.0f, 100);
expectThrows(IllegalArgumentException.class, () -> processor.processRequest(searchRequest));
}

@SneakyThrows
public void testGetSplitSetOnceByScoreThreshold() {
Map<String, Float> queryTokens = new HashMap<>();
for (int i = 0; i < 10; i++) {
queryTokens.put(String.valueOf(i), (float) i);
}
Map<Boolean, SetOnce<Map<String, Float>>> splitSetOnce = NeuralSparseTwoPhaseProcessor.getSplitSetOnceByScoreThreshold(
queryTokens,
0.4f
);
assertNotNull(splitSetOnce);
SetOnce<Map<String, Float>> upSet = splitSetOnce.get(true);
SetOnce<Map<String, Float>> downSet = splitSetOnce.get(false);
assertNotNull(upSet);
assertEquals(6, upSet.get().size());
assertNotNull(downSet);
assertEquals(4, downSet.get().size());
assertNotNull(splitSetOnce.get(false));
}

public void testType() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory);
Expand Down

0 comments on commit 91f6a25

Please sign in to comment.