Skip to content

Commit

Permalink
Fix rvalue reference name in pullbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 17, 2024
1 parent eb7d7f2 commit 71582a4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
7 changes: 7 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
if (PVD->getType()->isRValueReferenceType()) {
IdentifierInfo* PVDII = PVD->getIdentifier();
IdentifierInfo* RValueName = CreateUniqueIdentifier("_r");
auto* param = CloneParmVarDecl(PVD, PVDII);
param->setDeclName(RValueName);
PVD = param;
}
StmtDiff argDiff{};
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
Expand Down
20 changes: 10 additions & 10 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -817,28 +817,28 @@ int main() {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: std::vector<double> _t0 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0);
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r1);
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: double _t3 = _t2.value;
// CHECK-NEXT: _t2.value = x * x;
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r3);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2);
// CHECK-NEXT: {{.*}}size_type _r3 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t2.value = _t3;
// CHECK-NEXT: double _r_d0 = _t2.adjoint;
// CHECK-NEXT: _t2.adjoint = 0{{.*}};
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: {{.*}}value_type _r1 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 71582a4

Please sign in to comment.