Skip to content

Commit

Permalink
Merge efficient filtering from feature branch (#588)
Browse files Browse the repository at this point in the history
* Adding efficient filtering 

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Oct 25, 2022
1 parent 3d0a9d7 commit f332ccb
Show file tree
Hide file tree
Showing 14 changed files with 718 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import org.hamcrest.MatcherAssert;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;

import org.opensearch.client.Request;
import org.opensearch.client.ResponseException;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;

import java.io.IOException;

import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.CoreMatchers.containsString;
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

/**
* Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine
*/
public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 50;
private static final int K = 10;
private static final int NUM_DOCS = 100;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100");

public void testLuceneFiltering() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0];
switch (getClusterType()) {
case OLD:
createKnnIndex(
testIndex,
getKNNDefaultIndexSettings(),
createKnnIndexMapping(TEST_FIELD, DIMENSIONS, METHOD_HNSW, LUCENE_NAME)
);
bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS);
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
case MIXED:
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
case UPGRADED:
searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
deleteKNNIndex(testIndex);
break;
}
}

private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(Strings.toString(builder));

Exception exception = expectThrows(ResponseException.class, () -> client().performRequest(request));
// assert for two possible exception messages, fist one can come from current version in case serialized request is coming from
// lower version,
// second exception is vise versa, when lower version node receives request with filter field from higher version
MatcherAssert.assertThat(
exception.getMessage(),
anyOf(
containsString("filter field is supported from version"),
containsString("[knn] unknown token [START_OBJECT] after [filter]")
)
);
}
}
58 changes: 58 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNClusterUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class KNNClusterUtil {

private ClusterService clusterService;
private static KNNClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized KNNClusterUtil instance() {
if (instance == null) {
instance = new KNNClusterUtil();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
try {
return this.clusterService.state().getNodes().getMinNodeVersion();
} catch (Exception exception) {
log.error(
String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT),
exception
);
return Version.CURRENT;
}
}
}
81 changes: 78 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -38,6 +41,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {

public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -49,6 +53,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final String fieldName;
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_3_0_0;

/**
* Constructs a new knn query
Expand All @@ -58,6 +64,10 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this(fieldName, vector, k, null);
}

public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("[" + NAME + "] requires fieldName");
}
Expand All @@ -77,6 +87,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this.fieldName = fieldName;
this.vector = vector;
this.k = k;
this.filter = filter;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -101,8 +112,13 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex);
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
}

Expand All @@ -111,6 +127,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
List<Object> vector = null;
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
int k = 0;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
XContentParser.Token token;
Expand Down Expand Up @@ -139,6 +156,35 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
// Query filters are supported starting from a certain k-NN version only, exact version is defined by
// MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER variable.
// Here we're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with rolling cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = parseInnerQueryBuilder(parser);
} else {
log.debug(
String.format(
"This version of k-NN doesn't support [filter] field, minimal required version is [%s]",
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
throw new IllegalArgumentException(
String.format(
"%s field is supported from version %s",
FILTER_FIELD.getPreferredName(),
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
}
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -153,7 +199,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter);
knnQueryBuilder.queryName(queryName);
knnQueryBuilder.boost(boost);
return knnQueryBuilder;
Expand All @@ -164,6 +210,11 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
// We're checking if all cluster nodes has at least that version or higher. This check is required
// to avoid issues with cluster upgrade
if (isClusterOnOrAfterMinRequiredVersion()) {
out.writeOptionalNamedWriteable(filter);
}
}

/**
Expand All @@ -184,13 +235,20 @@ public int getK() {
return this.k;
}

public QueryBuilder getFilter() {
return this.filter;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.startObject(fieldName);

builder.field(VECTOR_FIELD.getPreferredName(), vector);
builder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -225,8 +283,21 @@ protected Query doToQuery(QueryShardContext context) {
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

String indexName = context.index().getName();
return KNNQueryFactory.create(knnEngine, indexName, this.fieldName, this.vector, this.k);
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(this.vector)
.k(this.k)
.filter(this.filter)
.context(context)
.build();
return KNNQueryFactory.create(createQueryRequest);
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down Expand Up @@ -257,4 +328,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
}
}
Loading

0 comments on commit f332ccb

Please sign in to comment.