Skip to content

Commit

Permalink
Update the accumulation dtype if FP8 precision is used for AQT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686533904
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Nov 5, 2024
1 parent 6134c4b commit 05778b8
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def dot_general_raw_make(
and 2 <= rhs_bits <= 8
):
dg_accumulator_dtype = jnp.int32
elif (
lhs_bits in fp8_numerics.fp8_map.keys()
or rhs_bits in fp8_numerics.fp8_map.keys()
):
dg_accumulator_dtype = jnp.float32
else:
dg_accumulator_dtype = None

Expand Down

0 comments on commit 05778b8

Please sign in to comment.