diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 59b508d8f..aeefdbff4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -151,11 +151,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } else if (token == XContentParser.Token.START_OBJECT) { String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + log.debug(String.format("Start parsing filter for field [%s]", fieldName)); filter = parseInnerQueryBuilder(parser); } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } - } else { throw new ParsingException( parser.getTokenLocation(), @@ -201,6 +201,10 @@ public int getK() { return this.k; } + public QueryBuilder getFilter() { + return this.filter; + } + @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -208,6 +212,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.field(VECTOR_FIELD.getPreferredName(), vector); builder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + builder.field(FILTER_FIELD.getPreferredName(), filter); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -242,6 +249,10 @@ protected Query doToQuery(QueryShardContext context) { ); } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) { + throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); + } + String indexName = context.index().getName(); KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 10c8edd1a..c68ce9502 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -55,47 +55,31 @@ public static Query create(KNNEngine knnEngine, String indexName, String fieldNa public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final int k = createQueryRequest.getK(); + final float[] vector = createQueryRequest.getVector(); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { - log.debug( - String.format( - "Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", - createQueryRequest.getIndexName(), - createQueryRequest.getFieldName(), - createQueryRequest.getK() - ) - ); - return new KNNQuery( - createQueryRequest.getFieldName(), - createQueryRequest.getVector(), - createQueryRequest.getK(), - createQueryRequest.getIndexName() - ); + log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KNNQuery(fieldName, vector, k, indexName); } - log.debug( - String.format( - "Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", - createQueryRequest.getIndexName(), - createQueryRequest.getFieldName(), - createQueryRequest.getK() - ) - ); if (createQueryRequest.getFilter().isPresent()) { final QueryShardContext queryShardContext = createQueryRequest.getContext() .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + log.debug( + String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k) + ); try { final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - return new KnnVectorQuery( - createQueryRequest.getFieldName(), - createQueryRequest.getVector(), - createQueryRequest.getK(), - filterQuery - ); + return new KnnVectorQuery(fieldName, vector, k, filterQuery); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } } - return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK()); + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnVectorQuery(fieldName, vector, k); } /** diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index fb26b893b..d0e3bae20 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -17,12 +17,14 @@ import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; import java.io.IOException; import java.util.Arrays; @@ -33,14 +35,19 @@ import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; public class LuceneEngineIT extends KNNRestTestCase { private static final int DIMENSION = 3; private static final String DOC_ID = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; private static final int EF_CONSTRUCTION = 128; private static final String INDEX_NAME = "test-index-1"; private static final String FIELD_NAME = "test-field-1"; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; private static final int M = 16; private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; @@ -246,6 +253,107 @@ public void testDeleteDoc() throws Exception { assertEquals(0, getDocCount(INDEX_NAME)); } + public void testQueryWithFilter() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshAllIndices(); + + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } + + public void testQuery_filterWithNonLuceneEngine() throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, NMSLIB_NAME) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, mapping); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + final float[] searchVector = { 6.0f, 6.0f, 5.6f }; + int k = 5; + expectThrows( + ResponseException.class, + () -> searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + k + ) + ); + } + + private void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index c3d40cbc7..4ebcf9ec4 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -5,6 +5,14 @@ package org.opensearch.knn.index.query; +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.KnnVectorQuery; +import org.apache.lucene.search.Query; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -12,16 +20,22 @@ import org.opensearch.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.plugins.SearchPlugin; import java.io.IOException; +import java.util.List; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; public class KNNQueryBuilderTests extends KNNTestCase { @@ -74,6 +88,36 @@ public void testFromXcontent() throws Exception { actualBuilder.equals(knnQueryBuilder); } + public void testFromXcontent_WithFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + actualBuilder.equals(knnQueryBuilder); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + NamedXContentRegistry registry = new NamedXContentRegistry(list); + return registry; + } + public void testDoToQuery_Normal() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); @@ -89,6 +133,23 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value")); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query instanceof KnnVectorQuery); + } + public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);