diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 1defe45e8..59b508d8f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -7,6 +7,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; @@ -38,6 +39,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField VECTOR_FIELD = new ParseField("vector"); public static final ParseField K_FIELD = new ParseField("k"); + public static final ParseField FILTER_FIELD = new ParseField("filter"); public static int K_MAX = 10000; /** * The name for the knn query @@ -49,6 +51,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; + private QueryBuilder filter; /** * Constructs a new knn query @@ -58,6 +61,10 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { * @param k K nearest neighbours for the given vector */ public KNNQueryBuilder(String fieldName, float[] vector, int k) { + this(fieldName, vector, k, null); + } + + public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); } @@ -77,6 +84,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) { this.fieldName = fieldName; this.vector = vector; this.k = k; + this.filter = filter; } public static void initialize(ModelDao modelDao) { @@ -111,6 +119,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; int k = 0; + QueryBuilder filter = null; String queryName = null; String currentFieldName = null; XContentParser.Token token; @@ -139,6 +148,14 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } + } else if (token == XContentParser.Token.START_OBJECT) { + String tokenName = parser.currentName(); + if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + filter = parseInnerQueryBuilder(parser); + } else { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); + } + } else { throw new ParsingException( parser.getTokenLocation(), @@ -153,7 +170,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); knnQueryBuilder.queryName(queryName); knnQueryBuilder.boost(boost); return knnQueryBuilder; @@ -226,7 +243,16 @@ protected Query doToQuery(QueryShardContext context) { } String indexName = context.index().getName(); - return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k); + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .filter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index cbdb03ea8..10c8edd1a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,11 +5,21 @@ package org.opensearch.knn.index.query; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.util.KNNEngine; +import java.io.IOException; +import java.util.Optional; + /** * Creates the Lucene k-NN queries */ @@ -27,14 +37,97 @@ public class KNNQueryFactory { * @return Lucene Query */ public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .k(k) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(CreateQueryRequest createQueryRequest) { // Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to // use the custom query type created by the plugin - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) { - log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + log.debug( + String.format( + "Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + return new KNNQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + createQueryRequest.getIndexName() + ); + } + + log.debug( + String.format( + "Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName(), + createQueryRequest.getK() + ) + ); + if (createQueryRequest.getFilter().isPresent()) { + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + try { + final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); + return new KnnVectorQuery( + createQueryRequest.getFieldName(), + createQueryRequest.getVector(), + createQueryRequest.getK(), + filterQuery + ); + } catch (IOException e) { + throw new RuntimeException("Cannot create knn query with filter", e); + } + } + return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK()); + } + + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Setter + static class CreateQueryRequest { + @Getter + @NonNull + private KNNEngine knnEngine; + @Getter + @NonNull + private String indexName; + @Getter + private String fieldName; + @Getter + private float[] vector; + @Getter + private int k; + // can be null in cases filter not passed with the knn query + private QueryBuilder filter; + // can be null in cases filter not passed with the knn query + private QueryShardContext context; + + public Optional getFilter() { + return Optional.ofNullable(filter); } - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnVectorQuery(fieldName, vector, k); + public Optional getContext() { + return Optional.ofNullable(context); + } } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 06b0ce6ca..908ea1021 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -7,6 +7,10 @@ import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; @@ -14,6 +18,10 @@ import java.util.List; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class KNNQueryFactoryTests extends KNNTestCase { private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; @@ -42,4 +50,27 @@ public void testCreateLuceneDefaultQuery() { assertTrue(query instanceof KnnVectorQuery); } } + + public void testCreateLuceneQueryWithFilter() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(filter) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KnnVectorQuery); + } + } }