Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parsing logic of the Query Builder (second try) #1824

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .idea/copyright/SPDX_ALv2.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
### Documentation
### Maintenance
### Refactoring
### Refactoring
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using this copyright

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import org.opensearch.cluster.ClusterModule;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

/**
* Benchmarks for impact of changes around query parsing
*/
@Warmup(iterations = 5, time = 10)
@Measurement(iterations = 3, time = 10)
@Fork(3)
@State(Scope.Benchmark)
public class QueryParsingBenchmarks {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to have this file in knn repo?
To make this code useful, I think we should force it to run for every change in query parser and fail if there is a degradation. Otherwise, I don't think it is worth to add it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be kept. I think its fine to have this for testing in the future if we want to make some other change around this functionality.

I think we should force it to run for every change in query parser and fail if there is a degradation.

Sure, this is something we could consider in future, but out of scope for this PR. Want to try to get this change in without having to change multiple other things.

private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
private static final NamedXContentRegistry NAMED_X_CONTENT_REGISTRY = xContentRegistry();

@Param({ "128", "1024" })
private int dimension;
@Param({ "basic", "filter" })
private String type;

private BytesReference bytesReference;

@Setup
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject("test");
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), generateVectorWithOnes(dimension));
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), 1);
if (type.equals("filter")) {
builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), TERM_QUERY);
}
builder.endObject();
builder.endObject();
bytesReference = BytesReference.bytes(builder);
}

@Benchmark
public void fromXContent(final Blackhole bh) throws IOException {
XContentParser xContentParser = createParser();
bh.consume(KNNQueryBuilderParser.fromXContent(xContentParser));
}

private XContentParser createParser() throws IOException {
XContentParser contentParser = createParser(bytesReference);
contentParser.nextToken();
return contentParser;
}

private float[] generateVectorWithOnes(final int dimensions) {
float[] vector = new float[dimensions];
Arrays.fill(vector, (float) 1);
return vector;
}

private XContentParser createParser(final BytesReference data) throws IOException {
BytesArray array = (BytesArray) data;
return JsonXContent.jsonXContent.createParser(
NAMED_X_CONTENT_REGISTRY,
LoggingDeprecationHandler.INSTANCE,
array.array(),
array.offset(),
array.length()
);
}

private static NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
SearchPlugin.QuerySpec<?> spec = new SearchPlugin.QuerySpec<>(
TermQueryBuilder.NAME,
TermQueryBuilder::new,
TermQueryBuilder::fromXContent
);
list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p)));
return new NamedXContentRegistry(list);
}
}
183 changes: 14 additions & 169 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,23 @@
import org.apache.lucene.search.Query;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.util.QueryContext;
Expand All @@ -44,7 +40,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand All @@ -55,7 +50,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters;
import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
import static org.opensearch.knn.validation.ParameterValidator.validateParameters;
Expand All @@ -78,6 +72,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);

public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand Down Expand Up @@ -141,7 +136,7 @@ public static class Builder {
private String queryName;
private float boost = DEFAULT_BOOST;

private Builder() {}
public Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
Expand Down Expand Up @@ -294,154 +289,26 @@ public static void initialize(ModelDao modelDao) {
KNNQueryBuilder.modelDao = modelDao;
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME)
);
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME)
);
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
return vec;
}

/**
* @param in Reads from stream
* @throws IOException Throws IO Exception
*/
public KNNQueryBuilder(StreamInput in) throws IOException {
super(in);
try {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
maxDistance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) {
methodParameters = MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
}

public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException {
String fieldName = null;
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float maxDistance = null;
Float minScore = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
boolean ignoreUnmapped = false;
Map<String, ?> methodParameters = null;
XContentParser.Token token;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName);
fieldName = currentFieldName;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue() || token == XContentParser.Token.START_ARRAY) {
if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
vector = parser.list();
} else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
boost = parser.floatValue();
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
filter = parseInnerQueryBuilder(parser);
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
methodParameters = MethodParametersParser.fromXContent(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
} else {
throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, parser.currentName());
fieldName = parser.currentName();
vector = parser.list();
}
}

return KNNQueryBuilder.builder()
.queryName(queryName)
.boost(boost)
.fieldName(fieldName)
.vector(ObjectsToFloats(vector))
.k(k)
.maxDistance(maxDistance)
.minScore(minScore)
.methodParameters(methodParameters)
.ignoreUnmapped(ignoreUnmapped)
.filter(filter)
.build();
KNNQueryBuilder.Builder builder = KNNQueryBuilderParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
fieldName = builder.fieldName;
vector = builder.vector;
k = builder.k;
filter = builder.filter;
ignoreUnmapped = builder.ignoreUnmapped;
maxDistance = builder.maxDistance;
minScore = builder.minScore;
methodParameters = builder.methodParameters;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(maxDistance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(minScore);
}
if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) {
MethodParametersParser.streamOutput(out, methodParameters, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}
KNNQueryBuilderParser.streamOutput(out, this, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

/**
Expand All @@ -460,29 +327,7 @@ public Object vector() {

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.startObject(fieldName);

builder.field(VECTOR_FIELD.getPreferredName(), vector);
builder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (maxDistance != null) {
builder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (minScore != null) {
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (methodParameters != null) {
MethodParametersParser.doXContent(builder, methodParameters);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
KNNQueryBuilderParser.toXContent(builder, params, this);
}

@Override
Expand Down
Loading
Loading