Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds rescore parameter to KNNQuery #1969

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -54,6 +56,8 @@
import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
import static org.opensearch.knn.index.engine.validation.ParameterValidator.validateParameters;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;

/**
* Helper class to build the KNN query
Expand All @@ -73,6 +77,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
public static final ParseField RESCORE_FIELD = new ParseField(RESCORE_PARAMETER);
public static final ParseField RESCORE_OVERSAMPLE_FIELD = new ParseField(RESCORE_OVERSAMPLE_PARAMETER);

public static final int K_MAX = 10000;
/**
Expand All @@ -96,6 +102,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private QueryBuilder filter;
@Getter
private boolean ignoreUnmapped;
@Getter
private RescoreContext rescoreContext;

/**
* Constructs a new query with the given field name and vector
Expand Down Expand Up @@ -136,6 +144,7 @@ public static class Builder {
private boolean ignoreUnmapped;
private String queryName;
private float boost = DEFAULT_BOOST;
private RescoreContext rescoreContext;

public Builder() {}

Expand Down Expand Up @@ -189,11 +198,25 @@ public Builder boost(float boost) {
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public KNNQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost)
.queryName(queryName);
return new KNNQueryBuilder(
fieldName,
vector,
k,
maxDistance,
minScore,
methodParameters,
filter,
ignoreUnmapped,
rescoreContext
).boost(boost).queryName(queryName);
}

private void validate() {
Expand Down Expand Up @@ -240,6 +263,15 @@ private void validate() {
);
}
}

if (rescoreContext != null) {
ValidationException validationException = RescoreParser.validate(rescoreContext);
if (validationException != null) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] errors in rescore parameter [%s]", NAME, validationException.getMessage())
);
}
}
}
}

Expand Down Expand Up @@ -284,6 +316,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.ignoreUnmapped = false;
this.maxDistance = null;
this.minScore = null;
this.rescoreContext = null;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -305,6 +338,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
maxDistance = builder.maxDistance;
minScore = builder.minScore;
methodParameters = builder.methodParameters;
rescoreContext = builder.rescoreContext;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;

Expand All @@ -29,6 +30,8 @@
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;
import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD;
Expand Down Expand Up @@ -79,6 +82,17 @@ private static ObjectParser<KNNQueryBuilder.Builder, Void> createInternalObjectP
);
internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD);

internalParser.declareObjectOrDefault(
KNNQueryBuilder.Builder::rescoreContext,
(p, v) -> RescoreParser.fromXContent(p),
RescoreContext::getDefault,
RESCORE_FIELD
);

// Declare fields that cannot be set at the same time. Right now, rescore and radial is not supported
internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MAX_DISTANCE_FIELD.getPreferredName());
internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MIN_SCORE_FIELD.getPreferredName());

return internalParser;
}

Expand Down Expand Up @@ -110,6 +124,10 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function<Strin
builder.methodParameters(MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion));
}

if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will affect neural, just a heads up

builder.rescoreContext(RescoreParser.streamInput(in));
}

return builder;
}

Expand Down Expand Up @@ -139,6 +157,9 @@ public static void streamOutput(StreamOutput out, KNNQueryBuilder builder, Funct
if (minClusterVersionCheck.apply(METHOD_PARAMETER)) {
MethodParametersParser.streamOutput(out, builder.getMethodParameters(), IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}
if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) {
RescoreParser.streamOutput(out, builder.getRescoreContext());
}
}

/**
Expand Down Expand Up @@ -204,6 +225,9 @@ public static void toXContent(XContentBuilder builder, ToXContent.Params params,
if (knnQueryBuilder.getMethodParameters() != null) {
MethodParametersParser.doXContent(builder, knnQueryBuilder.getMethodParameters());
}
if (knnQueryBuilder.getRescoreContext() != null) {
RescoreParser.doXContent(builder, knnQueryBuilder.getRescoreContext());
}

builder.field(BOOST_FIELD.getPreferredName(), knnQueryBuilder.boost());
if (knnQueryBuilder.queryName() != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.parser;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD;

/**
* Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well
*/
@Getter
@AllArgsConstructor
@Log4j2
public final class RescoreParser {

public static final String RESCORE_PARAMETER = "rescore";
public static final String RESCORE_OVERSAMPLE_PARAMETER = "oversample_factor";

private static final ObjectParser<RescoreContext.RescoreContextBuilder, Void> INTERNAL_PARSER = createInternalObjectParser();

private static ObjectParser<RescoreContext.RescoreContextBuilder, Void> createInternalObjectParser() {
ObjectParser<RescoreContext.RescoreContextBuilder, Void> internalParser = new ObjectParser<>(
RESCORE_PARAMETER,
RescoreContext::builder
);
internalParser.declareFloat(RescoreContext.RescoreContextBuilder::oversampleFactor, RESCORE_OVERSAMPLE_FIELD);
return internalParser;
}

/**
* Validate the rescore context
*
* @return ValidationException if validation fails, null otherwise
*/
public static ValidationException validate(RescoreContext rescoreContext) {
if (rescoreContext.getOversampleFactor() < RescoreContext.MIN_OVERSAMPLE_FACTOR) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Oversample factor [%f] cannot be less than [%f]",
rescoreContext.getOversampleFactor(),
RescoreContext.MIN_OVERSAMPLE_FACTOR
)
);
return validationException;
}

if (rescoreContext.getOversampleFactor() > RescoreContext.MAX_OVERSAMPLE_FACTOR) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Oversample factor [%f] cannot be more than [%f]",
rescoreContext.getOversampleFactor(),
RescoreContext.MAX_OVERSAMPLE_FACTOR
)
);
return validationException;
}
return null;
}

/**
*
* @param in stream input
* @return RescoreContext
* @throws IOException on stream failure
*/
public static RescoreContext streamInput(StreamInput in) throws IOException {
if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), RESCORE_PARAMETER)) {
return null;
}
Float oversample = in.readOptionalFloat();
if (oversample == null) {
return null;
}
return RescoreContext.builder().oversampleFactor(oversample).build();
}

/**
*
* @param out stream output
* @param rescoreContext RescoreContext
* @throws IOException on stream failure
*/
public static void streamOutput(StreamOutput out, RescoreContext rescoreContext) throws IOException {
if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), RESCORE_PARAMETER)) {
return;
}
out.writeOptionalFloat(rescoreContext == null ? null : rescoreContext.getOversampleFactor());
}

/**
*
* @param builder XContentBuilder
* @param rescoreContext RescoreContext
* @throws IOException on XContent failure
*/
public static void doXContent(final XContentBuilder builder, final RescoreContext rescoreContext) throws IOException {
builder.startObject(RESCORE_PARAMETER);
builder.field(RESCORE_OVERSAMPLE_PARAMETER, rescoreContext.getOversampleFactor());
builder.endObject();
}

/**
*
* @param parser input parser
* @return RescoreContext
*/
public static RescoreContext fromXContent(final XContentParser parser) {
return INTERNAL_PARSER.apply(parser, null).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.rescore;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;

@Getter
@AllArgsConstructor
@Builder
@EqualsAndHashCode
public final class RescoreContext {

public static final float DEFAULT_OVERSAMPLE_FACTOR = 1.0f;
public static final float MAX_OVERSAMPLE_FACTOR = 100.0f;
public static final float MIN_OVERSAMPLE_FACTOR = 0.0f;

@Builder.Default
private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR;

/**
*
* @return default RescoreContext
*/
public static RescoreContext getDefault() {
return RescoreContext.builder().build();
}
}
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;

public class IndexUtil {

Expand All @@ -48,6 +49,8 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0;
// TODO: Will update once 2.17 backport change is merged
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_3_0_0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Look around if there is a need to have this in neural plugin

// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();

Expand Down Expand Up @@ -402,6 +405,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS);
put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE);
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import lombok.AllArgsConstructor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -86,13 +87,22 @@ public static Collection<Object[]> invalidParameters() {
"min score less than 0",
"[knn] requires minScore to be greater than 0",
KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(-1f)
),
$(
"Rescore context",
" cannot be less than",
KNNQueryBuilder.builder()
.rescoreContext(RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build())
.fieldName(FIELD_NAME)
.vector(QUERY_VECTOR)
.k(1)
)
)
);
}

public void testInvalidBuilder() {
Throwable exception = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilderBuilder.build());
assertEquals(expectedMessage, expectedMessage, exception.getMessage());
assertTrue(exception.getMessage(), exception.getMessage().contains(expectedMessage));
}
}
Loading
Loading