diff --git a/src/api.cpp b/src/api.cpp index 8713eec1..10558a8e 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -380,6 +380,13 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, } } } + if (min_dists.size() < kNN || row_dist < min_dists.top().dist) { + SparseDist new_min = {row_dist, j}; + min_dists.push(new_min); + if (min_dists.size() > kNN) { + min_dists.pop(); + } + } if ((i * ref_sketches.size() + j) % update_every == 0) { #pragma omp critical { @@ -390,24 +397,15 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, } } } + } - if (min_dists.size() < kNN || row_dist < min_dists.top().dist) { - SparseDist new_min = {row_dist, j}; - min_dists.push(new_min); - if (min_dists.size() > kNN) { - min_dists.pop(); - } - } - - long offset = i * kNN; - std::fill_n(i_vec.begin() + offset, kNN, i); - for (int k = 0; k < kNN; ++k) { - SparseDist entry = min_dists.top(); - j_vec[offset + k] = entry.j; - dists[offset + k] = entry.dist; - min_dists.pop(); - } - + long offset = i * kNN; + std::fill_n(i_vec.begin() + offset, kNN, i); + for (int k = 0; k < kNN; ++k) { + SparseDist entry = min_dists.top(); + j_vec[offset + k] = entry.j; + dists[offset + k] = entry.dist; + min_dists.pop(); } } }