diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index e24e511401..ab12da65cc 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -345,11 +345,11 @@ def test_knn_gpu(self, use_cuvs=False): def test_knn_gpu_cuvs(self): self.test_knn_gpu(use_cuvs=True) - def test_knn_gpu_datatypes(self, use_cuvs=False): + def test_knn_gpu_datatypes(self, use_cuvs=False, use_bf16=False): torch.manual_seed(10) d = 10 nb = 1024 - nq = 5 + nq = 50 k = 10 res = faiss.StandardGpuResources() @@ -361,8 +361,13 @@ def test_knn_gpu_datatypes(self, use_cuvs=False): index.add(xb) gt_D, gt_I = index.search(xq, k) - xb_c = xb.cuda().half() - xq_c = xq.cuda().half() + # convert to float16 + if use_bf16: + xb_c = xb.cuda().bfloat16() + xq_c = xq.cuda().bfloat16() + else: + xb_c = xb.cuda().half() + xq_c = xq.cuda().half() # use i32 output indices D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32) @@ -370,20 +375,32 @@ def test_knn_gpu_datatypes(self, use_cuvs=False): faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs) - self.assertTrue(torch.equal(I.long().cpu(), gt_I)) - self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3) + ndiff = (I.cpu() != gt_I).sum().item() + MSE = ((D.float().cpu() - gt_D) ** 2).sum().item() + if use_bf16: + # 57 -- bf16 is not as accurate as fp16 + self.assertLess(ndiff, 80) + # 0.00515 + self.assertLess(MSE, 8e-3) + else: + # 5 + self.assertLess(ndiff, 10) + # 8.565e-5 + self.assertLess(MSE, 1e-4) # Test using numpy - D = np.zeros((nq, k), dtype=np.float32) - I = np.zeros((nq, k), dtype=np.int32) + if not use_bf16: # bf16 not supported by numpy + # use i32 output indices + D = np.zeros((nq, k), dtype=np.float32) + I = np.zeros((nq, k), dtype=np.int32) - xb_c = xb.half().numpy() - xq_c = xq.half().numpy() + xb_c = xb.half().numpy() + xq_c = xq.half().numpy() faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs) - self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) - self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3) + def test_knn_gpu_bf16(self): + self.test_knn_gpu_datatypes(use_bf16=True) class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):