diff --git a/benchmarks/perf-tool/okpt/test/steps/steps.py b/benchmarks/perf-tool/okpt/test/steps/steps.py index b04a4af4de..cc1773330b 100644 --- a/benchmarks/perf-tool/okpt/test/steps/steps.py +++ b/benchmarks/perf-tool/okpt/test/steps/steps.py @@ -454,6 +454,9 @@ def _action(self): results['took'] = [ float(query_response['took']) for query_response in query_responses ] + results['client_time'] = [ + float(query_response['client_time']) for query_response in query_responses + ] results['memory_kb'] = get_cache_size_in_kb(self.endpoint, self.port) if self.calculate_recall: @@ -472,7 +475,7 @@ def _action(self): return results def _get_measures(self) -> List[str]: - measures = ['took', 'memory_kb'] + measures = ['took', 'memory_kb', 'client_time'] if self.calculate_recall: measures.extend(['recall@K', f'recall@{str(self.r)}']) @@ -783,9 +786,13 @@ def get_cache_size_in_kb(endpoint, port): def query_index(opensearch: OpenSearch, index_name: str, body: dict, excluded_fields: list): - return opensearch.search(index=index_name, + start_time = round(time.time()*1000) + queryResponse = opensearch.search(index=index_name, body=body, _source_excludes=excluded_fields) + end_time = round(time.time() * 1000) + queryResponse['client_time'] = end_time - start_time + return queryResponse def bulk_index(opensearch: OpenSearch, index_name: str, body: List): diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 5f96a79717..83306a75e8 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -9,6 +9,8 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -291,18 +293,40 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont try { final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name); final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE)); - + //Creating min heap and init with MAX DocID and Score as -INF. + final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); + ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); - for (int j : filterIdsArray) { - int docId = values.advance(j); - BytesRef value = values.binaryValue(); - ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); + for (int filterId : filterIdsArray) { + int docId = values.advance(filterId); + final BytesRef value = values.binaryValue(); + final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, + value.length); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - // making min score as high score as this is closest to the vector + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector); - docToScore.put(docId, score); + if(score > topDoc.score) { + topDoc.score = score; + topDoc.doc = docId; + // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we + // have seen till now on top. + topDoc = queue.updateTop(); + } + } + // If scores are negative we will remove them. + // This is done, because there can be negative values in the Heap as we init the heap with Score as -INF. + // If filterIds < k, the some values in heap can have a negative score. + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); } + + while (queue.size() > 0) { + final ScoreDoc doc = queue.pop(); + docToScore.put(doc.doc, doc.score); + } + return docToScore; } catch (Exception e) { log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery);