Skip to content

Commit

Permalink
quick fix of cosine similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiawei-Yang committed Nov 7, 2023
1 parent 65c0e56 commit 47f921f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions datasets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,12 @@ def knn_predict(
"""
if similarity == "cosine":
# compute cos similarity between each feature vector and feature bank ---> [N_q, N_m]
assert queries.size(-1) == memory_bank.size(0)
memory_bank = memory_bank.T
memory_bank = memory_bank / (memory_bank.norm(dim=-1, keepdim=True) + 1e-7)
similarity_matrix = torch.mm(
queries / queries.norm(dim=-1, keepdim=True),
memory_bank / memory_bank.norm(dim=-1, keepdim=True),
queries / (queries.norm(dim=-1, keepdim=True) + 1e-7),
memory_bank.T,
)
elif similarity == "l2":
# compute the L2 distance using broadcasting
Expand Down

0 comments on commit 47f921f

Please sign in to comment.