Skip to content

Commit

Permalink
[FRONTEND] fix default max_num_imprecise_acc (#2804) (#3851)
Browse files Browse the repository at this point in the history
(cherry picked from commit 39f4473)

Co-authored-by: Philippe Tillet <[email protected]>
  • Loading branch information
plotfi and ptillet authored May 14, 2024
1 parent 3f8d91b commit 730c825
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/triton/compiler/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, device_type: tuple) -> None:
def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
args["allow_fp8e4nv"] = self.capability >= 89
args["max_num_imprecise_acc_default"] = 0 if self.capability >= 89 else None
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
return CUDAOptions(**args)

@staticmethod
Expand Down
7 changes: 3 additions & 4 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,11 +1263,10 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
assert acc.type == ret_ty

# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
max_num_imprecise_acc = 0
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc is None:
max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30
else:
max_num_imprecise_acc = 0

return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)

Expand Down

0 comments on commit 730c825

Please sign in to comment.