diff --git a/.idea/copyright/SPDX_ALv2.xml b/.idea/copyright/SPDX_ALv2.xml
index a2485beef0..3475d15120 100644
--- a/.idea/copyright/SPDX_ALv2.xml
+++ b/.idea/copyright/SPDX_ALv2.xml
@@ -1,6 +1,6 @@
-
+
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 09b4c47f91..c90ea16213 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,4 +20,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
### Documentation
### Maintenance
-### Refactoring
\ No newline at end of file
+### Refactoring
+* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
diff --git a/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java b/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java
new file mode 100644
index 0000000000..1c5a3b8757
--- /dev/null
+++ b/micro-benchmarks/src/main/java/org/opensearch/knn/QueryParsingBenchmarks.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn;
+
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+import org.opensearch.cluster.ClusterModule;
+import org.opensearch.common.xcontent.LoggingDeprecationHandler;
+import org.opensearch.common.xcontent.XContentFactory;
+import org.opensearch.common.xcontent.json.JsonXContent;
+import org.opensearch.core.common.bytes.BytesArray;
+import org.opensearch.core.common.bytes.BytesReference;
+import org.opensearch.core.xcontent.NamedXContentRegistry;
+import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.index.query.QueryBuilder;
+import org.opensearch.index.query.QueryBuilders;
+import org.opensearch.index.query.TermQueryBuilder;
+import org.opensearch.knn.index.query.KNNQueryBuilder;
+import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
+import org.opensearch.plugins.SearchPlugin;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Benchmarks for impact of changes around query parsing
+ */
+@Warmup(iterations = 5, time = 10)
+@Measurement(iterations = 3, time = 10)
+@Fork(3)
+@State(Scope.Benchmark)
+public class QueryParsingBenchmarks {
+ private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
+ private static final NamedXContentRegistry NAMED_X_CONTENT_REGISTRY = xContentRegistry();
+
+ @Param({ "128", "1024" })
+ private int dimension;
+ @Param({ "basic", "filter" })
+ private String type;
+
+ private BytesReference bytesReference;
+
+ @Setup
+ public void setup() throws IOException {
+ XContentBuilder builder = XContentFactory.jsonBuilder();
+ builder.startObject();
+ builder.startObject("test");
+ builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), generateVectorWithOnes(dimension));
+ builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), 1);
+ if (type.equals("filter")) {
+ builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), TERM_QUERY);
+ }
+ builder.endObject();
+ builder.endObject();
+ bytesReference = BytesReference.bytes(builder);
+ }
+
+ @Benchmark
+ public void fromXContent(final Blackhole bh) throws IOException {
+ XContentParser xContentParser = createParser();
+ bh.consume(KNNQueryBuilderParser.fromXContent(xContentParser));
+ }
+
+ private XContentParser createParser() throws IOException {
+ XContentParser contentParser = createParser(bytesReference);
+ contentParser.nextToken();
+ return contentParser;
+ }
+
+ private float[] generateVectorWithOnes(final int dimensions) {
+ float[] vector = new float[dimensions];
+ Arrays.fill(vector, (float) 1);
+ return vector;
+ }
+
+ private XContentParser createParser(final BytesReference data) throws IOException {
+ BytesArray array = (BytesArray) data;
+ return JsonXContent.jsonXContent.createParser(
+ NAMED_X_CONTENT_REGISTRY,
+ LoggingDeprecationHandler.INSTANCE,
+ array.array(),
+ array.offset(),
+ array.length()
+ );
+ }
+
+ private static NamedXContentRegistry xContentRegistry() {
+ List list = ClusterModule.getNamedXWriteables();
+ SearchPlugin.QuerySpec> spec = new SearchPlugin.QuerySpec<>(
+ TermQueryBuilder.NAME,
+ TermQueryBuilder::new,
+ TermQueryBuilder::fromXContent
+ );
+ list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p)));
+ return new NamedXContentRegistry(list);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
index 63d05540ee..69039e7c3d 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
@@ -14,19 +14,15 @@
import org.apache.lucene.search.Query;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
-import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
-import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
-import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
-import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
@@ -34,7 +30,7 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
-import org.opensearch.knn.index.query.parser.MethodParametersParser;
+import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.util.QueryContext;
@@ -44,7 +40,6 @@
import java.io.IOException;
import java.util.Arrays;
-import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
@@ -55,8 +50,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
-import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
-import static org.opensearch.knn.index.IndexUtil.minimalRequiredVersionMap;
import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters;
import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;
import static org.opensearch.knn.validation.ParameterValidator.validateParameters;
@@ -79,6 +72,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder {
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
+
public static final int K_MAX = 10000;
/**
* The name for the knn query
@@ -142,7 +136,7 @@ public static class Builder {
private String queryName;
private float boost = DEFAULT_BOOST;
- private Builder() {}
+ public Builder() {}
public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
@@ -295,182 +289,26 @@ public static void initialize(ModelDao modelDao) {
KNNQueryBuilder.modelDao = modelDao;
}
- private static float[] ObjectsToFloats(List