Skip to content

Commit

Permalink
Support k-NN radial search parameters in neural search
Browse files Browse the repository at this point in the history
  • Loading branch information
junqiu-lei committed Apr 18, 2024
1 parent e69752c commit cf2d2ec
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
@VisibleForTesting
static final ParseField K_FIELD = new ParseField("k");

@VisibleForTesting
static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");

@VisibleForTesting
static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");

private static final int DEFAULT_K = 10;

private static MLCommonsClientAccessor ML_CLIENT;
Expand All @@ -87,13 +93,16 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private String queryText;
private String queryImage;
private String modelId;
private int k = DEFAULT_K;
private Integer k = null;
private Float max_distance = null;
private Float min_score = null;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;

/**
* Constructor from stream input
Expand All @@ -113,6 +122,10 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
}
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
this.max_distance = in.readOptionalFloat();
this.min_score = in.readOptionalFloat();
}
}

@Override
Expand All @@ -127,6 +140,10 @@ protected void doWriteTo(StreamOutput out) throws IOException {
}
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
out.writeOptionalFloat(this.max_distance);
out.writeOptionalFloat(this.min_score);
}
}

@Override
Expand All @@ -137,10 +154,18 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (Objects.nonNull(k)) {
xContentBuilder.field(K_FIELD.getPreferredName(), k);
}
if (Objects.nonNull(filter)) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (Objects.nonNull(max_distance)) {
xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance);
}
if (Objects.nonNull(min_score)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), min_score);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -193,6 +218,12 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}

long queryCountProvided = validateKNNQueryType(neuralQueryBuilder);
if (queryCountProvided == 0) {
neuralQueryBuilder.k(DEFAULT_K);
}

return neuralQueryBuilder;
}

Expand All @@ -215,6 +246,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.queryName(parser.text());
} else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.boost(parser.floatValue());
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.max_distance(parser.floatValue());
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.min_score(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -246,7 +281,19 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
if (vectorSupplier() != null) {
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter());
if (vectorSupplier().get() == null) {
return this;
} else {
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (max_distance != null) {
knnQueryBuilder.maxDistance(max_distance);
} else if (min_score != null) {
knnQueryBuilder.minScore(min_score);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
}
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -263,7 +310,17 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter());
return new NeuralQueryBuilder(
fieldName(),
queryText(),
queryImage(),
modelId(),
k(),
max_distance(),
min_score(),
vectorSetOnce::get,
filter()
);
}

@Override
Expand Down Expand Up @@ -298,4 +355,25 @@ public String getWriteableName() {
private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}

private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

private static int validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
int queryCountProvided = 0;
if (neuralQueryBuilder.k() != null) {
queryCountProvided++;
}
if (neuralQueryBuilder.max_distance() != null) {
queryCountProvided++;
}
if (neuralQueryBuilder.min_score() != null) {
queryCountProvided++;
}
if (queryCountProvided > 1) {
throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided");
}
return queryCountProvided;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
modelId,
5,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -142,6 +144,8 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
modelId,
5,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -179,6 +183,8 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
modelId,
6,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -248,7 +248,9 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null));
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Map<String, Object> searchResponseAsMapL2Norm = search(
Expand Down Expand Up @@ -297,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -321,7 +323,9 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null));
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Map<String, Object> searchResponseAsMapL2Norm = search(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder();
hybridQueryBuilderArithmeticMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
hybridQueryBuilderHarmonicMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -140,7 +140,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
hybridQueryBuilderGeometricMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down Expand Up @@ -190,7 +190,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder();
hybridQueryBuilderArithmeticMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -215,7 +215,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
hybridQueryBuilderHarmonicMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -240,7 +240,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
hybridQueryBuilderGeometricMean.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null)
);
hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(K, neuralQueryBuilder.k());
assertEquals(K, (int) neuralQueryBuilder.k());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
// verify term query
Expand Down Expand Up @@ -602,7 +602,7 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder);
KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0);
assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName());
assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK());
assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK());
assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder);
TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1);
assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName());
Expand Down
Loading

0 comments on commit cf2d2ec

Please sign in to comment.