diff --git a/datasets/metrics.py b/datasets/metrics.py index b7cb5e3..0dd89ca 100644 --- a/datasets/metrics.py +++ b/datasets/metrics.py @@ -203,9 +203,12 @@ def knn_predict( """ if similarity == "cosine": # compute cos similarity between each feature vector and feature bank ---> [N_q, N_m] + assert queries.size(-1) == memory_bank.size(0) + memory_bank = memory_bank.T + memory_bank = memory_bank / (memory_bank.norm(dim=-1, keepdim=True) + 1e-7) similarity_matrix = torch.mm( - queries / queries.norm(dim=-1, keepdim=True), - memory_bank / memory_bank.norm(dim=-1, keepdim=True), + queries / (queries.norm(dim=-1, keepdim=True) + 1e-7), + memory_bank.T, ) elif similarity == "l2": # compute the L2 distance using broadcasting