Skip to content

Commit

Permalink
Support of new k-NN query parameter expand_nested.
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Zhang <[email protected]>
  • Loading branch information
bzhangam committed Dec 17, 2024
1 parent 3d72cc3 commit fa149d4
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import java.util.Map;

import org.opensearch.index.query.MatchQueryBuilder;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand Down Expand Up @@ -71,9 +74,9 @@ private void validateNormalizationProcessor(final String fileName, final String
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
Expand Down Expand Up @@ -115,12 +118,23 @@ private void validateTestIndex(final String index, final String searchPipeline,
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName()) && expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -78,6 +79,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
Expand All @@ -83,10 +83,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
Expand Down Expand Up @@ -124,6 +124,7 @@ private void validateTestIndexOnUpgrade(

private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContextForNeuralQuery
) {
Expand All @@ -132,6 +133,9 @@ private HybridQueryBuilder getQueryBuilder(
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -104,6 +105,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
Expand Down Expand Up @@ -98,6 +99,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private Integer k = null;
private Float maxDistance = null;
private Float minScore = null;
private Boolean expandNested;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
Expand Down Expand Up @@ -132,6 +134,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
this.expandNested = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -158,6 +163,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
out.writeOptionalBoolean(this.expandNested);
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -184,6 +192,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(minScore)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (Objects.nonNull(expandNested)) {
xContentBuilder.field(EXPAND_NESTED_FIELD.getPreferredName(), expandNested);
}
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
Expand Down Expand Up @@ -274,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.maxDistance(parser.floatValue());
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.minScore(parser.floatValue());
} else if (EXPAND_NESTED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.expandNested(parser.booleanValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -318,6 +331,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
Expand Down Expand Up @@ -346,6 +360,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
k(),
maxDistance(),
minScore(),
expandNested(),
vectorSetOnce::get,
filter(),
methodParameters(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -150,6 +151,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -191,6 +193,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -249,7 +262,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down Expand Up @@ -299,7 +325,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -324,7 +363,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down
Loading

0 comments on commit fa149d4

Please sign in to comment.