diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 22f4a43e..c07fa01a 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -117,6 +117,11 @@ def dot_general_raw_make( and 2 <= rhs_bits <= 8 ): dg_accumulator_dtype = jnp.int32 + elif ( + lhs_bits in fp8_numerics.fp8_map.keys() + or rhs_bits in fp8_numerics.fp8_map.keys() + ): + dg_accumulator_dtype = jnp.float32 else: dg_accumulator_dtype = None