diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d2fd22e3325f..5aa4d07b088d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -18,6 +18,7 @@ #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -80,11 +81,11 @@ namespace mlir::tpu { // because it's easier to debug from Python (particularly from OSS where symbols // are removed) #define TPU_ASSERT_IMPL(stream, cond) \ - if (!(cond)) { \ + if (LLVM_UNLIKELY(!(cond))) { \ (stream) << "Internal error: assert failed: " #cond; \ } #define TPU_ASSERT_CMP_IMPL(stream, lhs, rhs, cmp) \ - if (!((lhs)cmp(rhs))) { \ + if (LLVM_UNLIKELY(!((lhs)cmp(rhs)))) { \ (stream) << "Internal error: assert failed: " #lhs " " #cmp " " #rhs " (" \ << (lhs) << " vs. " << (rhs) << ")"; \ return failure(); \