Skip to content

Commit

Permalink
Adds method_parameters in neural search query to support ef_search (#787
Browse files Browse the repository at this point in the history
) (#814)

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas authored Jul 4, 2024
1 parent 07406f3 commit dea1fe3
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x)
### Features
### Enhancements
* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.opensearch.index.query.MatchQueryBuilder;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
Expand Down Expand Up @@ -69,6 +70,7 @@ private void validateNormalizationProcessor(final String fileName, final String
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
}
Expand Down Expand Up @@ -96,10 +98,14 @@ private void createSearchPipeline(final String pipelineName) {
);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline) throws Exception {
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
validateTestIndex(modelId, index, searchPipeline, null);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
int docCount = getDocCount(index);
assertEquals(6, docCount);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
assertNotNull(searchResponseAsMap);
int hits = getHitCount(searchResponseAsMap);
Expand All @@ -110,12 +116,15 @@ private void validateTestIndex(final String modelId, final String index, final S
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId) {
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ private void validateIndexQuery(final String modelId) {
null,
0.01f,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -74,6 +75,7 @@ private void validateIndexQuery(final String modelId) {
100000f,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -83,10 +84,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
Expand All @@ -103,12 +109,15 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId) {
private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
0.01f,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -100,6 +101,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
100000f,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.common;

import com.google.common.collect.ImmutableMap;
import org.opensearch.Version;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import java.util.Map;

import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;

/**
* A util class which holds the logic to determine the min version supported by the request parameters
*/
public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
.put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID)
.put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.build();

public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}

public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
version = MINIMAL_VERSION_NEURAL.get(key);
} else {
version = IndexUtil.minimalRequiredVersionMap.get(key);
}
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(version);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;
Expand All @@ -19,7 +25,6 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -34,8 +39,9 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
@VisibleForTesting
static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");

@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
public static final ParseField MODEL_ID_FIELD = new ParseField("model_id");

@VisibleForTesting
static final ParseField K_FIELD = new ParseField("k");

@VisibleForTesting
static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");

@VisibleForTesting
static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");

private static final int DEFAULT_K = 10;

private static MLCommonsClientAccessor ML_CLIENT;
Expand All @@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private Map<String, ?> methodParameters;

/**
* Constructor from stream input
Expand Down Expand Up @@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
}

@Override
Expand All @@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
}

@Override
Expand All @@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(minScore)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
}
} else {
throw new ParsingException(
Expand All @@ -292,15 +301,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
return KNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -326,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
maxDistance(),
minScore(),
vectorSetOnce::get,
filter()
filter(),
methodParameters()
);
}

Expand Down Expand Up @@ -359,14 +368,6 @@ public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}

private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
int queryCount = 0;
if (neuralQueryBuilder.k() != null) {
Expand Down
Loading

0 comments on commit dea1fe3

Please sign in to comment.