Skip to content

Commit

Permalink
Remove llvm-muladd pass and move it's functionality to to llvm-simdlo…
Browse files Browse the repository at this point in the history
…op (#55802)

Closes #55785

I'm not sure if we want to backport this like this. Because that removes
some functionality (the pass itself). So LLVM.jl and friends might need
annoying version code. We can maybe keep the code there and just not run
the pass in a backport.
  • Loading branch information
gbaraldi authored Oct 15, 2024
1 parent b02d671 commit 924dc17
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 213 deletions.
12 changes: 0 additions & 12 deletions doc/src/devdocs/llvm-passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin

These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations.

### CombineMulAdd

* Filename: `llvm-muladd.cpp`
* Class Name: `CombineMulAddPass`
* Opt Name: `function(CombineMulAdd)`

This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/).

!!! note

This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`.

### AllocOpt

* Filename: `llvm-alloc-opt.cpp`
Expand Down
1 change: 0 additions & 1 deletion doc/src/devdocs/llvm.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir
| `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics |
| `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values |
| `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks |
| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA |
| `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures |
| `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces |
| `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations |
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ RT_LLVMLINK :=
CG_LLVMLINK :=

ifeq ($(JULIACODEGEN),LLVM)
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@
YY(LLVMExtraMPMAddRemoveAddrspacesPass) \
YY(LLVMExtraMPMAddLowerPTLSPass) \
YY(LLVMExtraFPMAddDemoteFloat16Pass) \
YY(LLVMExtraFPMAddCombineMulAddPass) \
YY(LLVMExtraFPMAddLateLowerGCPass) \
YY(LLVMExtraFPMAddAllocOptPass) \
YY(LLVMExtraFPMAddPropagateJuliaAddrspacesPass) \
Expand Down
1 change: 0 additions & 1 deletion src/llvm-julia-passes.inc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ MODULE_PASS("LowerPTLSPass", LowerPTLSPass, LowerPTLSPass())
//Function passes
#ifdef FUNCTION_PASS
FUNCTION_PASS("DemoteFloat16", DemoteFloat16Pass, DemoteFloat16Pass())
FUNCTION_PASS("CombineMulAdd", CombineMulAddPass, CombineMulAddPass())
FUNCTION_PASS("LateLowerGCFrame", LateLowerGCPass, LateLowerGCPass())
FUNCTION_PASS("AllocOpt", AllocOptPass, AllocOptPass())
FUNCTION_PASS("PropagateJuliaAddrspaces", PropagateJuliaAddrspacesPass, PropagateJuliaAddrspacesPass())
Expand Down
117 changes: 0 additions & 117 deletions src/llvm-muladd.cpp

This file was deleted.

66 changes: 66 additions & 0 deletions src/llvm-simdloop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction
STATISTIC(MaxChainLength, "Max length of reduction chain");
STATISTIC(AddChains, "Addition reduction chains");
STATISTIC(MulChains, "Multiply reduction chains");
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");

#ifndef __clang_gcanalyzer__
#define REMARK(remark) ORE.emit(remark)
Expand All @@ -49,6 +50,49 @@ STATISTIC(MulChains, "Multiply reduction chains");
#endif
namespace {

/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd contract ... %v0, %c
* ```
* to
* %v0 = fmul contract ... %a, %b
* %v = fadd contract ... %v0, %c
* when `%v0` has no other use
*/

static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!L.contains(mulOp))
return false;
if (!mulOp->hasOneUse()) {
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
});
return false;
}
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
if (!fmf.allowContract()) {
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
<< "marked for fma " << ore::NV("fmul", mulOp);
});
++TotalContracted;
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return true;
}
return false;
}

static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT
{
switch (J->getOpcode()) {
Expand Down Expand Up @@ -150,6 +194,28 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe
});
(*K)->setHasAllowReassoc(true);
(*K)->setHasAllowContract(true);
switch ((*K)->getOpcode()) {
case Instruction::FAdd: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
case Instruction::FSub: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
default:
break;
}
if (SE)
SE->forgetValue(*K);
++length;
Expand Down
4 changes: 0 additions & 4 deletions src/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ struct DemoteFloat16Pass : PassInfoMixin<DemoteFloat16Pass> {
static bool isRequired() { return true; }
};

struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
};

struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
static bool isRequired() { return true; }
Expand Down
1 change: 0 additions & 1 deletion src/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi
if (options.cleanup) {
if (O.getSpeedupLevel() >= 2) {
FunctionPassManager FPM;
JULIA_PASS(FPM.addPass(CombineMulAddPass()));
FPM.addPass(DivRemPairsPass());
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
Expand Down
69 changes: 57 additions & 12 deletions test/llvmpasses/julia-simdloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='loop(LowerSIMDLoop)' -S %s | FileCheck %s

; CHECK-LABEL: @simd_test(
define void @simd_test(double *%a, double *%b) {
define void @simd_test(ptr %a, ptr %b) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%bptr = getelementptr double, double *%b, i64 %i
%aptr = getelementptr double, ptr %a, i64 %i
%bptr = getelementptr double, ptr %b, i64 %i
; CHECK: llvm.mem.parallel_loop_access
%aval = load double, double *%aptr
%bval = load double, double *%aptr
%aval = load double, ptr %aptr
%bval = load double, ptr %aptr
%cval = fadd double %aval, %bval
store double %cval, double *%bptr
store double %cval, ptr %bptr
%nexti = add i64 %i, 1
%done = icmp sgt i64 %nexti, 500
br i1 %done, label %loopdone, label %loop, !llvm.loop !1
Expand All @@ -23,15 +23,15 @@ loopdone:
}

; CHECK-LABEL: @simd_test_sub(
define double @simd_test_sub(double *%a) {
define double @simd_test_sub(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aptr = getelementptr double, ptr %a, i64 %i
; CHECK: llvm.mem.parallel_loop_access
%aval = load double, double *%aptr
%aval = load double, ptr %aptr
%nextv = fsub double %v, %aval
; CHECK: fsub reassoc contract double %v, %aval
%nexti = add i64 %i, 1
Expand All @@ -42,14 +42,14 @@ loopdone:
}

; CHECK-LABEL: @simd_test_sub2(
define double @simd_test_sub2(double *%a) {
define double @simd_test_sub2(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aval = load double, double *%aptr
%aptr = getelementptr double, ptr %a, i64 %i
%aval = load double, ptr %aptr
%nextv = fsub double %v, %aval
; CHECK: fsub reassoc contract double %v, %aval
%nexti = add i64 %i, 1
Expand All @@ -59,6 +59,26 @@ loopdone:
ret double %nextv
}

; CHECK-LABEL: @simd_test_sub4(
define double @simd_test_sub4(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aval = load double, double *%aptr
%nextv2 = fmul double %aval, %aval
; CHECK: fmul contract double %aval, %aval
%nextv = fsub double %v, %nextv2
; CHECK: fsub reassoc contract double %v, %nextv2
%nexti = add i64 %i, 1
%done = icmp sgt i64 %nexti, 500
br i1 %done, label %loopdone, label %loop, !llvm.loop !0
loopdone:
ret double %nextv
}

; Tests if we correctly pass through other metadata
; CHECK-LABEL: @disabled(
define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) {
Expand All @@ -82,6 +102,31 @@ for.end: ; preds = %for.body
ret i32 %1
}

; Check that we don't add contract to non loop things
; CHECK-LABEL: @dont_add_no_loop(
define double @dont_add_no_loop(ptr nocapture noundef nonnull readonly align 8 dereferenceable(72) %"a::Tuple", ptr nocapture noundef nonnull readonly align 8 dereferenceable(24) %"b::Tuple") #0 {
top:
%"a::Tuple[9]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 64
%"b::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 16
%"a::Tuple[6]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 40
%"b::Tuple[2]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 8
%"a::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 16
%"a::Tuple[3]_ptr.unbox" = load double, ptr %"a::Tuple[3]_ptr", align 8
%"b::Tuple.unbox" = load double, ptr %"b::Tuple", align 8
%0 = fmul double %"a::Tuple[3]_ptr.unbox", %"b::Tuple.unbox"
; CHECK: fmul double %
%"a::Tuple[6]_ptr.unbox" = load double, ptr %"a::Tuple[6]_ptr", align 8
%"b::Tuple[2]_ptr.unbox" = load double, ptr %"b::Tuple[2]_ptr", align 8
%1 = fmul contract double %"a::Tuple[6]_ptr.unbox", %"b::Tuple[2]_ptr.unbox"
%2 = fadd contract double %0, %1
%"a::Tuple[9]_ptr.unbox" = load double, ptr %"a::Tuple[9]_ptr", align 8
%"b::Tuple[3]_ptr.unbox" = load double, ptr %"b::Tuple[3]_ptr", align 8
%3 = fmul contract double %"a::Tuple[9]_ptr.unbox", %"b::Tuple[3]_ptr.unbox"
%4 = fadd contract double %2, %3
ret double %4
}


!0 = distinct !{!0, !"julia.simdloop"}
!1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"}
!2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3}
Expand Down
Loading

0 comments on commit 924dc17

Please sign in to comment.