diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fb52ca874..433e8132a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1480,6 +1480,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { + if (!dfdx()) + return StmtDiff(Clone(ILE)); QualType ILEType = ILE->getType(); llvm::SmallVector clonedExprs(ILE->getNumInits()); if (isArrayOrPointerType(ILEType)) { diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index b9c880682..846b9fee5 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -383,6 +383,27 @@ double fn11(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +struct MyStruct{ + double a; + double b; +}; + +MyStruct fn12(MyStruct s) { + s = {2 * s.a, 2 * s.b + 2}; + return s; +} + +// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) { +// CHECK-NEXT: MyStruct _t0 = s; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {}); +// CHECK-NEXT: { +// CHECK-NEXT: MyStruct _r0 = {}; +// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0); +// CHECK-NEXT: (*_d_s).a += 2 * _r0.a; +// CHECK-NEXT: (*_d_s).b += 2 * _r0.b; +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -391,6 +412,10 @@ void print(const Tangent& t) { } } +void print(const MyStruct& s) { + printf("{%.2f, %.2f}\n", s.a, s.b); +} + int main() { pairdd p(3, 5), d_p; double i = 3, d_i, d_j; @@ -425,6 +450,10 @@ int main() { TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00} TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, -14, &d_i, &d_j); // CHECK-EXEC: {1.00, -1.00} + MyStruct s = {1.0, 2.0}, d_s = {1.0, 1.0}; + auto fn12_test = clad::gradient(fn12); + fn12_test.execute(s, &d_s); + print(d_s); // CHECK-EXEC: {2.00, 2.00} } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) { @@ -546,4 +575,31 @@ int main() { // CHECK-NEXT: *_d_x += _d_y; // CHECK-NEXT: (*_d_t).data[0] += _d_y; // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&_r0, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d__r0) noexcept { +// CHECK-NEXT: double _t0 = this->a; +// CHECK-NEXT: this->a = _r0.a; +// CHECK-NEXT: double _t1 = this->b; +// CHECK-NEXT: this->b = _r0.b; +// CHECK-NEXT: { +// CHECK-NEXT: this->b = _t1; +// CHECK-NEXT: double _r_d1 = (*_d_this).b; +// CHECK-NEXT: (*_d_this).b = 0.; +// CHECK-NEXT: (*_d__r0).b += _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: this->a = _t0; +// CHECK-NEXT: double _r_d0 = (*_d_this).a; +// CHECK-NEXT: (*_d_this).a = 0.; +// CHECK-NEXT: (*_d__r0).a += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT:} + +// CHECK: inline constexpr clad::ValueAndAdjoint operator_equal_forw(MyStruct &&_r0, MyStruct *_d_this, MyStruct &&_d__r0) noexcept { +// CHECK-NEXT: double _t0 = this->a; +// CHECK-NEXT: this->a = _r0.a; +// CHECK-NEXT: double _t1 = this->b; +// CHECK-NEXT: this->b = _r0.b; +// CHECK-NEXT: return {*this, (*_d_this)}; +// CHECK-NEXT:} \ No newline at end of file