Skip to content

Commit

Permalink
Avoid undefined behaviour in ElementwiseOpToLLVM
Browse files Browse the repository at this point in the history
Triton uses the pattern `llvm:errs() << "Error message"; llvm::unreachable()` in some places. I suspect the author assumed that `llvm::errs()` aborts after printing the error message which it does not. So I replace the construct by `llvm::report_fatal_error("Error message")` instead which is used in many other places in the same file.

This recently causes flakyness in the `TritonSupportTest` in XLA. In this test we rely on the fact that Triton aborts when reaching these code paths but since invoking `llvm:unreachable()` leads to undefined behaviour it not always aborts, but rather does something else. On ARM for example test sometimes deadlocks on ARM - which resulted in the observed flakyness.

PiperOrigin-RevId: 694508541
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Nov 8, 2024
1 parent 6b43045 commit 9dee234
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions third_party/triton/temporary/replace_unreachable_by_abort.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -388,10 +388,10 @@ struct FpToFpOpConversion
ptx = "cvt.rz.f16.f32";
break;
default:
- llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 "
- "conversion: "
- << stringifyRoundingMode(rounding) << "\n";
- llvm_unreachable("");
+ llvm::report_fatal_error(
+ "WARNING: unsupported rounding mode for f32->f16 "
+ "conversion: " + stringifyRoundingMode(rounding) +
+ "\n");
}
auto &cvt = *builder.create(ptx.str());
auto res = builder.newOperand("=h");
@@ -448,10 +448,10 @@ struct FpToFpOpConversion
}
if (computeCapability < 89 &&
(srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
- llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
- "compute capability >= 89"
- << "\n";
- llvm_unreachable("");
+ llvm::report_fatal_error(
+ "Conversion from/to f8e4m3nv is only supported on "
+ "compute capability >= 89"
+ "\n");
}
auto convDesc = srcMap.lookup(key);
return {makeConverterFromPtx(
@@ -476,9 +476,9 @@ struct FpToFpOpConversion
// For now only RTNE is supported for conversions from fp16 to fp8
if (!srcElementType.isF32() &&
roundingMode.value() != RoundingMode::RTNE) {
- llvm::errs() << "Unsupported rounding mode for conversion to fp8: "
- << stringifyRoundingMode(roundingMode.value()) << "\n";
- llvm_unreachable("");
+ llvm::report_fatal_error(
+ "Unsupported rounding mode for conversion to fp8: " +
+ stringifyRoundingMode(roundingMode.value()) + "\n");
}
}

1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ those to this list.
"""

temporary_patch_list = [
"//third_party/triton:temporary/replace_unreachable_by_abort.patch",
# Add new patches just above this line
]

0 comments on commit 9dee234

Please sign in to comment.