diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java index 223f9d09..6654dc34 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.EmptyDocValuesProducer; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; @@ -42,9 +43,13 @@ public BinaryDocValues getBinary(FieldInfo field) { try { List subs = new ArrayList<>(this.mergeState.docValuesProducers.length); for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) { + BinaryDocValues values = null; DocValuesProducer docValuesProducer = mergeState.docValuesProducers[i]; if (docValuesProducer != null) { - BinaryDocValues values = docValuesProducer.getBinary(field); + FieldInfo readerFieldInfo = mergeState.fieldInfos[i].fieldInfo(field.name); + if (readerFieldInfo != null && readerFieldInfo.getDocValuesType() == DocValuesType.BINARY) { + values = docValuesProducer.getBinary(readerFieldInfo); + } if (values != null) { subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values)); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/KNNRestTestCase.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/KNNRestTestCase.java index fe6aa8f1..4b34f637 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/KNNRestTestCase.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/KNNRestTestCase.java @@ -251,10 +251,36 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[] XContentBuilder builder = XContentFactory.jsonBuilder().startObject() .field(fieldName, vector) .endObject(); + request.setJsonEntity(Strings.toString(builder)); + Response response = client().performRequest(request); + + request = new Request( + "POST", + "/" + index + "/_refresh" + ); + response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Add a single numeric field Doc to an index + */ + protected void addDocWithNumericField(String index, String docId, String fieldName, int value) throws IOException { + Request request = new Request( + "POST", + "/" + index + "/_doc/" + docId + "?refresh=true" + ); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject() + .field(fieldName, value) + .endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); + + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java index 2a02a140..6a71d34c 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java @@ -67,6 +67,29 @@ public void testKNNResultsWithForceMerge() throws Exception { } } + public void testKNNResultsUpdateDocAndForceMerge() throws Exception { + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + addDocWithNumericField(INDEX_NAME, "1", "abc", 100 ); + addTestData(); + forceMergeKnnIndex(INDEX_NAME); + + /** + * Query params + */ + float[] queryVector = {1.0f, 1.0f}; // vector to be queried + int k = 1; // nearest 1 neighbor + + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(k, results.size()); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); + } + } + public void testKNNResultsWithoutForceMerge() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData();