From 6093904aaaecb8446d05e2361e9ebe85b205e810 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 24 Oct 2022 17:26:41 -0700 Subject: [PATCH] Merge efficient filtering from feature branch (#588) * Adding efficient filtering Signed-off-by: Martin Gaievski (cherry picked from commit f332ccbde6152f823bb2803591aa983a8a33591d) --- .../opensearch/knn/bwc/LuceneFilteringIT.java | 86 +++++++++ .../opensearch/knn/index/KNNClusterUtil.java | 58 ++++++ .../knn/index/query/KNNQueryBuilder.java | 81 +++++++- .../knn/index/query/KNNQueryFactory.java | 79 +++++++- .../org/opensearch/knn/plugin/KNNPlugin.java | 2 + .../knn/plugin/stats/KNNCounter.java | 3 +- .../knn/plugin/stats/KNNStatsConfig.java | 4 + .../knn/plugin/stats/StatNames.java | 3 +- .../knn/index/KNNClusterTestUtils.java | 35 ++++ .../knn/index/KNNClusterUtilTests.java | 51 +++++ .../opensearch/knn/index/LuceneEngineIT.java | 108 +++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 175 ++++++++++++++++-- .../knn/index/query/KNNQueryFactoryTests.java | 31 ++++ .../plugin/action/RestKNNStatsHandlerIT.java | 20 ++ 14 files changed, 718 insertions(+), 18 deletions(-) create mode 100644 qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java create mode 100644 src/main/java/org/opensearch/knn/index/KNNClusterUtil.java create mode 100644 src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java create mode 100644 src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java new file mode 100644 index 000000000..3a7d0329d --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.bwc; + +import org.hamcrest.MatcherAssert; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; + +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; + +import java.io.IOException; + +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.containsString; +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +/** + * Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine + */ +public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 50; + private static final int K = 10; + private static final int NUM_DOCS = 100; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100"); + + public void testLuceneFiltering() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0]; + switch (getClusterType()) { + case OLD: + createKnnIndex( + testIndex, + getKNNDefaultIndexSettings(), + createKnnIndexMapping(TEST_FIELD, DIMENSIONS, METHOD_HNSW, LUCENE_NAME) + ); + bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS); + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case MIXED: + validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + break; + case UPGRADED: + searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); + deleteKNNIndex(testIndex); + break; + } + } + + private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + request.setJsonEntity(Strings.toString(builder)); + + Exception exception = expectThrows(ResponseException.class, () -> client().performRequest(request)); + // assert for two possible exception messages, fist one can come from current version in case serialized request is coming from + // lower version, + // second exception is vise versa, when lower version node receives request with filter field from higher version + MatcherAssert.assertThat( + exception.getMessage(), + anyOf( + containsString("filter field is supported from version"), + containsString("[knn] unknown token [START_OBJECT] after [filter]") + ) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNClusterUtil.java b/src/main/java/org/opensearch/knn/index/KNNClusterUtil.java new file mode 100644 index 000000000..63a49f095 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/KNNClusterUtil.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; + +/** + * Class abstracts information related to underlying OpenSearch cluster + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Log4j2 +public class KNNClusterUtil { + + private ClusterService clusterService; + private static KNNClusterUtil instance; + + /** + * Return instance of the cluster context, must be initialized first for proper usage + * @return instance of cluster context + */ + public static synchronized KNNClusterUtil instance() { + if (instance == null) { + instance = new KNNClusterUtil(); + } + return instance; + } + + /** + * Initializes instance of cluster context by injecting dependencies + * @param clusterService + */ + public void initialize(final ClusterService clusterService) { + this.clusterService = clusterService; + } + + /** + * Return minimal OpenSearch version based on all nodes currently discoverable in the cluster + * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version + */ + public Version getClusterMinVersion() { + try { + return this.clusterService.state().getNodes().getMinNodeVersion(); + } catch (Exception exception) { + log.error( + String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT), + exception + ); + return Version.CURRENT; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 1defe45e8..ebf721304 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -6,7 +6,10 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -38,6 +41,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField VECTOR_FIELD = new ParseField("vector"); public static final ParseField K_FIELD = new ParseField("k"); + public static final ParseField FILTER_FIELD = new ParseField("filter"); public static int K_MAX = 10000; /** * The name for the knn query @@ -49,6 +53,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; + private QueryBuilder filter; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_3_0_0; /** * Constructs a new knn query @@ -58,6 +64,10 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { * @param k K nearest neighbours for the given vector */ public KNNQueryBuilder(String fieldName, float[] vector, int k) { + this(fieldName, vector, k, null); + } + + public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); } @@ -77,6 +87,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) { this.fieldName = fieldName; this.vector = vector; this.k = k; + this.filter = filter; } public static void initialize(ModelDao modelDao) { @@ -101,8 +112,13 @@ public KNNQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); vector = in.readFloatArray(); k = in.readInt(); + // We're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with cluster upgrade + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = in.readOptionalNamedWriteable(QueryBuilder.class); + } } catch (IOException ex) { - throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex); + throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } } @@ -111,6 +127,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; int k = 0; + QueryBuilder filter = null; String queryName = null; String currentFieldName = null; XContentParser.Token token; @@ -139,6 +156,35 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } + } else if (token == XContentParser.Token.START_OBJECT) { + String tokenName = parser.currentName(); + if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + log.debug(String.format("Start parsing filter for field [%s]", fieldName)); + KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); + // Query filters are supported starting from a certain k-NN version only, exact version is defined by + // MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable. + // Here we're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with rolling cluster upgrade + if (isClusterOnOrAfterMinRequiredVersion()) { + filter = parseInnerQueryBuilder(parser); + } else { + log.debug( + String.format( + "This version of k-NN doesn't support [filter] field, minimal required version is [%s]", + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + throw new IllegalArgumentException( + String.format( + "%s field is supported from version %s", + FILTER_FIELD.getPreferredName(), + MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + ) + ); + } + } else { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); + } } else { throw new ParsingException( parser.getTokenLocation(), @@ -153,7 +199,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); knnQueryBuilder.queryName(queryName); knnQueryBuilder.boost(boost); return knnQueryBuilder; @@ -164,6 +210,11 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeFloatArray(vector); out.writeInt(k); + // We're checking if all cluster nodes has at least that version or higher. This check is required + // to avoid issues with cluster upgrade + if (isClusterOnOrAfterMinRequiredVersion()) { + out.writeOptionalNamedWriteable(filter); + } } /** @@ -184,6 +235,10 @@ public int getK() { return this.k; } + public QueryBuilder getFilter() { + return this.filter; + } + @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -191,6 +246,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.field(VECTOR_FIELD.getPreferredName(), vector); builder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + builder.field(FILTER_FIELD.getPreferredName(), filter); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -225,8 +283,21 @@ protected Query doToQuery(QueryShardContext context) { ); } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) { + throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); + } + String indexName = context.index().getName(); - return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k); + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .filter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { @@ -257,4 +328,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinRequiredVersion() { + return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index cbdb03ea8..c68ce9502 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,11 +5,21 @@ package org.opensearch.knn.index.query; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.util.KNNEngine; +import java.io.IOException; +import java.util.Optional; + /** * Creates the Lucene k-NN queries */ @@ -27,14 +37,81 @@ public class KNNQueryFactory { * @return Lucene Query */ public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .k(k) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) { + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final int k = createQueryRequest.getK(); + final float[] vector = createQueryRequest.getVector(); + + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); return new KNNQuery(fieldName, vector, k, indexName); } + if (createQueryRequest.getFilter().isPresent()) { + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + log.debug( + String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k) + ); + try { + final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); + return new KnnVectorQuery(fieldName, vector, k, filterQuery); + } catch (IOException e) { + throw new RuntimeException("Cannot create knn query with filter", e); + } + } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); return new KnnVectorQuery(fieldName, vector, k); } + + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Setter + static class CreateQueryRequest { + @Getter + @NonNull + private KNNEngine knnEngine; + @Getter + @NonNull + private String indexName; + @Getter + private String fieldName; + @Getter + private float[] vector; + @Getter + private int k; + // can be null in cases filter not passed with the knn query + private QueryBuilder filter; + // can be null in cases filter not passed with the knn query + private QueryShardContext context; + + public Optional getFilter() { + return Optional.ofNullable(filter); + } + + public Optional getContext() { + return Optional.ofNullable(context); + } + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index c2564f179..670294802 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -11,6 +11,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -179,6 +180,7 @@ public Collection createComponents( NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); KNNSettings.state().initialize(client, clusterService); + KNNClusterUtil.instance().initialize(clusterService); ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java index d933ce66d..ce04c9078 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java @@ -21,7 +21,8 @@ public enum KNNCounter { SCRIPT_QUERY_REQUESTS("script_query_requests"), SCRIPT_QUERY_ERRORS("script_query_errors"), TRAINING_REQUESTS("training_requests"), - TRAINING_ERRORS("training_errors"); + TRAINING_ERRORS("training_errors"), + KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"); private String name; private AtomicLong count; diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java index c41170b32..8769e0e46 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java @@ -55,6 +55,10 @@ public class KNNStatsConfig { .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier())) .put(StatNames.MODEL_INDEX_STATUS.getName(), new KNNStat<>(true, new ModelIndexStatusSupplier<>(ModelDao::getHealthStatus))) .put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) + .put( + StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS)) + ) .put(StatNames.SCRIPT_COMPILATIONS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATIONS))) .put( StatNames.SCRIPT_COMPILATION_ERRORS.getName(), diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index ffe5882bb..a098dd8b5 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -40,7 +40,8 @@ public enum StatNames { TRAINING_ERRORS(KNNCounter.TRAINING_ERRORS.getName()), TRAINING_MEMORY_USAGE("training_memory_usage"), TRAINING_MEMORY_USAGE_PERCENTAGE("training_memory_usage_percentage"), - SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()); + SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()), + KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); private String name; diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java new file mode 100644 index 000000000..6ded05d17 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Collection of util methods required for testing and related to OpenSearch cluster setup and functionality + */ +public class KNNClusterTestUtils { + + /** + * Create new mock for ClusterService + * @param version min version for cluster nodes + * @return + */ + public static ClusterService mockClusterService(final Version version) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getMinNodeVersion()).thenReturn(version); + return clusterService; + } +} diff --git a/src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java b/src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java new file mode 100644 index 000000000..0e00a7f75 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNClusterUtilTests.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.KNNTestCase; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; + +public class KNNClusterUtilTests extends KNNTestCase { + + public void testSingleNodeCluster() { + ClusterService clusterService = mockClusterService(Version.V_2_4_0); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + final Version minVersion = knnClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_4_0.equals(minVersion)); + } + + public void testMultipleNodesCluster() { + ClusterService clusterService = mockClusterService(Version.V_2_3_0); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + final Version minVersion = knnClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_3_0.equals(minVersion)); + } + + public void testWhenErrorOnClusterStateDiscover() { + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + final Version minVersion = knnClusterUtil.getClusterMinVersion(); + + assertTrue(Version.CURRENT.equals(minVersion)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 9fe42cb9f..a7f04cef4 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -17,12 +17,14 @@ import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; import java.io.IOException; import java.util.Arrays; @@ -33,14 +35,19 @@ import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; public class LuceneEngineIT extends KNNRestTestCase { private static final int DIMENSION = 3; private static final String DOC_ID = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; private static final int EF_CONSTRUCTION = 128; private static final String INDEX_NAME = "test-index-1"; private static final String FIELD_NAME = "test-field-1"; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; private static final int M = 16; private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; @@ -246,6 +253,107 @@ public void testDeleteDoc() throws Exception { assertEquals(0, getDocCount(INDEX_NAME)); } + public void testQueryWithFilter() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshAllIndices(); + + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } + + public void testQuery_filterWithNonLuceneEngine() throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, NMSLIB_NAME) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, mapping); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + final float[] searchVector = { 6.0f, 6.0f, 5.6f }; + int k = 5; + expectThrows( + ResponseException.class, + () -> searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + k + ) + ); + } + + private void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index c3d40cbc7..e3376dda9 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -5,6 +5,20 @@ package org.opensearch.knn.index.query; +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.Query; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -12,36 +26,50 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.plugins.SearchPlugin; import java.io.IOException; +import java.util.List; +import java.util.Optional; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; public class KNNQueryBuilderTests extends KNNTestCase { + private static final String FIELD_NAME = "myvector"; + private static final int K = 1; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; /** * -ve k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, -1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, -K)); /** * zero k */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 0)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, 0)); /** * k > KNNQueryBuilder.K_MAX */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K)); } public void testEmptyVector() { @@ -49,18 +77,18 @@ public void testEmptyVector() { * null query vector */ float[] queryVector = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, K)); /** * empty query vector */ float[] queryVector1 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector1, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); } public void testFromXcontent() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -74,9 +102,74 @@ public void testFromXcontent() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXcontent_WithFilter() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + actualBuilder.equals(knnQueryBuilder); + } + + public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { + final ClusterService clusterService = mockClusterService(Version.V_2_3_0); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + final XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + + expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser)); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + NamedXContentRegistry registry = new NamedXContentRegistry(list); + return registry; + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + final List entries = ClusterModule.getNamedWriteables(); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KNNQueryBuilder.NAME, KNNQueryBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + return new NamedWriteableRegistry(entries); + } + public void testDoToQuery_Normal() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -89,16 +182,33 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query instanceof KnnVectorQuery); + } + public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-1); + when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; when(mockKNNVectorField.getModelId()).thenReturn(modelId); @@ -120,7 +230,7 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -128,13 +238,13 @@ public void testDoToQuery_InvalidDimensions() { when(mockKNNVectorField.getDimension()).thenReturn(400); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(1); + when(mockKNNVectorField.getDimension()).thenReturn(K); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } public void testDoToQuery_InvalidFieldType() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, 1); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); NumberFieldMapper.NumberFieldType mockNumberField = mock(NumberFieldMapper.NumberFieldType.class); @@ -142,4 +252,45 @@ public void testDoToQuery_InvalidFieldType() throws IOException { when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockNumberField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + + public void testSerialization() throws Exception { + assertSerialization(Version.CURRENT, Optional.empty()); + + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY)); + + assertSerialization(Version.V_2_3_0, Optional.empty()); + } + + private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { + final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + + final ClusterService clusterService = mockClusterService(version); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + output.writeNamedWriteable(knnQueryBuilder); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + in.setVersion(Version.CURRENT); + final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + + assertNotNull(deserializedQuery); + assertTrue(deserializedQuery instanceof KNNQueryBuilder); + final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; + assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); + assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); + assertEquals(K, deserializedKnnQueryBuilder.getK()); + if (queryBuilderOptional.isPresent()) { + assertNotNull(deserializedKnnQueryBuilder.getFilter()); + assertEquals(queryBuilderOptional.get(), deserializedKnnQueryBuilder.getFilter()); + } else { + assertNull(deserializedKnnQueryBuilder.getFilter()); + } + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 06b0ce6ca..908ea1021 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -7,6 +7,10 @@ import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; @@ -14,6 +18,10 @@ import java.util.List; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class KNNQueryFactoryTests extends KNNTestCase { private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; @@ -42,4 +50,27 @@ public void testCreateLuceneDefaultQuery() { assertTrue(query instanceof KnnVectorQuery); } } + + public void testCreateLuceneQueryWithFilter() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(filter) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KnnVectorQuery); + } + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 2b6bb757e..0c51e77e1 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -22,6 +22,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -67,6 +68,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { private boolean isDebuggingRemoteCluster = System.getProperty("cluster.debug", "false").equals("true"); private static final String FIELD_NAME_2 = "test_field_two"; private static final String FIELD_NAME_3 = "test_field_three"; + private static final String FIELD_LUCENE_NAME = "lucene_test_field"; private static final int DIMENSION = 4; private static int DOC_ID = 0; private static final int NUM_DOCS = 10; @@ -106,6 +108,7 @@ public void testStatsValueCheck() throws IOException { Map nodeStats0 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount0 = (Integer) nodeStats0.get(StatNames.HIT_COUNT.getName()); Integer missCount0 = (Integer) nodeStats0.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount0 = (Integer) nodeStats0.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); // Setup index createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); @@ -124,9 +127,11 @@ public void testStatsValueCheck() throws IOException { Map nodeStats1 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount1 = (Integer) nodeStats1.get(StatNames.HIT_COUNT.getName()); Integer missCount1 = (Integer) nodeStats1.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount1 = (Integer) nodeStats1.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); assertEquals(hitCount0, hitCount1); assertEquals((Integer) (missCount0 + 1), missCount1); + assertEquals(knnQueryWithFilterCount0, knnQueryWithFilterCount1); // Second search: Ensure that hits=1 searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); @@ -137,9 +142,24 @@ public void testStatsValueCheck() throws IOException { Map nodeStats2 = parseNodeStatsResponse(responseBody).get(0); Integer hitCount2 = (Integer) nodeStats2.get(StatNames.HIT_COUNT.getName()); Integer missCount2 = (Integer) nodeStats2.get(StatNames.MISS_COUNT.getName()); + Integer knnQueryWithFilterCount2 = (Integer) nodeStats2.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); assertEquals(missCount1, missCount2); assertEquals((Integer) (hitCount1 + 1), hitCount2); + assertEquals(knnQueryWithFilterCount0, knnQueryWithFilterCount2); + + putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_LUCENE_NAME, 2, METHOD_HNSW, LUCENE_NAME)); + addKnnDoc(INDEX_NAME, "2", FIELD_LUCENE_NAME, vector); + + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_LUCENE_NAME, qvector, 1, QueryBuilders.termQuery("_id", "1")), 1); + + response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + responseBody = EntityUtils.toString(response.getEntity()); + + Map nodeStats3 = parseNodeStatsResponse(responseBody).get(0); + Integer knnQueryWithFilterCount3 = (Integer) nodeStats3.get(StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); + + assertEquals((Integer) (knnQueryWithFilterCount0 + 1), knnQueryWithFilterCount3); } /**