diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 80833751e..af8e410d4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -26,6 +26,8 @@ import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.query.parser.RescoreParser; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -54,6 +56,8 @@ import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; import static org.opensearch.knn.index.engine.validation.ParameterValidator.validateParameters; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; /** * Helper class to build the KNN query @@ -73,6 +77,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES); public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); + public static final ParseField RESCORE_FIELD = new ParseField(RESCORE_PARAMETER); + public static final ParseField RESCORE_OVERSAMPLE_FIELD = new ParseField(RESCORE_OVERSAMPLE_PARAMETER); public static final int K_MAX = 10000; /** @@ -96,6 +102,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private QueryBuilder filter; @Getter private boolean ignoreUnmapped; + @Getter + private RescoreContext rescoreContext; /** * Constructs a new query with the given field name and vector @@ -136,6 +144,7 @@ public static class Builder { private boolean ignoreUnmapped; private String queryName; private float boost = DEFAULT_BOOST; + private RescoreContext rescoreContext; public Builder() {} @@ -189,11 +198,25 @@ public Builder boost(float boost) { return this; } + public Builder rescoreContext(RescoreContext rescoreContext) { + this.rescoreContext = rescoreContext; + return this; + } + public KNNQueryBuilder build() { validate(); int k = this.k == null ? 0 : this.k; - return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost) - .queryName(queryName); + return new KNNQueryBuilder( + fieldName, + vector, + k, + maxDistance, + minScore, + methodParameters, + filter, + ignoreUnmapped, + rescoreContext + ).boost(boost).queryName(queryName); } private void validate() { @@ -240,6 +263,15 @@ private void validate() { ); } } + + if (rescoreContext != null) { + ValidationException validationException = RescoreParser.validate(rescoreContext); + if (validationException != null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] errors in rescore parameter [%s]", NAME, validationException.getMessage()) + ); + } + } } } @@ -284,6 +316,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.ignoreUnmapped = false; this.maxDistance = null; this.minScore = null; + this.rescoreContext = null; } public static void initialize(ModelDao modelDao) { @@ -305,6 +338,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException { maxDistance = builder.maxDistance; minScore = builder.minScore; methodParameters = builder.methodParameters; + rescoreContext = builder.rescoreContext; } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java index 456439212..02fbd0113 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -29,6 +30,8 @@ import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD; @@ -79,6 +82,17 @@ private static ObjectParser createInternalObjectP ); internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD); + internalParser.declareObjectOrDefault( + KNNQueryBuilder.Builder::rescoreContext, + (p, v) -> RescoreParser.fromXContent(p), + RescoreContext::getDefault, + RESCORE_FIELD + ); + + // Declare fields that cannot be set at the same time. Right now, rescore and radial is not supported + internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MAX_DISTANCE_FIELD.getPreferredName()); + internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MIN_SCORE_FIELD.getPreferredName()); + return internalParser; } @@ -110,6 +124,10 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function INTERNAL_PARSER = createInternalObjectParser(); + + private static ObjectParser createInternalObjectParser() { + ObjectParser internalParser = new ObjectParser<>( + RESCORE_PARAMETER, + RescoreContext::builder + ); + internalParser.declareFloat(RescoreContext.RescoreContextBuilder::oversampleFactor, RESCORE_OVERSAMPLE_FIELD); + return internalParser; + } + + /** + * Validate the rescore context + * + * @return ValidationException if validation fails, null otherwise + */ + public static ValidationException validate(RescoreContext rescoreContext) { + if (rescoreContext.getOversampleFactor() < RescoreContext.MIN_OVERSAMPLE_FACTOR) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + Locale.ROOT, + "Oversample factor [%f] cannot be less than [%f]", + rescoreContext.getOversampleFactor(), + RescoreContext.MIN_OVERSAMPLE_FACTOR + ) + ); + return validationException; + } + + if (rescoreContext.getOversampleFactor() > RescoreContext.MAX_OVERSAMPLE_FACTOR) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + Locale.ROOT, + "Oversample factor [%f] cannot be more than [%f]", + rescoreContext.getOversampleFactor(), + RescoreContext.MAX_OVERSAMPLE_FACTOR + ) + ); + return validationException; + } + return null; + } + + /** + * + * @param in stream input + * @return RescoreContext + * @throws IOException on stream failure + */ + public static RescoreContext streamInput(StreamInput in) throws IOException { + if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), RESCORE_PARAMETER)) { + return null; + } + Float oversample = in.readOptionalFloat(); + if (oversample == null) { + return null; + } + return RescoreContext.builder().oversampleFactor(oversample).build(); + } + + /** + * + * @param out stream output + * @param rescoreContext RescoreContext + * @throws IOException on stream failure + */ + public static void streamOutput(StreamOutput out, RescoreContext rescoreContext) throws IOException { + if (!IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), RESCORE_PARAMETER)) { + return; + } + out.writeOptionalFloat(rescoreContext == null ? null : rescoreContext.getOversampleFactor()); + } + + /** + * + * @param builder XContentBuilder + * @param rescoreContext RescoreContext + * @throws IOException on XContent failure + */ + public static void doXContent(final XContentBuilder builder, final RescoreContext rescoreContext) throws IOException { + builder.startObject(RESCORE_PARAMETER); + builder.field(RESCORE_OVERSAMPLE_PARAMETER, rescoreContext.getOversampleFactor()); + builder.endObject(); + } + + /** + * + * @param parser input parser + * @return RescoreContext + */ + public static RescoreContext fromXContent(final XContentParser parser) { + return INTERNAL_PARSER.apply(parser, null).build(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java new file mode 100644 index 000000000..82b09807a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.rescore; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; + +@Getter +@AllArgsConstructor +@Builder +@EqualsAndHashCode +public final class RescoreContext { + + public static final float DEFAULT_OVERSAMPLE_FACTOR = 1.0f; + public static final float MAX_OVERSAMPLE_FACTOR = 100.0f; + public static final float MIN_OVERSAMPLE_FACTOR = 0.0f; + + @Builder.Default + private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR; + + /** + * + * @return default RescoreContext + */ + public static RescoreContext getDefault() { + return RescoreContext.builder().build(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 7713ff40e..eb184efdd 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -36,6 +36,7 @@ import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; public class IndexUtil { @@ -48,6 +49,8 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0; + // TODO: Will update once 2.17 backport change is merged + private static final Version MINIMAL_RESCORE_FEATURE = Version.V_3_0_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); @@ -402,6 +405,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE); + put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE); } }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java index 74c1cca58..29f2d2368 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java @@ -8,6 +8,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.AllArgsConstructor; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; import java.util.Arrays; import java.util.Collection; @@ -86,6 +87,15 @@ public static Collection invalidParameters() { "min score less than 0", "[knn] requires minScore to be greater than 0", KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(-1f) + ), + $( + "Rescore context", + " cannot be less than", + KNNQueryBuilder.builder() + .rescoreContext(RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build()) + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .k(1) ) ) ); @@ -93,6 +103,6 @@ public static Collection invalidParameters() { public void testInvalidBuilder() { Throwable exception = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilderBuilder.build()); - assertEquals(expectedMessage, expectedMessage, exception.getMessage()); + assertTrue(exception.getMessage(), exception.getMessage().contains(expectedMessage)); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 25982fb7d..762a36227 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -30,6 +30,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -776,19 +777,23 @@ public void testDoToQuery_InvalidZeroByteVector() { public void testSerialization() throws Exception { // For k-NN search - assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null); - assertSerialization(Version.CURRENT, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE, null); + + // Test rescore + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null, RescoreContext.getDefault()); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null, RescoreContext.getDefault()); } private void assertSerialization( @@ -797,7 +802,8 @@ private void assertSerialization( Integer k, Map methodParameters, Float distance, - Float score + Float score, + RescoreContext rescoreContext ) throws Exception { final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() .fieldName(FIELD_NAME) @@ -807,6 +813,7 @@ private void assertSerialization( .k(k) .methodParameters(methodParameters) .filter(queryBuilderOptional.orElse(null)) + .rescoreContext(rescoreContext) .build(); final ClusterService clusterService = mockClusterService(version); @@ -818,7 +825,7 @@ private void assertSerialization( output.writeNamedWriteable(knnQueryBuilder); try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { - in.setVersion(Version.CURRENT); + in.setVersion(version); final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); assertNotNull(deserializedQuery); @@ -840,6 +847,7 @@ private void assertSerialization( assertNull(deserializedKnnQueryBuilder.getFilter()); } assertMethodParameters(version, methodParameters, deserializedKnnQueryBuilder.getMethodParameters()); + assertRescore(version, rescoreContext, deserializedKnnQueryBuilder.getRescoreContext()); } } } @@ -854,6 +862,17 @@ private void assertMethodParameters(Version version, Map expectedMeth } } + private void assertRescore(Version version, RescoreContext expectedRescoreContext, RescoreContext actualRescoreContext) { + if (!version.onOrAfter(Version.V_2_17_0)) { + assertNull(actualRescoreContext); + return; + } + + if (expectedRescoreContext != null) { + assertEquals(expectedRescoreContext, actualRescoreContext); + } + } + public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java index 4b97df4b4..7b5e59224 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java @@ -8,6 +8,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.AllArgsConstructor; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; import java.util.Arrays; import java.util.Collection; @@ -28,8 +29,9 @@ public class KNNQueryBuilderValidParamsTests extends KNNTestCase { private Map methodParameters; private Float maxDistance; private Float minScore; + private RescoreContext rescoreContext; - @ParametersFactory(argumentFormatting = "description:%1$s; k:%3$s, efSearch:%4$s, maxDist:%5$s, minScore:%6$s") + @ParametersFactory(argumentFormatting = "description:%1$s; k:%3$s, efSearch:%4$s, maxDist:%5$s, minScore:%6$s, rescoreContext:%6$s") public static Collection validParameters() { return Arrays.asList( $$( @@ -39,6 +41,7 @@ public static Collection validParameters() { 10, null, null, + null, null ), $( @@ -52,6 +55,7 @@ public static Collection validParameters() { 10, Map.of("ef_search", 12), null, + null, null ), $( @@ -60,6 +64,7 @@ public static Collection validParameters() { null, null, 10.0f, + null, null ), $( @@ -68,7 +73,17 @@ public static Collection validParameters() { null, null, null, - 10.0f + 10.0f, + null + ), + $( + "valid knn with rescore", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(10.0f).build(), + null, + null, + null, + 10.0f, + RescoreContext.getDefault() ) ) ); @@ -84,6 +99,7 @@ public void testValidBuilder() { .methodParameters(methodParameters) .maxDistance(maxDistance) .minScore(minScore) + .rescoreContext(rescoreContext) .build() ); } diff --git a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java index 713e532f9..6cac5580b 100644 --- a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java +++ b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java @@ -17,6 +17,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.plugins.SearchPlugin; @@ -31,6 +32,8 @@ import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; public class KNNQueryBuilderParserTests extends KNNTestCase { @@ -470,6 +473,38 @@ public void testToXContent_whenMethodParams_thenSucceed() throws IOException { assertEquals(builder.toString(), testBuilder.toString()); } + public void testToXContent_whenRescore_thenSucceed() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float oversample = 1.0f; + XContentBuilder builderFromObject = XContentFactory.jsonBuilder() + .startObject() + .startObject(NAME) + .startObject(FIELD_NAME) + .field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), queryVector) + .field(KNNQueryBuilder.K_FIELD.getPreferredName(), K) + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, oversample) + .endObject() + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject(); + + KNNQueryBuilder knnQueryBuilderFromObject = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .boost(BOOST) + .k(K) + .rescoreContext(RescoreContext.builder().oversampleFactor(oversample).build()) + .build(); + + XContentBuilder testBuilder = XContentFactory.jsonBuilder(); + testBuilder.startObject(); + KNNQueryBuilderParser.toXContent(testBuilder, EMPTY_PARAMS, knnQueryBuilderFromObject); + testBuilder.endObject(); + assertEquals(builderFromObject.toString(), testBuilder.toString()); + } + @Override protected NamedXContentRegistry xContentRegistry() { List list = ClusterModule.getNamedXWriteables(); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java new file mode 100644 index 000000000..2bb1f89fc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/RescoreParserTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.SneakyThrows; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; + +import java.io.IOException; + +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; + +public class RescoreParserTests extends KNNTestCase { + + @SneakyThrows + public void testStreams() { + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.DEFAULT_OVERSAMPLE_FACTOR).build(); + validateStreams(rescoreContext); + validateStreams(null); + } + + private void validateStreams(RescoreContext rescoreContext) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + RescoreParser.streamOutput(output, rescoreContext); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + RescoreContext parsedRescoreContext = RescoreParser.streamInput(in); + assertEquals(rescoreContext, parsedRescoreContext); + } + } + } + + @SneakyThrows + public void testDoXContent() { + float oversample = RescoreContext.MAX_OVERSAMPLE_FACTOR - 1; + XContentBuilder expectedBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, oversample) + .endObject() + .endObject(); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + RescoreParser.doXContent(builder, RescoreContext.builder().oversampleFactor(oversample).build()); + builder.endObject(); + assertEquals(expectedBuilder.toString(), builder.toString()); + } + + @SneakyThrows + public void testFromXContent_whenValid_thenSucceed() { + float oversample1 = RescoreContext.MAX_OVERSAMPLE_FACTOR - 1; + XContentBuilder builder1 = XContentFactory.jsonBuilder().startObject().field(RESCORE_OVERSAMPLE_PARAMETER, oversample1).endObject(); + validateOversample(oversample1, builder1); + XContentBuilder builder2 = XContentFactory.jsonBuilder().startObject().endObject(); + validateOversample(RescoreContext.DEFAULT_OVERSAMPLE_FACTOR, builder2); + } + + @SneakyThrows + public void testFromXContent_whenInvalid_thenFail() { + XContentBuilder invalidParamBuilder = XContentFactory.jsonBuilder().startObject().field("invalid", 0).endObject(); + expectValidationException(invalidParamBuilder); + + XContentBuilder invalidParamValueBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(RESCORE_OVERSAMPLE_PARAMETER, "c") + .endObject(); + expectValidationException(invalidParamValueBuilder); + + XContentBuilder extraParamBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(RESCORE_OVERSAMPLE_PARAMETER, RescoreContext.MAX_OVERSAMPLE_FACTOR - 1) + .field("invalid", 0) + .endObject(); + expectValidationException(extraParamBuilder); + } + + private void validateOversample(float expectedOversample, XContentBuilder builder) throws IOException { + XContentParser parser = createParser(builder); + RescoreContext rescoreContext = RescoreParser.fromXContent(parser); + assertEquals(expectedOversample, rescoreContext.getOversampleFactor(), 0.0001); + } + + private void expectValidationException(XContentBuilder builder) throws IOException { + XContentParser parser = createParser(builder); + expectThrows(IllegalArgumentException.class, () -> RescoreParser.fromXContent(parser)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/parser/RescoreValidationTests.java b/src/test/java/org/opensearch/knn/index/query/parser/RescoreValidationTests.java new file mode 100644 index 000000000..a23a5c757 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/RescoreValidationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.parser; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.query.rescore.RescoreContext; + +import java.util.Arrays; +import java.util.Collection; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class RescoreValidationTests extends KNNTestCase { + + private boolean isValid; + private RescoreContext rescoreContext; + + @ParametersFactory(argumentFormatting = "isValid:%1$s; rescoreContext:%2$s") + public static Collection validParams() { + return Arrays.asList( + $$( + $(true, RescoreContext.builder().build()), + $(true, RescoreContext.getDefault()), + $(true, RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR - 1).build()), + $(false, RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR + 1).build()), + $(false, RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build()) + ) + ); + } + + public void testValidate() { + if (isValid) { + assertNull(RescoreParser.validate(rescoreContext)); + } else { + assertNotNull(RescoreParser.validate(rescoreContext)); + } + } +} diff --git a/src/test/java/org/opensearch/knn/integ/QueryParseIT.java b/src/test/java/org/opensearch/knn/integ/QueryParseIT.java new file mode 100644 index 000000000..bcaf6be12 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/QueryParseIT.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; + +import java.io.IOException; +import java.util.Locale; + +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_OVERSAMPLE_PARAMETER; +import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; + +public class QueryParseIT extends KNNRestTestCase { + + private final static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f }; + private final static int DIMENSION = 2; + private final static int K = 1; + + @SneakyThrows + public void testRescore() { + createTestIndex(); + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).startObject("rescore").endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, 2) + .endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).startObject(RESCORE_PARAMETER).endObject() + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, true) + ) + ) + ); + + assertValid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, false) + ) + ) + ); + + // Invalid value for rescore + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR).field("k", K).field(RESCORE_PARAMETER, "invalid") + ) + ) + ); + + // Invalid rescore param + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_OVERSAMPLE_PARAMETER) + .field("invalid_param", "invalid") + .endObject() + ) + ) + ); + + // Invalid rescore param value + assertInvalid( + buildRequest( + closeQueryXContentBuilder( + setupQueryXContentBuilder().field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RESCORE_PARAMETER) + .field(RESCORE_OVERSAMPLE_PARAMETER, "invalid") + .endObject() + ) + ) + ); + } + + private XContentBuilder setupQueryXContentBuilder() throws IOException { + return XContentFactory.jsonBuilder().startObject().startObject("query").startObject("knn").startObject(FIELD_NAME); + } + + private XContentBuilder closeQueryXContentBuilder(XContentBuilder xContentBuilder) throws IOException { + return xContentBuilder.endObject().endObject().endObject().endObject(); + } + + private void assertValid(Request request) throws IOException { + assertOK(client().performRequest(request)); + } + + private void assertInvalid(Request request) { + expectThrows(ResponseException.class, () -> client().performRequest(request)); + } + + private Request buildRequest(XContentBuilder xContentBuilder) { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", INDEX_NAME)); + request.addParameter("size", Integer.toString(10)); + request.addParameter("explain", Boolean.toString(true)); + request.setJsonEntity(xContentBuilder.toString()); + return request; + } + + private void createTestIndex() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, mapping); + } +}