diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 56f9ffaf89..ff2db13142 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -27,6 +27,8 @@ public class KNNConstants { public static final String TYPE_KNN_VECTOR = "knn_vector"; public static final String PROPERTIES = "properties"; public static final String METHOD_PARAMETER = "method_parameters"; + public static final String RESCORE_PARAMETER = "rescore"; + public static final String RESCORE_OVERSAMPLE_PARAMETER = "oversample_factor"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 80833751ec..dd635322b8 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -50,6 +51,8 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.RESCORE_PARAMETER; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -73,6 +76,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES); public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); + public static final ParseField RESCORE_FIELD = new ParseField(RESCORE_PARAMETER); + public static final ParseField RESCORE_OVERSAMPLE_FIELD = new ParseField(RESCORE_OVERSAMPLE_PARAMETER); public static final int K_MAX = 10000; /** @@ -96,6 +101,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private QueryBuilder filter; @Getter private boolean ignoreUnmapped; + @Getter + private RescoreContext rescoreContext; /** * Constructs a new query with the given field name and vector @@ -136,6 +143,7 @@ public static class Builder { private boolean ignoreUnmapped; private String queryName; private float boost = DEFAULT_BOOST; + private RescoreContext rescoreContext; public Builder() {} @@ -189,11 +197,25 @@ public Builder boost(float boost) { return this; } + public Builder rescoreContext(RescoreContext rescoreContext) { + this.rescoreContext = rescoreContext; + return this; + } + public KNNQueryBuilder build() { validate(); int k = this.k == null ? 0 : this.k; - return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost) - .queryName(queryName); + return new KNNQueryBuilder( + fieldName, + vector, + k, + maxDistance, + minScore, + methodParameters, + filter, + ignoreUnmapped, + rescoreContext + ).boost(boost).queryName(queryName); } private void validate() { @@ -240,6 +262,15 @@ private void validate() { ); } } + + if (rescoreContext != null) { + ValidationException validationException = rescoreContext.validate(); + if (validationException != null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] errors in rescore parameter [%s]", NAME, validationException.getMessage()) + ); + } + } } } @@ -284,6 +315,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.ignoreUnmapped = false; this.maxDistance = null; this.minScore = null; + this.rescoreContext = null; } public static void initialize(ModelDao modelDao) { @@ -305,6 +337,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException { maxDistance = builder.maxDistance; minScore = builder.minScore; methodParameters = builder.methodParameters; + rescoreContext = builder.rescoreContext; } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java index 4564392121..6996987821 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -29,6 +30,8 @@ import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.RESCORE_PARAMETER; +import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD; @@ -79,6 +82,17 @@ private static ObjectParser createInternalObjectP ); internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD); + internalParser.declareObjectOrDefault( + KNNQueryBuilder.Builder::rescoreContext, + (p, v) -> RescoreParser.fromXContent(p), + RescoreContext::getDefault, + RESCORE_FIELD + ); + + // Declare fields that cannot be set at the same time. Right now, rescore and radial is not supported + internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MAX_DISTANCE_FIELD.getPreferredName()); + internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MIN_SCORE_FIELD.getPreferredName()); + return internalParser; } @@ -110,6 +124,10 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function INTERNAL_PARSER = createInternalObjectParser(); + + private static ObjectParser createInternalObjectParser() { + ObjectParser internalParser = new ObjectParser<>( + RESCORE_PARAMETER, + RescoreContext::builder + ); + internalParser.declareFloat(RescoreContext.RescoreContextBuilder::oversampleFactor, RESCORE_OVERSAMPLE_FIELD); + return internalParser; + } + + /** + * + * @param in stream input + * @param minClusterVersionCheck function to check if the cluster version meets the minimum requirement + * @return RescoreContext + * @throws IOException on stream failure + */ + public static RescoreContext streamInput(StreamInput in, Function minClusterVersionCheck) throws IOException { + if (!in.readBoolean()) { + return null; + } + + RescoreContext.RescoreContextBuilder builder = RescoreContext.builder(); + if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) { + builder.oversampleFactor(in.readFloat()); + } + return builder.build(); + } + + /** + * + * @param out stream output + * @param rescoreContext RescoreContext + * @param minClusterVersionCheck function to check if the cluster version meets the minimum requirement + * @throws IOException on stream failure + */ + public static void streamOutput(StreamOutput out, RescoreContext rescoreContext, Function minClusterVersionCheck) + throws IOException { + if (rescoreContext == null) { + out.writeBoolean(false); + return; + } + + out.writeBoolean(true); + if (minClusterVersionCheck.apply(RESCORE_PARAMETER)) { + out.writeFloat(rescoreContext.getOversampleFactor()); + } + } + + /** + * + * @param builder XContentBuilder + * @param rescoreContext RescoreContext + * @throws IOException on XContent failure + */ + public static void doXContent(final XContentBuilder builder, final RescoreContext rescoreContext) throws IOException { + builder.startObject(RESCORE_PARAMETER); + builder.field(RESCORE_OVERSAMPLE_PARAMETER, rescoreContext.getOversampleFactor()); + builder.endObject(); + } + + /** + * + * @param parser input parser + * @return RescoreContext + */ + public static RescoreContext fromXContent(final XContentParser parser) { + return INTERNAL_PARSER.apply(parser, null).build(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java new file mode 100644 index 0000000000..14ffc8edcb --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.rescore; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.opensearch.common.ValidationException; + +@Slf4j +@Getter +@AllArgsConstructor +@Builder +@EqualsAndHashCode +public final class RescoreContext { + + public static final float DEFAULT_OVERSAMPLE_FACTOR = 1.0f; + public static final float MAX_OVERSAMPLE_FACTOR = 5.0f; + public static final float MIN_OVERSAMPLE_FACTOR = 0.0f; + + @Builder.Default + private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR; + + /** + * Validate the rescore context + * + * @return ValidationException if validation fails, null otherwise + */ + public ValidationException validate() { + if (oversampleFactor < RescoreContext.MIN_OVERSAMPLE_FACTOR) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + "Oversample factor [%f] cannot be less than [%f]", + getOversampleFactor(), + RescoreContext.MIN_OVERSAMPLE_FACTOR + ) + ); + return validationException; + } + + if (oversampleFactor > RescoreContext.MAX_OVERSAMPLE_FACTOR) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + "Oversample factor [%f] cannot be more than [%f]", + getOversampleFactor(), + RescoreContext.MAX_OVERSAMPLE_FACTOR + ) + ); + return validationException; + } + return null; + } + + /** + * + * @return default RescoreContext + */ + public static RescoreContext getDefault() { + return RescoreContext.builder().build(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 7713ff40ee..306ecbdd71 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -48,6 +48,7 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0; + private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); @@ -402,6 +403,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE); + put(KNNConstants.RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE); } }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 25982fb7d6..40a68c0de8 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -30,6 +30,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -776,19 +777,23 @@ public void testDoToQuery_InvalidZeroByteVector() { public void testSerialization() throws Exception { // For k-NN search - assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null); - assertSerialization(Version.CURRENT, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE, null); + + // Test rescore + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null, RescoreContext.getDefault()); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null, RescoreContext.getDefault()); } private void assertSerialization( @@ -797,7 +802,8 @@ private void assertSerialization( Integer k, Map methodParameters, Float distance, - Float score + Float score, + RescoreContext rescoreContext ) throws Exception { final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() .fieldName(FIELD_NAME) @@ -807,6 +813,7 @@ private void assertSerialization( .k(k) .methodParameters(methodParameters) .filter(queryBuilderOptional.orElse(null)) + .rescoreContext(rescoreContext) .build(); final ClusterService clusterService = mockClusterService(version); @@ -840,6 +847,7 @@ private void assertSerialization( assertNull(deserializedKnnQueryBuilder.getFilter()); } assertMethodParameters(version, methodParameters, deserializedKnnQueryBuilder.getMethodParameters()); + assertRescore(version, rescoreContext, deserializedKnnQueryBuilder.getRescoreContext()); } } } @@ -854,6 +862,17 @@ private void assertMethodParameters(Version version, Map expectedMeth } } + private void assertRescore(Version version, RescoreContext expectedRescoreContext, RescoreContext actualRescoreContext) { + if (!version.onOrAfter(Version.V_2_17_0)) { + assertNull(actualRescoreContext); + return; + } + + if (expectedRescoreContext != null) { + assertEquals(expectedRescoreContext, actualRescoreContext); + } + } + public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder() diff --git a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java index 713e532f9f..b450eb0271 100644 --- a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java +++ b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java @@ -17,6 +17,8 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.plugins.SearchPlugin; @@ -470,6 +472,38 @@ public void testToXContent_whenMethodParams_thenSucceed() throws IOException { assertEquals(builder.toString(), testBuilder.toString()); } + public void testToXContent_whenRescore_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float oversample = 1.0f; + XContentBuilder builderFromObject = XContentFactory.jsonBuilder() + .startObject() + .startObject(NAME) + .startObject(FIELD_NAME) + .field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector) + .field(KNNQueryBuilder.K_FIELD.getPreferredName(), K) + .startObject(KNNConstants.RESCORE_PARAMETER) + .field(KNNConstants.RESCORE_OVERSAMPLE_PARAMETER, oversample) + .endObject() + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject(); + + KNNQueryBuilder knnQueryBuilderFromObject = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .boost(BOOST) + .k(K) + .rescoreContext(RescoreContext.builder().oversampleFactor(oversample).build()) + .build(); + + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilderFromObject); + testBuilder.endObject(); + assertEquals(builderFromObject.toString(), testBuilder.toString()); + } + @Override protected NamedXContentRegistry xContentRegistry() { List list = ClusterModule.getNamedXWriteables(); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java new file mode 100644 index 0000000000..94c324f30b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.SneakyThrows; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; + +import java.io.IOException; + +import static org.opensearch.knn.common.KNNConstants.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.RESCORE_PARAMETER; + +public class RescoreParserTests extends KNNTestCase { + + @SneakyThrows + public void testStreams() { + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.DEFAULT_OVERSAMPLE_FACTOR).build(); + validateStreams(rescoreContext); + validateStreams(null); + } + + private void validateStreams(RescoreContext rescoreContext) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + RescoreParser.streamOutput(output, rescoreContext, x -> true); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + RescoreContext parsedRescoreContext = RescoreParser.streamInput(in, x -> true); + assertEquals(rescoreContext, parsedRescoreContext); + } + } + } + + @SneakyThrows + public void testDoXContent() { + float oversample = RescoreContext.MAX_OVERSAMPLE_FACTOR - 1; + XContentBuilder expectedBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, oversample) + .endObject() + .endObject(); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + RescoreParser.doXContent(builder, RescoreContext.builder().oversampleFactor(oversample).build()); + builder.endObject(); + assertEquals(expectedBuilder.toString(), builder.toString()); + } + + @SneakyThrows + public void testFromXContent_whenValid_thenSucceed() { + float oversample1 = RescoreContext.MAX_OVERSAMPLE_FACTOR - 1; + XContentBuilder builder1 = XContentFactory.jsonBuilder().startObject().field(RESCORE_OVERSAMPLE_PARAMETER, oversample1).endObject(); + validateOversample(oversample1, builder1); + XContentBuilder builder2 = XContentFactory.jsonBuilder().startObject().endObject(); + validateOversample(RescoreContext.DEFAULT_OVERSAMPLE_FACTOR, builder2); + } + + @SneakyThrows + public void testFromXContent_whenInvalid_thenFail() { + XContentBuilder invalidParamBuilder = XContentFactory.jsonBuilder().startObject().field("invalid", 0).endObject(); + expectValidationException(invalidParamBuilder); + + XContentBuilder invalidParamValueBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(RESCORE_OVERSAMPLE_PARAMETER, "c") + .endObject(); + expectValidationException(invalidParamValueBuilder); + + XContentBuilder extraParamBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(RESCORE_OVERSAMPLE_PARAMETER, RescoreContext.MAX_OVERSAMPLE_FACTOR - 1) + .field("invalid", 0) + .endObject(); + expectValidationException(extraParamBuilder); + } + + private void validateOversample(float expectedOversample, XContentBuilder builder) throws IOException { + XContentParser parser = createParser(builder); + RescoreContext rescoreContext = RescoreParser.fromXContent(parser); + assertEquals(expectedOversample, rescoreContext.getOversampleFactor(), 0.0001); + } + + private void expectValidationException(XContentBuilder builder) throws IOException { + XContentParser parser = createParser(builder); + expectThrows(IllegalArgumentException.class, () -> RescoreParser.fromXContent(parser)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java new file mode 100644 index 0000000000..11bc0681cf --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.rescore; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Collection; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class RescoreContextTests extends KNNTestCase { + + private boolean isValid; + private RescoreContext rescoreContext; + + @ParametersFactory(argumentFormatting = "isValid:%1$s; rescoreContext:%2$s") + public static Collection validParams() { + return Arrays.asList( + $$( + $(true, RescoreContext.builder().build()), + $(true, RescoreContext.getDefault()), + $(true, RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR - 1).build()), + $(false, RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR + 1).build()), + $(false, RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build()) + ) + ); + } + + public void testValidate() { + if (isValid) { + assertNull(rescoreContext.validate()); + } else { + assertNotNull(rescoreContext.validate()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/integ/QueryParseIT.java b/src/test/java/org/opensearch/knn/integ/QueryParseIT.java new file mode 100644 index 0000000000..9308e08112 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/QueryParseIT.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; + +import java.io.IOException; +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.RESCORE_PARAMETER; + +public class QueryParseIT extends KNNRestTestCase { + + private final static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f }; + private final static int DIMENSION = 2; + private final static int K = 1; + + @SneakyThrows + public void testRescore() { + createTestIndex(); + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).startObject("rescore").endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, 2) + .endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).startObject(RESCORE_PARAMETER).endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, true) + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, false) + ) + ) + ); + + // Invalid value for rescore + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, "invalid") + ) + ) + ); + + // Invalid rescore param + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_OVERSAMPLE_PARAMETER) + .field("invalid_param", "invalid") + .endObject() + ) + ) + ); + + // Invalid rescore param value + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, "invalid") + .endObject() + ) + ) + ); + } + + private XContentBuilder setupQueryXContentBuilder() throws IOException { + return XContentFactory.jsonBuilder().startObject().startObject("query").startObject("knn").startObject(FIELD_NAME); + } + + private XContentBuilder closeQueryXContentBuilder(XContentBuilder xContentBuilder) throws IOException { + return xContentBuilder.endObject().endObject().endObject().endObject(); + } + + private void assertValid(Request request) throws IOException { + assertOK(client().performRequest(request)); + } + + private void assertInvalid(Request request) { + expectThrows(ResponseException.class, () -> client().performRequest(request)); + } + + private Request buildRequest(XContentBuilder xContentBuilder) { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", INDEX_NAME)); + request.addParameter("size", Integer.toString(10)); + request.addParameter("explain", Boolean.toString(true)); + request.setJsonEntity(xContentBuilder.toString()); + return request; + } + + private void createTestIndex() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, mapping); + } +}