diff --git a/CHANGELOG.md b/CHANGELOG.md index 1075b68c4..f0413a82c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x) ### Features ### Enhancements +- Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 60156b4a7..f9809ba70 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.common.ValidationException; @@ -24,6 +25,7 @@ import java.io.File; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; @@ -32,6 +34,15 @@ public class IndexUtil { + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_10_0; + public static final Map minimalRequiredVersionMap = new HashMap() { + { + put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); + } + }; + /** * Determines the size of a file on disk in kilobytes * @@ -195,4 +206,12 @@ public static Map getParametersAtLoading(SpaceType spaceType, KN return Collections.unmodifiableMap(loadParameters); } + + public static boolean isClusterOnOrAfterMinRequiredVersion(String key) { + Version minimalRequiredVersion = minimalRequiredVersionMap.get(key); + if (minimalRequiredVersion == null) { + return false; + } + return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index b69b3dbfb..11912870f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -6,11 +6,10 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.commons.lang.StringUtils; -import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -33,6 +32,7 @@ import java.util.List; import java.util.Objects; +import static org.opensearch.knn.index.IndexUtil.*; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; /** @@ -45,6 +45,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField VECTOR_FIELD = new ParseField("vector"); public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); + public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); public static int K_MAX = 10000; /** * The name for the knn query @@ -57,7 +58,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final float[] vector; private int k = 0; private QueryBuilder filter; - private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0; + private boolean ignoreUnmapped = false; /** * Constructs a new knn query @@ -91,6 +92,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.vector = vector; this.k = k; this.filter = filter; + this.ignoreUnmapped = false; } public static void initialize(ModelDao modelDao) { @@ -117,9 +119,12 @@ public KNNQueryBuilder(StreamInput in) throws IOException { k = in.readInt(); // We're checking if all cluster nodes has at least that version or higher. This check is required // to avoid issues with cluster upgrade - if (isClusterOnOrAfterMinRequiredVersion()) { + if (isClusterOnOrAfterMinRequiredVersion("filter")) { filter = in.readOptionalNamedWriteable(QueryBuilder.class); } + if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { + ignoreUnmapped = in.readOptionalBoolean(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -131,6 +136,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep float boost = AbstractQueryBuilder.DEFAULT_BOOST; int k = 0; QueryBuilder filter = null; + boolean ignoreUnmapped = false; String queryName = null; String currentFieldName = null; XContentParser.Token token; @@ -153,6 +159,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals("ignore_unmapped")) { + if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { + ignoreUnmapped = parser.booleanValue(); + } } else { throw new ParsingException( parser.getTokenLocation(), @@ -168,20 +178,20 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep // MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable. // Here we're checking if all cluster nodes has at least that version or higher. This check is required // to avoid issues with rolling cluster upgrade - if (isClusterOnOrAfterMinRequiredVersion()) { + if (isClusterOnOrAfterMinRequiredVersion("filter")) { filter = parseInnerQueryBuilder(parser); } else { log.debug( String.format( "This version of k-NN doesn't support [filter] field, minimal required version is [%s]", - MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + minimalRequiredVersionMap.get("filter") ) ); throw new IllegalArgumentException( String.format( "%s field is supported from version %s", FILTER_FIELD.getPreferredName(), - MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER + minimalRequiredVersionMap.get("filter") ) ); } @@ -204,6 +214,9 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); knnQueryBuilder.queryName(queryName); + if (isClusterOnOrAfterMinRequiredVersion("ignoreUnmapped")) { + knnQueryBuilder.ignoreUnmapped(ignoreUnmapped); + } knnQueryBuilder.boost(boost); return knnQueryBuilder; } @@ -215,9 +228,12 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeInt(k); // We're checking if all cluster nodes has at least that version or higher. This check is required // to avoid issues with cluster upgrade - if (isClusterOnOrAfterMinRequiredVersion()) { + if (isClusterOnOrAfterMinRequiredVersion("filter")) { out.writeOptionalNamedWriteable(filter); } + if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { + out.writeOptionalBoolean(ignoreUnmapped); + } } /** @@ -242,6 +258,20 @@ public QueryBuilder getFilter() { return this.filter; } + /** + * Sets whether the query builder should ignore unmapped paths (and run a + * {@link MatchNoDocsQuery} in place of this query) or throw an exception if + * the path is unmapped. + */ + public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) { + this.ignoreUnmapped = ignoreUnmapped; + return this; + } + + public boolean getIgnoreUnmapped() { + return this.ignoreUnmapped; + } + @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -252,6 +282,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } + if (ignoreUnmapped) { + builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -261,6 +294,10 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio protected Query doToQuery(QueryShardContext context) { MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName); + if (mappedFieldType == null && ignoreUnmapped) { + return new MatchNoDocsQuery(); + } + if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) { throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName)); } @@ -345,8 +382,4 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } - - private static boolean isClusterOnOrAfterMinRequiredVersion() { - return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); - } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 236cbd644..1381e19be 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; @@ -41,6 +42,7 @@ import java.util.List; import java.util.Optional; +import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -297,6 +299,18 @@ public void testSerialization() throws Exception { assertSerialization(Version.V_2_3_0, Optional.empty()); } + public void testIgnoreUnmapped() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + knnQueryBuilder.ignoreUnmapped(true); + assertTrue(knnQueryBuilder.getIgnoreUnmapped()); + Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class)); + assertNotNull(query); + assertThat(query, instanceOf(MatchNoDocsQuery.class)); + knnQueryBuilder.ignoreUnmapped(false); + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class))); + } + private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get())