Skip to content

Commit

Permalink
Move Assertions (rust-lang#306)
Browse files Browse the repository at this point in the history
* move assertions

* don't use lookup in invertPointerM when running in forward mode
  • Loading branch information
tgymnich authored Sep 19, 2021
1 parent 1010096 commit 51988b0
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 97 deletions.
85 changes: 53 additions & 32 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,8 @@ class AdjointGenerator
if (constantval) {
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
} else {
auto dif1 =
Builder2.CreateLoad(gutils->invertPointerM(orig_ptr, Builder2));
auto dif1 = Builder2.CreateLoad(
lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2));
#if LLVM_VERSION_MAJOR >= 10
dif1->setAlignment(SI.getAlign());
#else
Expand Down Expand Up @@ -1340,8 +1340,6 @@ class AdjointGenerator

std::vector<SelectInst *> addToDiffe(Value *val, Value *dif,
IRBuilder<> &Builder, Type *T) {
assert(Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined);
return ((DiffeGradientUtils *)gutils)->addToDiffe(val, dif, Builder, T);
}

Expand Down Expand Up @@ -1928,7 +1926,8 @@ class AdjointGenerator
// (which thus == src and may be illegal)
if (gutils->isConstantValue(orig_src)) {
SmallVector<Value *, 4> args;
args.push_back(gutils->invertPointerM(orig_dst, Builder2));
args.push_back(
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2));
if (args[0]->getType()->isIntegerTy())
args[0] = Builder2.CreateIntToPtr(
args[0], Type::getInt8PtrTy(MTI->getContext()));
Expand Down Expand Up @@ -1958,7 +1957,8 @@ class AdjointGenerator

} else {
SmallVector<Value *, 4> args;
auto dsto = gutils->invertPointerM(orig_dst, Builder2);
auto dsto =
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2);
if (dsto->getType()->isIntegerTy())
dsto = Builder2.CreateIntToPtr(
dsto, Type::getInt8PtrTy(dsto->getContext()));
Expand All @@ -1968,7 +1968,8 @@ class AdjointGenerator
if (offset != 0)
dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset);
args.push_back(Builder2.CreatePointerCast(dsto, secretpt));
auto srco = gutils->invertPointerM(orig_src, Builder2);
auto srco =
lookup(gutils->invertPointerM(orig_src, Builder2), Builder2);
if (srco->getType()->isIntegerTy())
srco = Builder2.CreateIntToPtr(
srco, Type::getInt8PtrTy(srco->getContext()));
Expand Down Expand Up @@ -2949,7 +2950,8 @@ class AdjointGenerator
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
args.push_back(
gutils->invertPointerM(call.getArgOperand(i), Builder2));
lookup(gutils->invertPointerM(call.getArgOperand(i), Builder2),
Builder2));
}
pre_args.push_back(
gutils->invertPointerM(call.getArgOperand(i), BuilderZ));
Expand Down Expand Up @@ -3715,7 +3717,8 @@ class AdjointGenerator
llvm::errs() << " warning could not automatically determine mpi "
"status type, assuming [24 x i8]\n";
}
Value *d_req = gutils->invertPointerM(call.getOperand(6), Builder2);
Value *d_req = lookup(
gutils->invertPointerM(call.getOperand(6), Builder2), Builder2);
Value *args[] = {/*req*/ d_req,
/*status*/ IRBuilder<>(gutils->inversionAllocs)
.CreateAlloca(statusType)};
Expand Down Expand Up @@ -3769,7 +3772,8 @@ class AdjointGenerator
ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0);
auto volatile_arg = ConstantInt::getFalse(Builder2.getContext());
assert(!gutils->isConstantValue(call.getOperand(0)));
auto dbuf = gutils->invertPointerM(call.getOperand(0), Builder2);
auto dbuf = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
if (dbuf->getType()->isIntegerTy())
dbuf = Builder2.CreateIntToPtr(
dbuf, Type::getInt8PtrTy(call.getContext()));
Expand All @@ -3790,8 +3794,8 @@ class AdjointGenerator
memset->addParamAttr(0, Attribute::NonNull);
} else if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
assert(!gutils->isConstantValue(call.getOperand(0)));
Value *shadow =
gutils->invertPointerM(call.getOperand(0), Builder2);
Value *shadow = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
if (Mode == DerivativeMode::ReverseModeCombined) {
assert(firstallocation);
firstallocation = lookup(firstallocation, Builder2);
Expand Down Expand Up @@ -3830,7 +3834,8 @@ class AdjointGenerator
getReverseBuilder(Builder2);

assert(!gutils->isConstantValue(call.getOperand(0)));
Value *d_req = gutils->invertPointerM(call.getOperand(0), Builder2);
Value *d_req = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
if (d_req->getType()->isIntegerTy()) {
d_req = Builder2.CreateIntToPtr(
d_req,
Expand Down Expand Up @@ -3908,8 +3913,8 @@ class AdjointGenerator
assert(!gutils->isConstantValue(call.getOperand(1)));
Value *count =
lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2);
Value *d_req_orig =
gutils->invertPointerM(call.getOperand(1), Builder2);
Value *d_req_orig = lookup(
gutils->invertPointerM(call.getOperand(1), Builder2), Builder2);
if (d_req_orig->getType()->isIntegerTy()) {
d_req_orig = Builder2.CreateIntToPtr(
d_req_orig,
Expand Down Expand Up @@ -4007,7 +4012,8 @@ class AdjointGenerator
Mode == DerivativeMode::ReverseModeCombined) {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
Value *shadow = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);

if (shadow->getType()->isIntegerTy())
shadow = Builder2.CreateIntToPtr(
Expand Down Expand Up @@ -4095,7 +4101,8 @@ class AdjointGenerator
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);

Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
Value *shadow = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
if (shadow->getType()->isIntegerTy())
shadow = Builder2.CreateIntToPtr(
shadow, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -4165,7 +4172,8 @@ class AdjointGenerator
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);

Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
Value *shadow = lookup(
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
if (shadow->getType()->isIntegerTy())
shadow = Builder2.CreateIntToPtr(
shadow, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -4365,11 +4373,13 @@ class AdjointGenerator
report_fatal_error("unhandled mpi_allreduce op");
}

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Value *shadow_recvbuf =
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
if (shadow_recvbuf->getType()->isIntegerTy())
shadow_recvbuf = Builder2.CreateIntToPtr(
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
Value *shadow_sendbuf =
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
if (shadow_sendbuf->getType()->isIntegerTy())
shadow_sendbuf = Builder2.CreateIntToPtr(
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -4552,11 +4562,13 @@ class AdjointGenerator
report_fatal_error("unhandled mpi_allreduce op");
}

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Value *shadow_recvbuf =
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
if (shadow_recvbuf->getType()->isIntegerTy())
shadow_recvbuf = Builder2.CreateIntToPtr(
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
Value *shadow_sendbuf =
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
if (shadow_sendbuf->getType()->isIntegerTy())
shadow_sendbuf = Builder2.CreateIntToPtr(
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -4665,11 +4677,13 @@ class AdjointGenerator
Value *orig_root = call.getOperand(6);
Value *orig_comm = call.getOperand(7);

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Value *shadow_recvbuf =
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
if (shadow_recvbuf->getType()->isIntegerTy())
shadow_recvbuf = Builder2.CreateIntToPtr(
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
Value *shadow_sendbuf =
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
if (shadow_sendbuf->getType()->isIntegerTy())
shadow_sendbuf = Builder2.CreateIntToPtr(
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -4820,11 +4834,13 @@ class AdjointGenerator
Value *orig_root = call.getOperand(6);
Value *orig_comm = call.getOperand(7);

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Value *shadow_recvbuf =
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
if (shadow_recvbuf->getType()->isIntegerTy())
shadow_recvbuf = Builder2.CreateIntToPtr(
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
Value *shadow_sendbuf =
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
if (shadow_sendbuf->getType()->isIntegerTy())
shadow_sendbuf = Builder2.CreateIntToPtr(
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -5008,11 +5024,13 @@ class AdjointGenerator
Value *orig_recvcount = call.getOperand(4);
Value *orig_comm = call.getOperand(6);

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Value *shadow_recvbuf =
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
if (shadow_recvbuf->getType()->isIntegerTy())
shadow_recvbuf = Builder2.CreateIntToPtr(
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
Value *shadow_sendbuf =
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
if (shadow_sendbuf->getType()->isIntegerTy())
shadow_sendbuf = Builder2.CreateIntToPtr(
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
Expand Down Expand Up @@ -5502,7 +5520,8 @@ class AdjointGenerator
diffe(orig, Builder2),
structarg1,
estride,
gutils->invertPointerM(orig->getArgOperand(3), Builder2),
lookup(gutils->invertPointerM(orig->getArgOperand(3), Builder2),
Builder2),
lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
Builder2)};
firstdcall = Builder2.CreateCall(derivcall, args1);
Expand All @@ -5520,7 +5539,8 @@ class AdjointGenerator
diffe(orig, Builder2),
structarg2,
estride,
gutils->invertPointerM(orig->getArgOperand(1), Builder2),
lookup(gutils->invertPointerM(orig->getArgOperand(1), Builder2),
Builder2),
lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
Builder2)};
seconddcall = Builder2.CreateCall(derivcall, args2);
Expand Down Expand Up @@ -7267,7 +7287,8 @@ class AdjointGenerator
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
args.push_back(
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
lookup(gutils->invertPointerM(orig->getArgOperand(i), Builder2),
Builder2));
}
pre_args.push_back(
gutils->invertPointerM(orig->getArgOperand(i), BuilderZ));
Expand Down Expand Up @@ -7702,7 +7723,7 @@ class AdjointGenerator
llvm::errs() << " orig: " << *orig << " callval: " << *callval << "\n";
}
assert(!gutils->isConstantValue(callval));
newcalled = gutils->invertPointerM(callval, Builder2);
newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2);

auto ft = cast<FunctionType>(
cast<PointerType>(callval->getType())->getElementType());
Expand Down
Loading

0 comments on commit 51988b0

Please sign in to comment.