Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge efficient filtering from feature branch #588

Merged
merged 8 commits into from
Oct 25, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import org.hamcrest.MatcherAssert;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;

import org.opensearch.client.Request;
import org.opensearch.client.ResponseException;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;

import java.io.IOException;

import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.CoreMatchers.containsString;
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

/**
* Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine
*/
public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 50;
private static final int K = 10;
private static final int NUM_DOCS = 100;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100");

public void testLuceneFiltering() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0];
switch (getClusterType()) {
case OLD:
createKnnIndex(
testIndex,
getKNNDefaultIndexSettings(),
createKnnIndexMapping(TEST_FIELD, DIMENSIONS, METHOD_HNSW, LUCENE_NAME)
);
bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS);
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
case MIXED:
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when we start upgrading from 2.4 to 2.5 or 3.x?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll need to disable this test for higher versions similarly to what we're doing for some other IT, this will work for cases when previous version doesn't have filtering and next does have it

break;
case UPGRADED:
searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
deleteKNNIndex(testIndex);
break;
}
}

private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(Strings.toString(builder));

Exception exception = expectThrows(ResponseException.class, () -> client().performRequest(request));
// assert for two possible exception messages, fist one can come from current version in case serialized request is coming from
// lower version,
// second exception is vise versa, when lower version node receives request with filter field from higher version
MatcherAssert.assertThat(
exception.getMessage(),
anyOf(
containsString("filter field is supported from version"),
containsString("[knn] unknown token [START_OBJECT] after [filter]")
)
);
}
}
58 changes: 58 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNClusterUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class KNNClusterUtil {

private ClusterService clusterService;
private static KNNClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized KNNClusterUtil instance() {
if (instance == null) {
instance = new KNNClusterUtil();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
try {
return this.clusterService.state().getNodes().getMinNodeVersion();
} catch (Exception exception) {
log.error(
String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT),
exception
);
return Version.CURRENT;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
@AllArgsConstructor
@Log4j2
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
defaultMaxConnections,
defaultBeamWidth
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
);
return formatSupplier.apply(maxConnections, beamWidth);
}

private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate;

/**
* Extends the Codec to support a new file format for KNN index
* based on the mappings.
*
*/
public final class KNN910Codec extends FilterCodec {

private static final String KNN910 = "KNN910Codec";
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_1_0;
private final KNNFormatFacade knnFormatFacade;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN910Codec() {
this(createKNN91DefaultDelegate());
this(VERSION.getDefaultCodecDelegate());
}

/**
Expand All @@ -36,8 +33,8 @@ public KNN910Codec() {
* @param delegate codec that will perform all operations this codec does not override
*/
public KNN910Codec(Codec delegate) {
super(KNN910, delegate);
knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,23 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate;

/**
* KNN codec that is based on Lucene92 codec
*/
@Log4j2
public final class KNN920Codec extends FilterCodec {

private static final String KNN920 = "KNN920Codec";

private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN920Codec() {
this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty()));
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -45,8 +39,8 @@ public KNN920Codec() {
*/
@Builder
public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN920, delegate);
knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,24 @@

package org.opensearch.knn.index.codec.KNN920Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH
)
);
return new Lucene92HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene92HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Loading