Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove llvm-muladd pass and move it's functionality to to llvm-simdloop #55802

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions doc/src/devdocs/llvm-passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin

These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations.

### CombineMulAdd

* Filename: `llvm-muladd.cpp`
* Class Name: `CombineMulAddPass`
* Opt Name: `function(CombineMulAdd)`

This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/).

!!! note

This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`.

### AllocOpt

* Filename: `llvm-alloc-opt.cpp`
Expand Down
1 change: 0 additions & 1 deletion doc/src/devdocs/llvm.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir
| `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics |
| `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values |
| `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks |
| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA |
| `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures |
| `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces |
| `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations |
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ RT_LLVMLINK :=
CG_LLVMLINK :=

ifeq ($(JULIACODEGEN),LLVM)
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@
YY(LLVMExtraMPMAddRemoveAddrspacesPass) \
YY(LLVMExtraMPMAddLowerPTLSPass) \
YY(LLVMExtraFPMAddDemoteFloat16Pass) \
YY(LLVMExtraFPMAddCombineMulAddPass) \
YY(LLVMExtraFPMAddLateLowerGCPass) \
YY(LLVMExtraFPMAddAllocOptPass) \
YY(LLVMExtraFPMAddPropagateJuliaAddrspacesPass) \
Expand Down
1 change: 0 additions & 1 deletion src/llvm-julia-passes.inc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ MODULE_PASS("LowerPTLSPass", LowerPTLSPass, LowerPTLSPass())
//Function passes
#ifdef FUNCTION_PASS
FUNCTION_PASS("DemoteFloat16", DemoteFloat16Pass, DemoteFloat16Pass())
FUNCTION_PASS("CombineMulAdd", CombineMulAddPass, CombineMulAddPass())
FUNCTION_PASS("LateLowerGCFrame", LateLowerGCPass, LateLowerGCPass())
FUNCTION_PASS("AllocOpt", AllocOptPass, AllocOptPass())
FUNCTION_PASS("PropagateJuliaAddrspaces", PropagateJuliaAddrspacesPass, PropagateJuliaAddrspacesPass())
Expand Down
117 changes: 0 additions & 117 deletions src/llvm-muladd.cpp

This file was deleted.

66 changes: 66 additions & 0 deletions src/llvm-simdloop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction
STATISTIC(MaxChainLength, "Max length of reduction chain");
STATISTIC(AddChains, "Addition reduction chains");
STATISTIC(MulChains, "Multiply reduction chains");
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");

#ifndef __clang_gcanalyzer__
#define REMARK(remark) ORE.emit(remark)
Expand All @@ -49,6 +50,49 @@ STATISTIC(MulChains, "Multiply reduction chains");
#endif
namespace {

/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd contract ... %v0, %c
* ```
* to
* %v0 = fmul contract ... %a, %b
* %v = fadd contract ... %v0, %c
* when `%v0` has no other use
*/

static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!L.contains(mulOp))
return false;
if (!mulOp->hasOneUse()) {
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
});
return false;
}
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
if (!fmf.allowContract()) {
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
<< "marked for fma " << ore::NV("fmul", mulOp);
});
++TotalContracted;
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return true;
}
return false;
}

static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT
{
switch (J->getOpcode()) {
Expand Down Expand Up @@ -150,6 +194,28 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe
});
(*K)->setHasAllowReassoc(true);
(*K)->setHasAllowContract(true);
switch ((*K)->getOpcode()) {
case Instruction::FAdd: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
case Instruction::FSub: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
default:
break;
}
if (SE)
SE->forgetValue(*K);
++length;
Expand Down
4 changes: 0 additions & 4 deletions src/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ struct DemoteFloat16Pass : PassInfoMixin<DemoteFloat16Pass> {
static bool isRequired() { return true; }
};

struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
};

struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
static bool isRequired() { return true; }
Expand Down
1 change: 0 additions & 1 deletion src/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi
if (options.cleanup) {
if (O.getSpeedupLevel() >= 2) {
FunctionPassManager FPM;
JULIA_PASS(FPM.addPass(CombineMulAddPass()));
FPM.addPass(DivRemPairsPass());
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
Expand Down
69 changes: 57 additions & 12 deletions test/llvmpasses/julia-simdloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='loop(LowerSIMDLoop)' -S %s | FileCheck %s

; CHECK-LABEL: @simd_test(
define void @simd_test(double *%a, double *%b) {
define void @simd_test(ptr %a, ptr %b) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%bptr = getelementptr double, double *%b, i64 %i
%aptr = getelementptr double, ptr %a, i64 %i
%bptr = getelementptr double, ptr %b, i64 %i
; CHECK: llvm.mem.parallel_loop_access
%aval = load double, double *%aptr
%bval = load double, double *%aptr
%aval = load double, ptr %aptr
%bval = load double, ptr %aptr
%cval = fadd double %aval, %bval
store double %cval, double *%bptr
store double %cval, ptr %bptr
%nexti = add i64 %i, 1
%done = icmp sgt i64 %nexti, 500
br i1 %done, label %loopdone, label %loop, !llvm.loop !1
Expand All @@ -23,15 +23,15 @@ loopdone:
}

; CHECK-LABEL: @simd_test_sub(
define double @simd_test_sub(double *%a) {
define double @simd_test_sub(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aptr = getelementptr double, ptr %a, i64 %i
; CHECK: llvm.mem.parallel_loop_access
%aval = load double, double *%aptr
%aval = load double, ptr %aptr
%nextv = fsub double %v, %aval
; CHECK: fsub reassoc contract double %v, %aval
%nexti = add i64 %i, 1
Expand All @@ -42,14 +42,14 @@ loopdone:
}

; CHECK-LABEL: @simd_test_sub2(
define double @simd_test_sub2(double *%a) {
define double @simd_test_sub2(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aval = load double, double *%aptr
%aptr = getelementptr double, ptr %a, i64 %i
%aval = load double, ptr %aptr
%nextv = fsub double %v, %aval
; CHECK: fsub reassoc contract double %v, %aval
%nexti = add i64 %i, 1
Expand All @@ -59,6 +59,26 @@ loopdone:
ret double %nextv
}

; CHECK-LABEL: @simd_test_sub4(
define double @simd_test_sub4(ptr %a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aval = load double, double *%aptr
%nextv2 = fmul double %aval, %aval
; CHECK: fmul contract double %aval, %aval
%nextv = fsub double %v, %nextv2
; CHECK: fsub reassoc contract double %v, %nextv2
%nexti = add i64 %i, 1
%done = icmp sgt i64 %nexti, 500
br i1 %done, label %loopdone, label %loop, !llvm.loop !0
loopdone:
ret double %nextv
}

; Tests if we correctly pass through other metadata
; CHECK-LABEL: @disabled(
define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) {
Expand All @@ -82,6 +102,31 @@ for.end: ; preds = %for.body
ret i32 %1
}

; Check that we don't add contract to non loop things
; CHECK-LABEL: @dont_add_no_loop(
define double @dont_add_no_loop(ptr nocapture noundef nonnull readonly align 8 dereferenceable(72) %"a::Tuple", ptr nocapture noundef nonnull readonly align 8 dereferenceable(24) %"b::Tuple") #0 {
top:
%"a::Tuple[9]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 64
%"b::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 16
%"a::Tuple[6]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 40
%"b::Tuple[2]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 8
%"a::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 16
%"a::Tuple[3]_ptr.unbox" = load double, ptr %"a::Tuple[3]_ptr", align 8
%"b::Tuple.unbox" = load double, ptr %"b::Tuple", align 8
%0 = fmul double %"a::Tuple[3]_ptr.unbox", %"b::Tuple.unbox"
; CHECK: fmul double %
%"a::Tuple[6]_ptr.unbox" = load double, ptr %"a::Tuple[6]_ptr", align 8
%"b::Tuple[2]_ptr.unbox" = load double, ptr %"b::Tuple[2]_ptr", align 8
%1 = fmul contract double %"a::Tuple[6]_ptr.unbox", %"b::Tuple[2]_ptr.unbox"
%2 = fadd contract double %0, %1
%"a::Tuple[9]_ptr.unbox" = load double, ptr %"a::Tuple[9]_ptr", align 8
%"b::Tuple[3]_ptr.unbox" = load double, ptr %"b::Tuple[3]_ptr", align 8
%3 = fmul contract double %"a::Tuple[9]_ptr.unbox", %"b::Tuple[3]_ptr.unbox"
%4 = fadd contract double %2, %3
ret double %4
}


!0 = distinct !{!0, !"julia.simdloop"}
!1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"}
!2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3}
Expand Down
Loading