Skip to content

Commit

Permalink
Support std::initializer_list in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 22, 2024
1 parent 41f9b1c commit 9167618
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
11 changes: 11 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,17 @@ void constructor_pullback(::std::vector<T>* v, S count, U val,
d_v->clear();
}

// A specialization for std::initializer_list (which is replaced with
// clad::array).
template <typename T>
void constructor_pullback(::std::vector<T>* v, clad::array<T> init,
::std::vector<T>* d_v, clad::array<T>* d_init) {
for (unsigned i = 0; i < init.size(); ++i) {
(*d_init)[i] += (*d_v)[i];
(*d_v)[i] = 0;
}
}

template <typename T, typename U, typename dU>
void assign_pullback(::std::vector<T>* v,
typename ::std::vector<T>::size_type n, U /*val*/,
Expand Down
7 changes: 5 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4284,8 +4284,11 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// double _r0 = 0;
// SomeClass_pullback(c, u, ..., &_d_c, &_r0, ...);
// _d_u += _r0;
QualType dArgTy = getNonConstType(ArgTy, m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
QualType dArgTy = getNonConstType(CloneType(ArgTy), m_Context, m_Sema);
Expr* init = getStdInitListSizeExpr(arg);
if (!init)
init = getZeroInit(dArgTy);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", init);
prePullbackCallStmts.push_back(BuildDeclStmt(dArgDecl));
adjointArg = BuildDeclRef(dArgDecl);
argDiff = Visit(arg, BuildDeclRef(dArgDecl));
Expand Down
32 changes: 31 additions & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ double fn21(double x, double y) {
return a[0];
}

double fn22(double u, double v) {
std::vector<double> ls{u, v};
return ls[1] - 2 * ls[0];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
Expand All @@ -198,6 +203,7 @@ int main() {
INIT_GRADIENT(fn19);
INIT_GRADIENT(fn20);
INIT_GRADIENT(fn21);
INIT_GRADIENT(fn22);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
Expand All @@ -211,6 +217,7 @@ int main() {
TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00}
TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00}
TEST_GRADIENT(fn21, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
TEST_GRADIENT(fn22, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {-2.00, 1.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -848,4 +855,27 @@ int main() {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn22_grad(double u, double v, double *_d_u, double *_d_v) {
// CHECK-NEXT: std::vector<double> ls{u, v};
// CHECK-NEXT: std::vector<double> _d_ls(ls);
// CHECK-NEXT: clad::zero_init(_d_ls);
// CHECK-NEXT: std::vector<double> _t0 = ls;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, _r1);
// CHECK-NEXT: std::vector<double> _t3 = ls;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 0, &_d_ls, _r2);
// CHECK-NEXT: {{.*}}value_type _t2 = _t4.value;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t0, 1, 1, &_d_ls, &_r1);
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t3, 0, 2 * -1, &_d_ls, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _r0 = {{2U|2UL|2ULL}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&ls, {u, v}, &_d_ls, &_r0);
// CHECK-NEXT: *_d_u += _r0[0];
// CHECK-NEXT: *_d_v += _r0[1];
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 9167618

Please sign in to comment.