From 41ff6a877f1431d71e993e8fb1b48bcce804526d Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Tue, 26 Nov 2024 09:54:04 -0800 Subject: [PATCH] support bfloat16 in python (#4037) Summary: Add test in for the pytorch interface Reviewed By: asadoughi Differential Revision: D66074156 --- faiss/gpu/test/torch_test_contrib_gpu.py | 44 +++++++++++++++++------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index e24e511401..654e7641ce 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -345,13 +345,16 @@ def test_knn_gpu(self, use_cuvs=False): def test_knn_gpu_cuvs(self): self.test_knn_gpu(use_cuvs=True) - def test_knn_gpu_datatypes(self, use_cuvs=False): + def test_knn_gpu_datatypes(self, use_cuvs=False, use_bf16=False): torch.manual_seed(10) d = 10 nb = 1024 - nq = 5 + nq = 50 k = 10 res = faiss.StandardGpuResources() + if use_bf16 and not res.supportsBFloat16CurrentDevice(): + print("WARNING bfloat16 not supported -- test not executed") + return # make GT on torch cpu and test using IndexFlatL2 xb = torch.rand(nb, d, dtype=torch.float32) @@ -361,8 +364,13 @@ def test_knn_gpu_datatypes(self, use_cuvs=False): index.add(xb) gt_D, gt_I = index.search(xq, k) - xb_c = xb.cuda().half() - xq_c = xq.cuda().half() + # convert to float16 + if use_bf16: + xb_c = xb.cuda().bfloat16() + xq_c = xq.cuda().bfloat16() + else: + xb_c = xb.cuda().half() + xq_c = xq.cuda().half() # use i32 output indices D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32) @@ -370,20 +378,32 @@ def test_knn_gpu_datatypes(self, use_cuvs=False): faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs) - self.assertTrue(torch.equal(I.long().cpu(), gt_I)) - self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3) + ndiff = (I.cpu() != gt_I).sum().item() + MSE = ((D.float().cpu() - gt_D) ** 2).sum().item() + if use_bf16: + # 57 -- bf16 is not as accurate as fp16 + self.assertLess(ndiff, 80) + # 0.00515 + self.assertLess(MSE, 8e-3) + else: + # 5 + self.assertLess(ndiff, 10) + # 8.565e-5 + self.assertLess(MSE, 1e-4) # Test using numpy - D = np.zeros((nq, k), dtype=np.float32) - I = np.zeros((nq, k), dtype=np.int32) + if not use_bf16: # bf16 not supported by numpy + # use i32 output indices + D = np.zeros((nq, k), dtype=np.float32) + I = np.zeros((nq, k), dtype=np.int32) - xb_c = xb.half().numpy() - xq_c = xq.half().numpy() + xb_c = xb.half().numpy() + xq_c = xq.half().numpy() faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs) - self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) - self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3) + def test_knn_gpu_bf16(self): + self.test_knn_gpu_datatypes(use_bf16=True) class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):