From bf8d2a909c9a8db0b0dc6dfd1283c84d2b119c43 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 14 May 2024 16:18:13 -0700 Subject: [PATCH] Add stats for radial search (#1684) (#1701) (cherry picked from commit 9a52b2bcd4d7e0a05368d8d689b50971f44c6489) Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + .../opensearch/knn/index/VectorQueryType.java | 57 ++++++++ .../knn/index/query/KNNQueryBuilder.java | 37 +++-- .../knn/plugin/stats/KNNCounter.java | 6 +- .../opensearch/knn/plugin/stats/KNNStats.java | 20 +++ .../knn/plugin/stats/StatNames.java | 6 +- .../knn/index/VectorQueryTypeTests.java | 30 ++++ .../plugin/action/RestKNNStatsHandlerIT.java | 128 ++++++++++++++---- 8 files changed, 245 insertions(+), 40 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/VectorQueryType.java create mode 100644 src/test/java/org/opensearch/knn/index/VectorQueryTypeTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f3d670f17..ee01f2ac7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) +* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/VectorQueryType.java b/src/main/java/org/opensearch/knn/index/VectorQueryType.java new file mode 100644 index 000000000..4697a917e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/VectorQueryType.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.Getter; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.plugin.stats.KNNCounter; + +@Getter +public enum VectorQueryType { + K(KNNConstants.K) { + @Override + public KNNCounter getQueryStatCounter() { + return KNNCounter.KNN_QUERY_REQUESTS; + } + + @Override + public KNNCounter getQueryWithFilterStatCounter() { + return KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS; + } + }, + MIN_SCORE(KNNConstants.MIN_SCORE) { + @Override + public KNNCounter getQueryStatCounter() { + return KNNCounter.MIN_SCORE_QUERY_REQUESTS; + } + + @Override + public KNNCounter getQueryWithFilterStatCounter() { + return KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS; + } + }, + MAX_DISTANCE(KNNConstants.MAX_DISTANCE) { + @Override + public KNNCounter getQueryStatCounter() { + return KNNCounter.MAX_DISTANCE_QUERY_REQUESTS; + } + + @Override + public KNNCounter getQueryWithFilterStatCounter() { + return KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS; + } + }; + + private final String queryTypeName; + + VectorQueryType(String queryTypeName) { + this.queryTypeName = queryTypeName; + } + + public abstract KNNCounter getQueryStatCounter(); + + public abstract KNNCounter getQueryWithFilterStatCounter(); +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 3d3b0969f..88bcc84bc 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,42 +5,43 @@ package org.opensearch.knn.index.query; -import java.io.IOException; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.opensearch.core.common.Strings; import org.apache.commons.lang.StringUtils; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; -import org.opensearch.knn.plugin.stats.KNNCounter; -import org.opensearch.index.query.AbstractQueryBuilder; -import org.opensearch.index.query.QueryShardContext; -import static org.opensearch.knn.index.IndexUtil.*; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.IndexUtil.minimalRequiredVersionMap; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; /** @@ -246,7 +247,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String queryName = null; String currentFieldName = null; XContentParser.Token token; - KNNCounter.KNN_QUERY_REQUESTS.increment(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); @@ -283,7 +283,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); - KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); // Query filters are supported starting from a certain k-NN version only, exact version is defined by // MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable. // Here we're checking if all cluster nodes has at least that version or higher. This check is required @@ -322,7 +321,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validateSingleQueryType(k, maxDistance, minScore); + VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore); + vectorQueryType.getQueryStatCounter().increment(); + if (filter != null) { + vectorQueryType.getQueryWithFilterStatCounter().increment(); + } KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .boost(boost) @@ -580,21 +583,27 @@ public String getWriteableName() { return NAME; } - private static void validateSingleQueryType(Integer k, Float distance, Float score) { + private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) { int countSetFields = 0; + VectorQueryType vectorQueryType = null; if (k != null && k != 0) { countSetFields++; + vectorQueryType = VectorQueryType.K; } if (distance != null) { countSetFields++; + vectorQueryType = VectorQueryType.MAX_DISTANCE; } if (score != null) { countSetFields++; + vectorQueryType = VectorQueryType.MIN_SCORE; } if (countSetFields != 1) { throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); } + + return vectorQueryType; } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java index ce04c9078..3bcc3399c 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java @@ -22,7 +22,11 @@ public enum KNNCounter { SCRIPT_QUERY_ERRORS("script_query_errors"), TRAINING_REQUESTS("training_requests"), TRAINING_ERRORS("training_errors"), - KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"); + KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"), + MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"), + MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"), + MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"), + MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests"); private String name; private AtomicLong count; diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 07d129652..3ddc8d4b4 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -90,12 +90,32 @@ private Map> buildStatsMap() { } private void addQueryStats(ImmutableMap.Builder> builder) { + // KNN Query Stats builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) .put( StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS)) ); + // Min Score Query Stats + builder.put( + StatNames.MIN_SCORE_QUERY_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS)) + ) + .put( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS)) + ); + + // Max Distance Query Stats + builder.put( + StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS)) + ) + .put( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS)) + ); } private void addNativeMemoryStats(ImmutableMap.Builder> builder) { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index e9ed2b126..e7f4fd4a2 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -44,7 +44,11 @@ public enum StatNames { KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()), GRAPH_STATS("graph_stats"), REFRESH("refresh"), - MERGE("merge"); + MERGE("merge"), + MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()), + MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()), + MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()), + MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()); private String name; diff --git a/src/test/java/org/opensearch/knn/index/VectorQueryTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorQueryTypeTests.java new file mode 100644 index 000000000..d0fac3f59 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorQueryTypeTests.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.plugin.stats.KNNCounter; + +public class VectorQueryTypeTests extends KNNTestCase { + + public void testGetQueryStatCounter() { + assertEquals(KNNCounter.KNN_QUERY_REQUESTS, VectorQueryType.K.getQueryStatCounter()); + assertEquals(KNNCounter.MIN_SCORE_QUERY_REQUESTS, VectorQueryType.MIN_SCORE.getQueryStatCounter()); + assertEquals(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryStatCounter()); + } + + public void testGetQueryWithFilterStatCounter() { + assertEquals(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.K.getQueryWithFilterStatCounter()); + assertEquals(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MIN_SCORE.getQueryWithFilterStatCounter()); + assertEquals(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryWithFilterStatCounter()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 7d11f2e4a..9f28d5a71 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.action; +import lombok.SneakyThrows; import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -16,41 +17,24 @@ import org.opensearch.cluster.health.ClusterHealthStatus; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.core.rest.RestStatus; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.opensearch.knn.TestUtils.KNN_VECTOR; -import static org.opensearch.knn.TestUtils.PROPERTIES; -import static org.opensearch.knn.TestUtils.VECTOR_TYPE; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.NAME; -import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import java.util.*; + +import static org.opensearch.knn.TestUtils.*; +import static org.opensearch.knn.common.KNNConstants.*; /** * Integration tests to check the correctness of RestKNNStatsHandler @@ -432,6 +416,95 @@ public void testFieldsByEngineModelTraining() throws Exception { assertTrue(faissField); } + public void testRadialSearchStats_thenSucceed() throws Exception { + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME)); + Float[] vector = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); + + // First search: radial search by min score + XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMinScore.startObject("knn"); + queryBuilderMinScore.startObject(FIELD_NAME); + queryBuilderMinScore.field("vector", vector); + queryBuilderMinScore.field(MIN_SCORE, 0.95f); + queryBuilderMinScore.endObject(); + queryBuilderMinScore.endObject(); + queryBuilderMinScore.endObject().endObject(); + + Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1); + Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + + assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch); + + // Second search: radial search by min score with filter + XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMinScoreWithFilter.startObject("knn"); + queryBuilderMinScoreWithFilter.startObject(FIELD_NAME); + queryBuilderMinScoreWithFilter.field("vector", vector); + queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f); + queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1")); + queryBuilderMinScoreWithFilter.endObject(); + queryBuilderMinScoreWithFilter.endObject(); + queryBuilderMinScoreWithFilter.endObject().endObject(); + + Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1); + Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + + assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch); + assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch); + + // Third search: radial search by max distance + XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMaxDistance.startObject("knn"); + queryBuilderMaxDistance.startObject(FIELD_NAME); + queryBuilderMaxDistance.field("vector", vector); + queryBuilderMaxDistance.field(MAX_DISTANCE, 100f); + queryBuilderMaxDistance.endObject(); + queryBuilderMaxDistance.endObject(); + queryBuilderMaxDistance.endObject().endObject(); + + Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0); + Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + + assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch); + + // Fourth search: radial search by max distance with filter + XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMaxDistanceWithFilter.startObject("knn"); + queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME); + queryBuilderMaxDistanceWithFilter.field("vector", vector); + queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f); + queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1")); + queryBuilderMaxDistanceWithFilter.endObject(); + queryBuilderMaxDistanceWithFilter.endObject(); + queryBuilderMaxDistanceWithFilter.endObject().endObject(); + + Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0); + Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + + assertEquals( + 1, + maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch + ); + assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch); + } + public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -487,4 +560,11 @@ protected Settings restClientSettings() { return super.restClientSettings(); } } + + @SneakyThrows + private Integer getStatCount(String statName) { + Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName); + } }