diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 83306a75e..b8b88b4fe 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -122,6 +122,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { } docIdsToScoreMap.putAll(annResults); } + if (docIdsToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(this); + } return convertSearchResponseToScorer(docIdsToScoreMap); } @@ -134,12 +137,7 @@ private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight f return new FixedBitSet(0); } - final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); - // TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the - // distance for K vectors. This can avoid calls to native layer and save some latency. - final int cost = acceptDocs.cardinality(); - log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost); - return acceptDocs; + return createBitSet(scorer.iterator(), liveDocs, maxDoc); } private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { @@ -165,7 +163,7 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; int filteredIdsIndex = 0; int docId = 0; - while (true) { + while (docId < filteredDocsBitSet.length()) { docId = filteredDocsBitSet.nextSetBit(docId); if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { break; @@ -291,23 +289,22 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); float[] queryVector = this.knnQuery.getQueryVector(); try { - final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name); + final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE)); - //Creating min heap and init with MAX DocID and Score as -INF. + // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); for (int filterId : filterIdsArray) { int docId = values.advance(filterId); final BytesRef value = values.binaryValue(); - final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, - value.length); + final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector); - if(score > topDoc.score) { + if (score > topDoc.score) { topDoc.score = score; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we @@ -329,7 +326,7 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont return docToScore; } catch (Exception e) { - log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery); + log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery, e); } return Collections.emptyMap(); } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 8eb99b625..bb0ade22d 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -12,11 +12,15 @@ package org.opensearch.knn.index; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import org.apache.http.util.EntityUtils; import org.junit.BeforeClass; +import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentFactory; @@ -26,6 +30,7 @@ import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.script.KNNScoringUtil; +import org.opensearch.rest.RestStatus; import java.io.IOException; import java.net.URL; @@ -44,6 +49,14 @@ public class FaissIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-1"; + private static final String FIELD_NAME = "test-field-1"; + private static final String DOC_ID_1 = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; + static TestUtils.TestData testData; @BeforeClass @@ -280,4 +293,90 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException { assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId())); } } + + @SneakyThrows + public void testQueryWithFilter_withDifferentCombination_thenSuccess() { + setupKNNIndexForFilterQuery(); + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + // K > filteredResults + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList(DOC_ID_1, DOC_ID_3); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + // K Limits Filter results + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = List.of(DOC_ID_1); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + + // Empty filter docIds + int k = 10; + final Response emptyFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present") + ), + k + ); + final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity()); + final List emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME); + + assertEquals(0, emptyKNNFilteredResultsFromResponse.size()); + } + + protected void setupKNNIndexForFilterQuery() throws Exception { + // Create Mappings + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", 3) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = Strings.toString(builder); + + createKnnIndex(INDEX_NAME, mapping); + + addKnnDocWithAttributes( + DOC_ID_1, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshIndex(INDEX_NAME); + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 2266d755d..080f0d497 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -361,22 +361,6 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - private void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { - Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); - for (String fieldName : fieldValues.keySet()) { - builder.field(fieldName, fieldValues.get(fieldName)); - } - builder.endObject(); - request.setJsonEntity(Strings.toString(builder)); - client().performRequest(request); - - request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - } - private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 0f8f43bf2..674d1be39 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -5,8 +5,11 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.Term; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.mockito.Mockito; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; @@ -23,6 +26,10 @@ import static org.mockito.Mockito.when; public class KNNQueryFactoryTests extends KNNTestCase { + private static final String FILTER_FILED_NAME = "foo"; + private static final String FILTER_FILED_VALUE = "fooval"; + private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE); + private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE)); private final int testQueryDimension = 17; private final float[] testQueryVector = new float[testQueryDimension]; private final String testIndexName = "test-index"; @@ -59,7 +66,6 @@ public void testCreateLuceneQueryWithFilter() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); MappedFieldType testMapper = mock(MappedFieldType.class); when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); - QueryBuilder filter = new TermQueryBuilder("foo", "fooval"); final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -67,10 +73,35 @@ public void testCreateLuceneQueryWithFilter() { .vector(testQueryVector) .k(testK) .context(mockQueryShardContext) - .filter(filter) + .filter(FILTER_QUERY_BUILDER) .build(); Query query = KNNQueryFactory.create(createQueryRequest); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } + + public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { + final KNNEngine knnEngine = KNNEngine.FAISS; + final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + final Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KNNQuery); + + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testK, ((KNNQuery) query).getK()); + assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index ec8675ab0..53d0330f0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -6,17 +6,24 @@ package org.opensearch.knn.index.query; import com.google.common.collect.Comparators; +import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentCommitInfo; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import org.junit.BeforeClass; @@ -28,6 +35,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.util.KNNEngine; @@ -70,6 +78,10 @@ public class KNNWeightTests extends KNNTestCase { private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); + private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); + private static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); + + private static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); private static MockedStatic nativeMemoryCacheManagerMockedStatic; private static MockedStatic jniServiceMockedStatic; @@ -313,6 +325,145 @@ public void testEmptyQueryResults() { assertNull(knnScorer); } + @SneakyThrows + public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))) + .thenReturn(getFilteredKNNQueryResults()); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(reader.maxDoc()).thenReturn(K + 1); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // Just to make sure that we are not hitting the exact search condition + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(K + 1)); + + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName()); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { + float[] vector = new float[] { 0.1f, 0.3f }; + int filterDocId = 0; + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // scorer will return 2 documents + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); + when(reader.maxDoc()).thenReturn(1); + + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + when(binaryDocValues.advance(filterDocId)).thenReturn(filterDocId); + BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector)); + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + + final Scorer knnScorer = knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(0, docIdSetIterator.cost()); + assertEquals(0, docIdSetIterator.cost()); + } + private void testQueryScore( final Function scoreTranslator, final Set segmentFiles, @@ -384,4 +535,12 @@ private KNNQueryResult[] getKNNQueryResults() { .collect(Collectors.toList()) .toArray(new KNNQueryResult[0]); } + + private KNNQueryResult[] getFilteredKNNQueryResults() { + return FILTERED_DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 39f7384ad..739b5835b 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -696,10 +696,23 @@ public void testQueryIndex_faiss_valid() throws IOException { KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null); assertEquals(k, results.length); } + + // Filter will result in no ids + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }); + assertEquals(0, results.length); + } } } } + public void testQueryIndexWithFilterIds_whenNMSLibEngine_thenException() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> JNIService.queryIndex(0L, new float[] { 0.1f, 0.2f, 0.33f }, 10, KNNEngine.NMSLIB.getName(), new int[] { 1, 2, 3 }) + ); + } + public void testFree_invalidEngine() { expectThrows(IllegalArgumentException.class, () -> JNIService.free(0L, "invalid-engine")); } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index c4e3cbbc7..de0839e25 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.knn; import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableMap; import com.google.common.io.Resources; import com.google.common.primitives.Floats; import org.apache.commons.lang.StringUtils; @@ -16,9 +17,11 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -1306,4 +1309,20 @@ protected void refreshIndex(final String index) throws IOException { Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + + protected void addKnnDocWithAttributes(String docId, float[] vector, Map fieldValues) throws IOException { + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true"); + + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector); + for (String fieldName : fieldValues.keySet()) { + builder.field(fieldName, fieldValues.get(fieldName)); + } + builder.endObject(); + request.setJsonEntity(Strings.toString(builder)); + client().performRequest(request); + + request = new Request("POST", "/" + INDEX_NAME + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } }