From c570a76d2de3db38416b3c75f1d103e008d19a2c Mon Sep 17 00:00:00 2001
From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Date: Tue, 21 May 2024 19:54:04 +0800
Subject: [PATCH] improve e4m3 decoding. (#43)

Co-authored-by: LeiWang199 <leiwang199>
---
 python/bitblas/quantization/quantization.py         | 6 +++---
 testing/python/operators/test_general_matmul_fp8.py | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py
index d9f36094794d..d68d437d49e3 100644
--- a/python/bitblas/quantization/quantization.py
+++ b/python/bitblas/quantization/quantization.py
@@ -142,9 +142,9 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
 def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
     assert nbit == 8
     assert dtype == "float16"
-    s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
-    offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
-    e_f16 = ((val << tir.const(7, "int16")) + offset)
+    s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
+    prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16"))
+    e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix
     return tir.reinterpret("float16", s_f16 | e_f16)
 
 
diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py
index 3d0a7be2f583..5b7de9ab0056 100644
--- a/testing/python/operators/test_general_matmul_fp8.py
+++ b/testing/python/operators/test_general_matmul_fp8.py
@@ -166,7 +166,7 @@ def map_torch_type(intype):
     print("torch_ref_out", ref_out)
     print("bitblas_out", bitblas_out)
 
-    torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2)
+    torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1)
 
 
 # fmt: on