Skip to content

Commit

Permalink
[Backport 2.x] Add ignore_unmapped support in KNNQueryBuilder (opense…
Browse files Browse the repository at this point in the history
…arch-project#1152)

* Add ignore_unmapped support in KNNQueryBuilder

Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored Oct 3, 2023
1 parent 0dd42ef commit 39836c6
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
### Enhancements
- Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
Expand All @@ -24,6 +25,7 @@

import java.io.File;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand All @@ -32,6 +34,15 @@

public class IndexUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_10_0;
public static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
}
};

/**
* Determines the size of a file on disk in kilobytes
*
Expand Down Expand Up @@ -195,4 +206,12 @@ public static Map<String, Object> getParametersAtLoading(SpaceType spaceType, KN

return Collections.unmodifiableMap(loadParameters);
}

public static boolean isClusterOnOrAfterMinRequiredVersion(String key) {
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}
}
57 changes: 45 additions & 12 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.commons.lang.StringUtils;
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand All @@ -33,6 +32,7 @@
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.*;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;

/**
Expand All @@ -45,6 +45,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -57,7 +58,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;
private boolean ignoreUnmapped = false;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -91,6 +92,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.vector = vector;
this.k = k;
this.filter = filter;
this.ignoreUnmapped = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -117,9 +119,12 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
k = in.readInt();
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -131,6 +136,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
int k = 0;
QueryBuilder filter = null;
boolean ignoreUnmapped = false;
String queryName = null;
String currentFieldName = null;
XContentParser.Token token;
Expand All @@ -153,6 +159,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals("ignore_unmapped")) {
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = parser.booleanValue();
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -168,20 +178,20 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
// MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable.
// Here we're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with rolling cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
filter = parseInnerQueryBuilder(parser);
} else {
log.debug(
String.format(
"This version of k-NN doesn't support [filter] field, minimal required version is [%s]",
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
minimalRequiredVersionMap.get("filter")
)
);
throw new IllegalArgumentException(
String.format(
"%s field is supported from version %s",
FILTER_FIELD.getPreferredName(),
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
minimalRequiredVersionMap.get("filter")
)
);
}
Expand All @@ -204,6 +214,9 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.queryName(queryName);
if (isClusterOnOrAfterMinRequiredVersion("ignoreUnmapped")) {
knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
}
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
}
Expand All @@ -215,9 +228,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeInt(k);
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
if (isClusterOnOrAfterMinRequiredVersion("filter")) {
out.writeOptionalNamedWriteable(filter);
}
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
}

/**
Expand All @@ -242,6 +258,20 @@ public QueryBuilder getFilter() {
return this.filter;
}

/**
* Sets whether the query builder should ignore unmapped paths (and run a
* {@link MatchNoDocsQuery} in place of this query) or throw an exception if
* the path is unmapped.
*/
public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand All @@ -252,6 +282,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand All @@ -261,6 +294,10 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
return new MatchNoDocsQuery();
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
}
Expand Down Expand Up @@ -345,8 +382,4 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;

import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -297,6 +299,18 @@ public void testSerialization() throws Exception {
assertSerialization(Version.V_2_3_0, Optional.empty());
}

public void testIgnoreUnmapped() throws IOException {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
knnQueryBuilder.ignoreUnmapped(true);
assertTrue(knnQueryBuilder.getIgnoreUnmapped());
Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
knnQueryBuilder.ignoreUnmapped(false);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class)));
}

private void assertSerialization(final Version version, final Optional<QueryBuilder> queryBuilderOptional) throws Exception {
final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent()
? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get())
Expand Down

0 comments on commit 39836c6

Please sign in to comment.