diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 5bd4e9359f..cd783331b0 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -117,7 +117,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * This improves the recall. */ if (filterWeight != null && canDoExactSearch(cardinality)) { - docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet)); + docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet, cardinality)); } else { Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { @@ -131,7 +131,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { annResults.size(), cardinality ); - annResults = doExactSearch(context, filterBitSet); + annResults = doExactSearch(context, filterBitSet, cardinality); } docIdsToScoreMap.putAll(annResults); } @@ -308,10 +308,10 @@ private Map doANNSearch(final LeafReaderContext context, final B .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) { + private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet, int cardinality) { try { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); + final HitQueue queue = new HitQueue(Math.min(this.knnQuery.getK(), cardinality), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet);