Skip to content

Commit

Permalink
Adds rescore parameter to KNNQuery (opensearch-project#1969)
Browse files Browse the repository at this point in the history
Adds rescore parameter to knn query. With this commit, the rescore is a
no-op. The functionality and validation will be added in a later commit.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Aug 15, 2024
1 parent b566753 commit 9db7058
Show file tree
Hide file tree
Showing 12 changed files with 612 additions and 16 deletions.
38 changes: 36 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -54,6 +56,8 @@
import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
import static org.opensearch.knn.index.engine.validation.ParameterValidator.validateParameters;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;

/**
* Helper class to build the KNN query
Expand All @@ -73,6 +77,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
public static final ParseField RESCORE_FIELD = new ParseField(RESCORE_PARAMETER);
public static final ParseField RESCORE_OVERSAMPLE_FIELD = new ParseField(RESCORE_OVERSAMPLE_PARAMETER);

public static final int K_MAX = 10000;
/**
Expand All @@ -96,6 +102,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private QueryBuilder filter;
@Getter
private boolean ignoreUnmapped;
@Getter
private RescoreContext rescoreContext;

/**
* Constructs a new query with the given field name and vector
Expand Down Expand Up @@ -136,6 +144,7 @@ public static class Builder {
private boolean ignoreUnmapped;
private String queryName;
private float boost = DEFAULT_BOOST;
private RescoreContext rescoreContext;

public Builder() {}

Expand Down Expand Up @@ -189,11 +198,25 @@ public Builder boost(float boost) {
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public KNNQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost)
.queryName(queryName);
return new KNNQueryBuilder(
fieldName,
vector,
k,
maxDistance,
minScore,
methodParameters,
filter,
ignoreUnmapped,
rescoreContext
).boost(boost).queryName(queryName);
}

private void validate() {
Expand Down Expand Up @@ -240,6 +263,15 @@ private void validate() {
);
}
}

if (rescoreContext != null) {
ValidationException validationException = RescoreParser.validate(rescoreContext);
if (validationException != null) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] errors in rescore parameter [%s]", NAME, validationException.getMessage())
);
}
}
}
}

Expand Down Expand Up @@ -284,6 +316,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.ignoreUnmapped = false;
this.maxDistance = null;
this.minScore = null;
this.rescoreContext = null;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -305,6 +338,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
maxDistance = builder.maxDistance;
minScore = builder.minScore;
methodParameters = builder.methodParameters;
rescoreContext = builder.rescoreContext;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;

Expand All @@ -29,6 +30,8 @@
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;
import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD;
Expand Down Expand Up @@ -79,6 +82,17 @@ private static ObjectParser<KNNQueryBuilder.Builder, Void> createInternalObjectP
);
internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD);

internalParser.declareObjectOrDefault(
KNNQueryBuilder.Builder::rescoreContext,
(p, v) -> RescoreParser.fromXContent(p),
RescoreContext::getDefault,
RESCORE_FIELD
);

// Declare fields that cannot be set at the same time. Right now, rescore and radial is not supported
internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MAX_DISTANCE_FIELD.getPreferredName());
internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MIN_SCORE_FIELD.getPreferredName());

return internalParser;
}

Expand Down Expand Up @@ -110,6 +124,10 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function<Strin
builder.methodParameters(MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion));
}

if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) {
builder.rescoreContext(RescoreParser.streamInput(in));
}

return builder;
}

Expand Down Expand Up @@ -139,6 +157,9 @@ public static void streamOutput(StreamOutput out, KNNQueryBuilder builder, Funct
if (minClusterVersionCheck.apply(METHOD_PARAMETER)) {
MethodParametersParser.streamOutput(out, builder.getMethodParameters(), IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}
if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) {
RescoreParser.streamOutput(out, builder.getRescoreContext());
}
}

/**
Expand Down Expand Up @@ -204,6 +225,9 @@ public static void toXContent(XContentBuilder builder, ToXContent.Params params,
if (knnQueryBuilder.getMethodParameters() != null) {
MethodParametersParser.doXContent(builder, knnQueryBuilder.getMethodParameters());
}
if (knnQueryBuilder.getRescoreContext() != null) {
RescoreParser.doXContent(builder, knnQueryBuilder.getRescoreContext());
}

builder.field(BOOST_FIELD.getPreferredName(), knnQueryBuilder.boost());
if (knnQueryBuilder.queryName() != null) {
Expand Down
131 changes: 131 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/parser/RescoreParser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.parser;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD;

/**
* Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well
*/
@Getter
@AllArgsConstructor
@Log4j2
public final class RescoreParser {

public static final String RESCORE_PARAMETER = "rescore";
public static final String RESCORE_OVERSAMPLE_PARAMETER = "oversample_factor";

private static final ObjectParser<RescoreContext.RescoreContextBuilder, Void> INTERNAL_PARSER = createInternalObjectParser();

private static ObjectParser<RescoreContext.RescoreContextBuilder, Void> createInternalObjectParser() {
ObjectParser<RescoreContext.RescoreContextBuilder, Void> internalParser = new ObjectParser<>(
RESCORE_PARAMETER,
RescoreContext::builder
);
internalParser.declareFloat(RescoreContext.RescoreContextBuilder::oversampleFactor, RESCORE_OVERSAMPLE_FIELD);
return internalParser;
}

/**
* Validate the rescore context
*
* @return ValidationException if validation fails, null otherwise
*/
public static ValidationException validate(RescoreContext rescoreContext) {
if (rescoreContext.getOversampleFactor() < RescoreContext.MIN_OVERSAMPLE_FACTOR) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Oversample factor [%f] cannot be less than [%f]",
rescoreContext.getOversampleFactor(),
RescoreContext.MIN_OVERSAMPLE_FACTOR
)
);
return validationException;
}

if (rescoreContext.getOversampleFactor() > RescoreContext.MAX_OVERSAMPLE_FACTOR) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Oversample factor [%f] cannot be more than [%f]",
rescoreContext.getOversampleFactor(),
RescoreContext.MAX_OVERSAMPLE_FACTOR
)
);
return validationException;
}
return null;
}

/**
*
* @param in stream input
* @return RescoreContext
* @throws IOException on stream failure
*/
public static RescoreContext streamInput(StreamInput in) throws IOException {
if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), RESCORE_PARAMETER)) {
return null;
}
Float oversample = in.readOptionalFloat();
if (oversample == null) {
return null;
}
return RescoreContext.builder().oversampleFactor(oversample).build();
}

/**
*
* @param out stream output
* @param rescoreContext RescoreContext
* @throws IOException on stream failure
*/
public static void streamOutput(StreamOutput out, RescoreContext rescoreContext) throws IOException {
if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), RESCORE_PARAMETER)) {
return;
}
out.writeOptionalFloat(rescoreContext == null ? null : rescoreContext.getOversampleFactor());
}

/**
*
* @param builder XContentBuilder
* @param rescoreContext RescoreContext
* @throws IOException on XContent failure
*/
public static void doXContent(final XContentBuilder builder, final RescoreContext rescoreContext) throws IOException {
builder.startObject(RESCORE_PARAMETER);
builder.field(RESCORE_OVERSAMPLE_PARAMETER, rescoreContext.getOversampleFactor());
builder.endObject();
}

/**
*
* @param parser input parser
* @return RescoreContext
*/
public static RescoreContext fromXContent(final XContentParser parser) {
return INTERNAL_PARSER.apply(parser, null).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.rescore;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;

@Getter
@AllArgsConstructor
@Builder
@EqualsAndHashCode
public final class RescoreContext {

public static final float DEFAULT_OVERSAMPLE_FACTOR = 1.0f;
public static final float MAX_OVERSAMPLE_FACTOR = 100.0f;
public static final float MIN_OVERSAMPLE_FACTOR = 0.0f;

@Builder.Default
private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR;

/**
*
* @return default RescoreContext
*/
public static RescoreContext getDefault() {
return RescoreContext.builder().build();
}
}
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;

public class IndexUtil {

Expand All @@ -48,6 +49,8 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0;
// TODO: Will update once 2.17 backport change is merged
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_3_0_0;
// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();

Expand Down Expand Up @@ -402,6 +405,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS);
put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE);
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import lombok.AllArgsConstructor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -86,13 +87,22 @@ public static Collection<Object[]> invalidParameters() {
"min score less than 0",
"[knn] requires minScore to be greater than 0",
KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(-1f)
),
$(
"Rescore context",
" cannot be less than",
KNNQueryBuilder.builder()
.rescoreContext(RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build())
.fieldName(FIELD_NAME)
.vector(QUERY_VECTOR)
.k(1)
)
)
);
}

public void testInvalidBuilder() {
Throwable exception = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilderBuilder.build());
assertEquals(expectedMessage, expectedMessage, exception.getMessage());
assertTrue(exception.getMessage(), exception.getMessage().contains(expectedMessage));
}
}
Loading

0 comments on commit 9db7058

Please sign in to comment.