Skip to content

Commit

Permalink
Fix some cases of std::vector::push_back in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Sep 27, 2024
1 parent 2e5560e commit 71ca82a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
16 changes: 15 additions & 1 deletion include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,27 @@ size_pushforward(const ::std::array<T, N>* a,
// vector reverse mode
// more can be found in tests: test/Gradient/STLCustomDerivatives.C

template <typename T, typename U, typename pU>
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
pU /*d_val*/) {
v->push_back(val);
d_v->push_back(0);
}

template <typename T, typename U>
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U d_val) {
U /*d_val*/) {
v->push_back(val);
d_v->push_back(0);
}

template <typename T, typename U, typename pU>
void push_back_pullback(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
pU* d_val) {
*d_val += d_v->back();
d_v->pop_back();
}

template <typename T, typename U>
void push_back_pullback(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U* d_val) {
Expand Down
39 changes: 39 additions & 0 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ double fn20(double x, double y) {
return res; // 11x+y
}

double fn21(double x, double y) {
std::vector<double> a;
a.push_back(0);
a[0] = x*x;
return a[0];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
Expand All @@ -190,6 +197,7 @@ int main() {
INIT_GRADIENT(fn18);
INIT_GRADIENT(fn19);
INIT_GRADIENT(fn20);
INIT_GRADIENT(fn21);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
Expand All @@ -202,6 +210,7 @@ int main() {
TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {2.00, 0.00}
TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00}
TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00}
TEST_GRADIENT(fn21, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -841,3 +850,33 @@ int main() {
// CHECK-NEXT: {{.*}}reserve_pullback(&_t0, 10, &_d_v, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn21_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: std::vector<double> _t0 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0);
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1);
// CHECK-NEXT: double _t3 = _t2.value;
// CHECK-NEXT: _t2.value = x * x;
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t2.value = _t3;
// CHECK-NEXT: double _r_d0 = _t2.adjoint;
// CHECK-NEXT: _t2.adjoint = 0{{.*}};
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 71ca82a

Please sign in to comment.