Skip to content

Commit

Permalink
Adding more tests and logs (#538)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 14, 2022
1 parent 4c8cf93 commit 47b9ad4
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
} else if (token == XContentParser.Token.START_OBJECT) {
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}

} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -201,13 +201,20 @@ public int getK() {
return this.k;
}

public QueryBuilder getFilter() {
return this.filter;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.startObject(fieldName);

builder.field(VECTOR_FIELD.getPreferredName(), vector);
builder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
builder.field(FILTER_FIELD.getPreferredName(), filter);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -242,6 +249,10 @@ protected Query doToQuery(QueryShardContext context) {
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

String indexName = context.index().getName();
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
Expand Down
42 changes: 13 additions & 29 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,47 +55,31 @@ public static Query create(KNNEngine knnEngine, String indexName, String fieldNa
public static Query create(CreateQueryRequest createQueryRequest) {
// Engines that create their own custom segment files cannot use the Lucene's KnnVectorQuery. They need to
// use the custom query type created by the plugin
final String indexName = createQueryRequest.getIndexName();
final String fieldName = createQueryRequest.getFieldName();
final int k = createQueryRequest.getK();
final float[] vector = createQueryRequest.getVector();

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
log.debug(
String.format(
"Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
return new KNNQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
createQueryRequest.getIndexName()
);
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
}

log.debug(
String.format(
"Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d",
createQueryRequest.getIndexName(),
createQueryRequest.getFieldName(),
createQueryRequest.getK()
)
);
if (createQueryRequest.getFilter().isPresent()) {
final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(
String.format("Creating Lucene k-NN query with filter for index [%s], field [%s] and k [%d]", indexName, fieldName, k)
);
try {
final Query filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
return new KnnVectorQuery(
createQueryRequest.getFieldName(),
createQueryRequest.getVector(),
createQueryRequest.getK(),
filterQuery
);
return new KnnVectorQuery(fieldName, vector, k, filterQuery);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
}
}
return new KnnVectorQuery(createQueryRequest.getFieldName(), createQueryRequest.getVector(), createQueryRequest.getK());
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnVectorQuery(fieldName, vector, k);
}

/**
Expand Down
108 changes: 108 additions & 0 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -33,14 +35,19 @@
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;

public class LuceneEngineIT extends KNNRestTestCase {

private static final int DIMENSION = 3;
private static final String DOC_ID = "doc1";
private static final String DOC_ID_2 = "doc2";
private static final String DOC_ID_3 = "doc3";
private static final int EF_CONSTRUCTION = 128;
private static final String INDEX_NAME = "test-index-1";
private static final String FIELD_NAME = "test-field-1";
private static final String COLOR_FIELD_NAME = "color";
private static final String TASTE_FIELD_NAME = "taste";
private static final int M = 16;

private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } };
Expand Down Expand Up @@ -246,6 +253,107 @@ public void testDeleteDoc() throws Exception {
assertEquals(0, getDocCount(INDEX_NAME));
}

public void testQueryWithFilter() throws Exception {
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2);

addKnnDocWithAttributes(
DOC_ID,
new float[] { 6.0f, 7.9f, 3.1f },
ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshAllIndices();

final float[] searchVector = { 6.0f, 6.0f, 4.1f };
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3);
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);
}

public void testQuery_filterWithNonLuceneEngine() throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, DIMENSION)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, NMSLIB_NAME)
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = Strings.toString(builder);
createKnnIndex(INDEX_NAME, mapping);

addKnnDocWithAttributes(
DOC_ID,
new float[] { 6.0f, 7.9f, 3.1f },
ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

final float[] searchVector = { 6.0f, 6.0f, 5.6f };
int k = 5;
expectThrows(
ResponseException.class,
() -> searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, k, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
k
)
);
}

private void addKnnDocWithAttributes(String docId, float[] vector, Map<String, String> fieldValues) throws IOException {
Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true");

XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + INDEX_NAME + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,37 @@

package org.opensearch.knn.index.query;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.cluster.ClusterModule;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.Index;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.List;

import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

public class KNNQueryBuilderTests extends KNNTestCase {

Expand Down Expand Up @@ -74,6 +88,36 @@ public void testFromXcontent() throws Exception {
actualBuilder.equals(knnQueryBuilder);
}

public void testFromXcontent_WithFilter() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value"));
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(knnQueryBuilder.fieldName());
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector());
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK());
builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter());
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser);
actualBuilder.equals(knnQueryBuilder);
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
SearchPlugin.QuerySpec<?> spec = new SearchPlugin.QuerySpec<>(
TermQueryBuilder.NAME,
TermQueryBuilder::new,
TermQueryBuilder::fromXContent
);
list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p)));
NamedXContentRegistry registry = new NamedXContentRegistry(list);
return registry;
}

public void testDoToQuery_Normal() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
Expand All @@ -89,6 +133,23 @@ public void testDoToQuery_Normal() throws Exception {
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
}

public void testDoToQuery_KnnQueryWithFilter() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value"));
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of());
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
Query query = knnQueryBuilder.doToQuery(mockQueryShardContext);
assertNotNull(query);
assertTrue(query instanceof KnnVectorQuery);
}

public void testDoToQuery_FromModel() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
Expand Down

0 comments on commit 47b9ad4

Please sign in to comment.