Skip to content

Commit

Permalink
Prevent Clad from trying to create a void zero literal (#989)
Browse files Browse the repository at this point in the history
Previously, clad used to try to synthesize a void zero literal
when differentiating a call to a void function with
literal arguments in the forward mode. This caused it to crash.

Fixes: #988
  • Loading branch information
gojakuch authored Jul 21, 2024
1 parent f4dcf5c commit 4b5bec8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 5 deletions.
3 changes: 1 addition & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,8 +1216,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
validLoc, llvm::MutableArrayRef<Expr*>(CallArgs),
validLoc)
.get();
auto* zero = ConstantFolder::synthesizeLiteral(CE->getType(), m_Context,
/*val=*/0);
auto* zero = getZeroInit(CE->getType());
return StmtDiff(call, zero);
}
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ namespace clad {

Expr* VisitorBase::getZeroInit(QualType T) {
// FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1.
if (T->isVoidType())
return nullptr;
if (T->isScalarType()) {
ExprResult Zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down
2 changes: 1 addition & 1 deletion test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ float f_literal_args_func(float x, float y, float *z) {
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: return _d_x * _t0 + x * 0;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
Expand Down
16 changes: 16 additions & 0 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ double test_9(double x) {
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

void some_important_void_func(double y) {
assert(y >= 1);
}

double test_10(double x) {
some_important_void_func(1);
return x;
}

// CHECK: double test_10_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: some_important_void_func(1);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

int main () {
clad::differentiate(test_1, 0);
clad::differentiate(test_2, 0);
Expand All @@ -196,6 +211,7 @@ int main () {
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
clad::differentiate<clad::opts::diagonal_only>(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}}
clad::differentiate(test_9);
clad::differentiate(test_10);
return 0;

// CHECK: void increment_pushforward(int &i, int &_d_i) {
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ double fn4(double i, double j) {
// CHECK: double fn4_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = nonRealParamFn(0, 0);
// CHECK-NEXT: _d_res += _d_i;
// CHECK-NEXT: res += i;
Expand Down Expand Up @@ -266,7 +266,7 @@ double fn8(double i, double j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = check_and_return_pushforward(_t0.value, 'a', _t0.pushforward, 0);
// CHECK-NEXT: double &_t2 = _t1.value;
// CHECK-NEXT: double _t3 = std::tanh(1.);
// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0.;
// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0;
// CHECK-NEXT: }

double g (double x) { return x; }
Expand Down

0 comments on commit 4b5bec8

Please sign in to comment.