Skip to content

Commit

Permalink
Correct and simplify sdot/ddot (rust-lang#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 7, 2022
1 parent 0ef2b90 commit 44e7122
Show file tree
Hide file tree
Showing 22 changed files with 435 additions and 2,061 deletions.
45 changes: 31 additions & 14 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4652,6 +4652,11 @@ class AdjointGenerator

bool handleBLAS(llvm::CallInst &call, Function *called, StringRef funcName,
const std::map<Argument *, bool> &uncacheable_args) {
// Forward Mode not handled yet
assert(Mode != DerivativeMode::ForwardMode &&
Mode != DerivativeMode::ForwardModeSplit);
// Vector Mode not handled yet
assert(gutils->getWidth() == 1);
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
IRBuilder<> BuilderZ(newCall);
BuilderZ.setFastMathFlags(getFast());
Expand All @@ -4671,9 +4676,6 @@ class AdjointGenerator
}
Type *castvals[2] = {call.getArgOperand(1)->getType(),
call.getArgOperand(3)->getType()};
auto *cachetype =
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
Value *undefinit = UndefValue::get(cachetype);
Value *cacheval;
auto in_arg = call.getCalledFunction()->arg_begin();
in_arg++;
Expand All @@ -4694,15 +4696,16 @@ class AdjointGenerator
if (xcache) {
auto dmemcpy =
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
PointerType::getUnqual(innerType), 0, 0);
cast<PointerType>(castvals[0]), 0, 0);
auto malins = CallInst::CreateMalloc(
gutils->getNewFromOriginal(&call), size->getType(), innerType,
size, call.getArgOperand(0), nullptr, "");
arg1 =
BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType());
size, gutils->getNewFromOriginal(call.getArgOperand(0)), nullptr,
"");
arg1 = BuilderZ.CreateBitCast(malins, castvals[0]);
Value *args[4] = {arg1,
gutils->getNewFromOriginal(call.getArgOperand(1)),
call.getArgOperand(0), call.getArgOperand(2)};
gutils->getNewFromOriginal(call.getArgOperand(0)),
gutils->getNewFromOriginal(call.getArgOperand(2))};

BuilderZ.CreateCall(
dmemcpy, args,
Expand All @@ -4715,15 +4718,16 @@ class AdjointGenerator
if (ycache) {
auto dmemcpy =
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
PointerType::getUnqual(innerType), 0, 0);
cast<PointerType>(castvals[1]), 0, 0);
auto malins = CallInst::CreateMalloc(
gutils->getNewFromOriginal(&call), size->getType(), innerType,
size, call.getArgOperand(0), nullptr, "");
arg2 =
BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType());
size, gutils->getNewFromOriginal(call.getArgOperand(0)), nullptr,
"");
arg2 = BuilderZ.CreateBitCast(malins, castvals[1]);
Value *args[4] = {arg2,
gutils->getNewFromOriginal(call.getArgOperand(3)),
call.getArgOperand(0), call.getArgOperand(4)};
gutils->getNewFromOriginal(call.getArgOperand(0)),
gutils->getNewFromOriginal(call.getArgOperand(4))};
BuilderZ.CreateCall(
dmemcpy, args,
gutils->getInvertedBundles(&call,
Expand All @@ -4733,7 +4737,10 @@ class AdjointGenerator
BuilderZ, /*lookup*/ false));
}
if (xcache && ycache) {
auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0);
Type *cachetype =
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
auto valins1 =
BuilderZ.CreateInsertValue(UndefValue::get(cachetype), arg1, 0);
cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1);
} else if (xcache)
cacheval = arg1;
Expand All @@ -4758,6 +4765,16 @@ class AdjointGenerator
if (Mode == DerivativeMode::ReverseModeGradient &&
(!gutils->isConstantValue(call.getArgOperand(1)) ||
!gutils->isConstantValue(call.getArgOperand(3)))) {
Type *cachetype = nullptr;
if (xcache && ycache)
cachetype = StructType::get(call.getContext(),
ArrayRef<Type *>(castvals));
else if (xcache)
cachetype = castvals[0];
else {
assert(ycache);
cachetype = castvals[1];
}
cacheval = BuilderZ.CreatePHI(cachetype, 0);
}
cacheval =
Expand Down
11 changes: 8 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3266,9 +3266,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
"return or non-constant");
}

if (key.todiff->empty() && CustomErrorHandler) {
std::string s = ("No derivative found for " + key.todiff->getName()).str();
CustomErrorHandler(s.c_str());
if (key.todiff->empty()) {
std::string str =
("No derivative found for " + key.todiff->getName()).str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str());
} else {
llvm_unreachable(str.c_str());
}
}
assert(!key.todiff->empty());

Expand Down
198 changes: 198 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas/cblas_ddot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

declare dso_local void @__enzyme_autodiff(...)

declare double @cblas_ddot(i32, double*, i32, double*, i32)

define void @active(i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
ret void
}

define void @inactiveFirst(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, metadata !"enzyme_const", double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
ret void
}

define void @inactiveSecond(i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, double* noalias %n, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, metadata !"enzyme_const", double* noalias %n, i32 %incn)
ret void
}

define void @activeMod(i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
ret void
}

define void @inactiveModFirst(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, metadata !"enzyme_const", double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
ret void
}

define void @inactiveModSecond(i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, double* noalias %n, i32 %incn) {
entry:
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, metadata !"enzyme_const", double* noalias %n, i32 %incn)
ret void
}

define double @f(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, i32 %incn) {
entry:
%call = call double @cblas_ddot(i32 %len, double* %m, i32 %incm, double* %n, i32 %incn)
ret double %call
}

define double @modf(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, i32 %incn) {
entry:
%call = call double @f(i32 %len, double* %m, i32 %incm, double* %n, i32 %incn)
store double 0.000000e+00, double* %m
store double 0.000000e+00, double* %n
ret double %call
}


; CHECK: define void @active
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[active:.+]](

; CHECK: define void @inactiveFirst
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[inactiveFirst:.+]](

; CHECK: define void @inactiveSecond
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[inactiveSecond:.+]](


; CHECK: define void @activeMod
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[activeMod:.+]](

; CHECK: define void @inactiveModFirst
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[inactiveModFirst:.+]](

; CHECK: define void @inactiveModSecond
; CHECK-NEXT: entry
; CHECK-NEXT: call void @[[inactiveModSecond:.+]](


; CHECK: define internal void @[[active]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %m, i32 %incm, double* %"n'", i32 %incn)
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %n, i32 %incn, double* %"m'", i32 %incm)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal void @[[inactiveFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %m, i32 %incm, double* %"n'", i32 %incn)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal void @[[inactiveSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %n, i32 %incn, double* %"m'", i32 %incm)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal void @[[activeMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK: %call_augmented = call { double*, double* } @[[augMod:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, double* %"n'", i32 %incn)
; CHECK: call void @[[revMod:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* } %call_augmented)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal { double*, double* } @[[augMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = zext i32 %len to i64
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
; CHECK-NEXT: %2 = zext i32 %len to i64
; CHECK-NEXT: %mallocsize1 = mul i64 %2, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
; CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i64 %mallocsize1)
; CHECK-NEXT: %3 = bitcast i8* %malloccall2 to double*
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %3, double* %n, i32 %len, i32 %incn)
; CHECK-NEXT: %4 = insertvalue { double*, double* } undef, double* %1, 0
; CHECK-NEXT: %5 = insertvalue { double*, double* } %4, double* %3, 1
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: ret { double*, double* } %5
; CHECK-NEXT: }

; CHECK: define internal void @[[revMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* }
; CHECK-NEXT: entry:
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 0
; CHECK-NEXT: %2 = extractvalue { double*, double* } %0, 1
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %1, i32 1, double* %"n'", i32 %incn)
; CHECK-NEXT: %3 = bitcast double* %1 to i8*
; CHECK-NEXT: tail call void @free(i8* %3)
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %2, i32 1, double* %"m'", i32 %incm)
; CHECK-NEXT: %4 = bitcast double* %2 to i8*
; CHECK-NEXT: tail call void @free(i8* %4)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal void @[[inactiveModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK: %call_augmented = call double* @[[augModFirst:.+]](i32 %len, double* %m, i32 %incm, double* %n, double* %"n'", i32 %incn)
; CHECK: call void @[[revModFirst:.+]](i32 %len, double* %m, i32 %incm, double* %n, double* %"n'", i32 %incn, double %differeturn, double* %call_augmented)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal double* @[[augModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = zext i32 %len to i64
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: ret double* %1
; CHECK-NEXT: }

; CHECK: define internal void @[[revModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, double*
; CHECK-NEXT: entry:
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %0, i32 1, double* %"n'", i32 %incn)
; CHECK-NEXT: %1 = bitcast double* %0 to i8*
; CHECK-NEXT: tail call void @free(i8* %1)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal void @[[inactiveModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn)
; CHECK-NEXT: entry:
; CHECK: %call_augmented = call double* @[[augModSecond:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, i32 %incn)
; CHECK: call void @[[revModSecond:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, i32 %incn, double %differeturn, double* %call_augmented)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal double* @[[augModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = zext i32 %len to i64
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %n, i32 %len, i32 %incn)
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
; CHECK-NEXT: ret double* %1
; CHECK-NEXT: }

; CHECK: define internal void @[[revModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn, double*
; CHECK-NEXT: entry:
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %0, i32 1, double* %"m'", i32 %incm)
; CHECK-NEXT: %1 = bitcast double* %0 to i8*
; CHECK-NEXT: tail call void @free(i8* %1)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

Loading

0 comments on commit 44e7122

Please sign in to comment.