Skip to content

Commit

Permalink
Support distance type radius search for Lucene engine (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#1498)

* Optimize Faiss Query With Filters: Reduce iteration and memory for id filter (opensearch-project#1402)

* Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator

Signed-off-by: luyuncheng <[email protected]>

* Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator.
Use Bitmap And Batch to do id filter. and you sparse or fixed bitset do exact ANN search

Signed-off-by: luyuncheng <[email protected]>

* Using int64_t instead of long type for GetLongArrayElements

Signed-off-by: luyuncheng <[email protected]>

* Add IDSelectorJlongBitmap

Signed-off-by: luyuncheng <[email protected]>

* 1. Add IDSelectorJlongBitmap and UT for it
2. Move FilterIdsSelectorType to a util class

Signed-off-by: luyuncheng <[email protected]>

* 1. Add IDSelectorJlongBitmap and UT for it
2. Move FilterIdsSelectorType to a util class
3. Spotless apply

Signed-off-by: luyuncheng <[email protected]>

* Rebase remote-tracking branch 'origin/main' into Filter

Signed-off-by: luyuncheng <[email protected]>

* tidy

Signed-off-by: luyuncheng <[email protected]>

* Add Changelog

Signed-off-by: luyuncheng <[email protected]>

* fix javadoc tasks

Signed-off-by: luyuncheng <[email protected]>

* fix bwc javadoc

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector

Signed-off-by: luyuncheng <[email protected]>

* Rebase faiss_wrapper.cpp

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector as Byte.SIZE

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For comments

Signed-off-by: luyuncheng <[email protected]>

---------

Signed-off-by: luyuncheng <[email protected]>

* Increment 2.12.0-SNAPSHOT to 2.13.0-SNAPSHOT in BWC workflow (opensearch-project#1505)

Signed-off-by: Varun Jain <[email protected]>

* Manually install zlib for win CI (opensearch-project#1513)

Signed-off-by: John Mazanec <[email protected]>

* Upgrade faiss to 12b92e9 (opensearch-project#1509)

Upgrades faiss to facebookresearch/faiss@12b92e9. Cleanup outdated patches.

Signed-off-by: John Mazanec <[email protected]>

* Disable sdc table for HNSWPQ read-only indices (opensearch-project#1518)

Passes flag to disable sdc table for the HNSWPQ indices. This table is
only used by HNSWPQ during graph creation to compare nodes already
present in graph. When we call load index, the graph is read only.
Hence, we wont be doing any ingestion and so the table can be disabled
to save some memory.

Along with this, added a unit test and a couple test helper methods for
generating random data.

Signed-off-by: John Mazanec <[email protected]>

* Support distance type radius search for Lucene engine

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve feedback

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve feedback

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve comments

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve comments

Signed-off-by: Junqiu Lei <[email protected]>

* Add RNNQueryFactory class

Signed-off-by: Junqiu Lei <[email protected]>

* Add javadoc

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve feedback

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve feedback

Signed-off-by: Junqiu Lei <[email protected]>

* Resolve feedback

Signed-off-by: Junqiu Lei <[email protected]>

---------

Signed-off-by: luyuncheng <[email protected]>
Signed-off-by: Varun Jain <[email protected]>
Signed-off-by: John Mazanec <[email protected]>
Signed-off-by: Junqiu Lei <[email protected]>
Co-authored-by: luyuncheng <[email protected]>
Co-authored-by: Varun Jain <[email protected]>
Co-authored-by: John Mazanec <[email protected]>
  • Loading branch information
4 people committed Mar 14, 2024
1 parent bfcf7dc commit 1cd948e
Show file tree
Hide file tree
Showing 20 changed files with 863 additions and 137 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x)
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502)
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class KNNConstants {
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";

Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_13_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
}
};

Expand Down
95 changes: 95 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Optional;

/**
* Base class for creating vector search queries.
*/
@Log4j2
public abstract class BaseQueryFactory {
/**
* DTO object to hold data required to create a Query instance.
*/
@AllArgsConstructor
@Builder
@Getter
public static class CreateQueryRequest {
@NonNull
private KNNEngine knnEngine;
@NonNull
private String indexName;
private String fieldName;
private float[] vector;
private byte[] byteVector;
private VectorDataType vectorDataType;
private Integer k;
private Float radius;
private QueryBuilder filter;
private QueryShardContext context;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
}

public Optional<QueryShardContext> getContext() {
return Optional.ofNullable(context);
}
}

/**
* Creates a query filter.
*
* @param createQueryRequest request object that has all required fields to construct the query
* @return Lucene Query
*/
protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
if (!createQueryRequest.getFilter().isPresent()) {
return null;
}

final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(
String.format(
"Creating query with filter for index [%s], field [%s]",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName()
)
);
final Query filterQuery;
try {
filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
} catch (IOException e) {
throw new RuntimeException("Cannot create query with filter", e);
}
BitSetProducer parentFilter = queryShardContext.getParentFilter();
if (parentFilter != null) {
boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery);
if (mightMatch) {
return filterQuery;
}
return new ToChildBlockJoinQuery(filterQuery, parentFilter);
}
return filterQuery;
}
}
162 changes: 144 additions & 18 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.io.IOException;
import java.util.Arrays;

import java.util.List;
import java.util.Objects;
import lombok.extern.log4j.Log4j2;
Expand All @@ -24,6 +25,7 @@
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
Expand All @@ -35,6 +37,8 @@

import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH;

/**
* Helper class to build the KNN query
Expand All @@ -47,6 +51,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -58,11 +63,74 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final String fieldName;
private final float[] vector;
private int k = 0;
private Float distance = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

/**
* Constructs a new knn query
* Constructs a new query with the given field name and vector
*
* @param fieldName Name of the field
* @param vector Array of floating points
*/
public KNNQueryBuilder(String fieldName, float[] vector) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("[" + NAME + "] requires fieldName");
}
if (vector == null) {
throw new IllegalArgumentException("[" + NAME + "] requires query vector");
}
if (vector.length == 0) {
throw new IllegalArgumentException("[" + NAME + "] query vector is empty");
}
this.fieldName = fieldName;
this.vector = vector;
}

/**
* Builder method for k
*
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder k(int k) {
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
if (distance != null) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.k = k;
return this;
}

/**
* Builder method for distance
*
* @param distance the distance threshold for the nearest neighbours
*/
public KNNQueryBuilder distance(Float distance) {
if (distance == null || distance < 0) {
throw new IllegalArgumentException("[" + NAME + "] requires distance >= 0");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.distance = distance;
return this;
}

/**
* Builder method for filter
*
* @param filter QueryBuilder
*/
public KNNQueryBuilder filter(QueryBuilder filter) {
this.filter = filter;
return this;
}

/**
* Constructs a new query for top k search
*
* @param fieldName Name of the filed
* @param vector Array of floating points
Expand Down Expand Up @@ -94,6 +162,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -128,6 +197,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -137,7 +209,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String fieldName = null;
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
int k = 0;
Integer k = null;
Float distance = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -166,6 +239,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -195,10 +270,21 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
if ((k != null && distance != null) || (k == null && distance == null)) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
.boost(boost)
.queryName(queryName);

if (k != null) {
knnQueryBuilder.k(k);
} else {
knnQueryBuilder.distance(distance);
}

return knnQueryBuilder;
}

Expand All @@ -211,6 +297,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
}
}

/**
Expand All @@ -231,6 +320,10 @@ public int getK() {
return this.k;
}

public float getDistance() {
return this.distance;
}

public QueryBuilder getFilter() {
return this.filter;
}
Expand Down Expand Up @@ -259,6 +352,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (distance != null) {
builder.field(DISTANCE_FIELD.getPreferredName(), distance);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
Expand Down Expand Up @@ -298,6 +394,14 @@ protected Query doToQuery(QueryShardContext context) {
} else if (knnMethodContext != null) {
// If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping
knnEngine = knnMethodContext.getKnnEngine();
spaceType = knnMethodContext.getSpaceType();
}

// Currently, k-NN supports distance type radius search.
// We need transform distance radius to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

if (fieldDimension != vector.length) {
Expand Down Expand Up @@ -325,18 +429,40 @@ protected Query doToQuery(QueryShardContext context) {
}

String indexName = context.index().getName();
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null)
.byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null)
.vectorDataType(vectorDataType)
.k(this.k)
.filter(this.filter)
.context(context)
.build();
return KNNQueryFactory.create(createQueryRequest);

if (k != 0) {
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null)
.byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null)
.vectorDataType(vectorDataType)
.k(this.k)
.filter(this.filter)
.context(context)
.build();
return KNNQueryFactory.create(createQueryRequest);
}
if (radius != null) {
if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) {
throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine));
}
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null)
.byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null)
.vectorDataType(vectorDataType)
.radius(radius)
.filter(this.filter)
.context(context)
.radius(radius)
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down
Loading

0 comments on commit 1cd948e

Please sign in to comment.