Skip to content

Commit

Permalink
Resolve feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 18, 2024
1 parent 2cd766b commit 9484101
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 36 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Support k-NN radial search parameters in neural search([#697](https://github.com/opensearch-project/neural-search/pull/697))
- Support k-NN radial search parameters in neural search([#697](https://github.com/opensearch-project/neural-search/pull/697))
### Enhancements
### Bug Fixes
- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private String queryImage;
private String modelId;
private Integer k = null;
private Float max_distance = null;
private Float min_score = null;
private Float maxDistance = null;
private Float minScore = null;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
Expand Down Expand Up @@ -123,8 +123,8 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
this.max_distance = in.readOptionalFloat();
this.min_score = in.readOptionalFloat();
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}
}

Expand All @@ -141,8 +141,8 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
out.writeOptionalFloat(this.max_distance);
out.writeOptionalFloat(this.min_score);
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}
}

Expand All @@ -160,11 +160,11 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(filter)) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (Objects.nonNull(max_distance)) {
xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance);
if (Objects.nonNull(maxDistance)) {
xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance);
}
if (Objects.nonNull(min_score)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), min_score);
if (Objects.nonNull(minScore)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
Expand Down Expand Up @@ -219,8 +219,8 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}

long queryCountProvided = validateKNNQueryType(neuralQueryBuilder);
if (queryCountProvided == 0) {
long queryCount = validateKNNQueryType(neuralQueryBuilder);
if (queryCount == 0) {
neuralQueryBuilder.k(DEFAULT_K);
}

Expand All @@ -247,9 +247,9 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
} else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.boost(parser.floatValue());
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.max_distance(parser.floatValue());
neuralQueryBuilder.maxDistance(parser.floatValue());
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.min_score(parser.floatValue());
neuralQueryBuilder.minScore(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -283,17 +283,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier() != null) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (max_distance != null) {
knnQueryBuilder.maxDistance(max_distance);
} else if (min_score != null) {
knnQueryBuilder.minScore(min_score);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -316,8 +315,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
queryImage(),
modelId(),
k(),
max_distance(),
min_score(),
maxDistance(),
minScore(),
vectorSetOnce::get,
filter()
);
Expand Down Expand Up @@ -361,19 +360,19 @@ private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
}

private static int validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
int queryCountProvided = 0;
int queryCount = 0;
if (neuralQueryBuilder.k() != null) {
queryCountProvided++;
queryCount++;
}
if (neuralQueryBuilder.max_distance() != null) {
queryCountProvided++;
if (neuralQueryBuilder.maxDistance() != null) {
queryCount++;
}
if (neuralQueryBuilder.min_score() != null) {
queryCountProvided++;
if (neuralQueryBuilder.minScore() != null) {
queryCount++;
}
if (queryCountProvided > 1) {
if (queryCount > 1) {
throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided");
}
return queryCountProvided;
return queryCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMaxDistance_then
assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(MAX_DISTANCE, neuralQueryBuilder.max_distance());
assertEquals(MAX_DISTANCE, neuralQueryBuilder.maxDistance());
}

@SneakyThrows
Expand Down Expand Up @@ -742,7 +742,7 @@ public void testFromXContent_whenBuiltWithDefaults_whenBuiltWithMinScore_thenBui
assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(MIN_SCORE, neuralQueryBuilder.min_score());
assertEquals(MIN_SCORE, neuralQueryBuilder.minScore());
}

@SneakyThrows
Expand Down

0 comments on commit 9484101

Please sign in to comment.