From 71582a4eaaa01189f295aa94ba416b61430082b9 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 17 Nov 2024 15:40:13 +0200 Subject: [PATCH] Fix rvalue reference name in pullbacks --- lib/Differentiator/ReverseModeVisitor.cpp | 7 +++++++ test/Gradient/STLCustomDerivatives.C | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 19b902ce1..2498bf46f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1893,6 +1893,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* arg = CE->getArg(i); const auto* PVD = FD->getParamDecl( i - static_cast(isMethodOperatorCall)); + if (PVD->getType()->isRValueReferenceType()) { + IdentifierInfo* PVDII = PVD->getIdentifier(); + IdentifierInfo* RValueName = CreateUniqueIdentifier("_r"); + auto* param = CloneParmVarDecl(PVD, PVDII); + param->setDeclName(RValueName); + PVD = param; + } StmtDiff argDiff{}; // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 58ae5d64d..a2612c60a 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -817,16 +817,16 @@ int main() { // CHECK-NEXT: std::vector _d_a({}); // CHECK-NEXT: std::vector a; // CHECK-NEXT: std::vector _t0 = a; -// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0); +// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r1); // CHECK-NEXT: std::vector _t1 = a; -// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1); +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2); // CHECK-NEXT: double _t3 = _t2.value; // CHECK-NEXT: _t2.value = x * x; // CHECK-NEXT: std::vector _t4 = a; -// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2); +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r3); // CHECK-NEXT: { -// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; -// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2); +// CHECK-NEXT: {{.*}}size_type _r3 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r3); // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: _t2.value = _t3; @@ -834,11 +834,11 @@ int main() { // CHECK-NEXT: _t2.adjoint = 0{{.*}}; // CHECK-NEXT: *_d_x += _r_d0 * x; // CHECK-NEXT: *_d_x += x * _r_d0; -// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}}; -// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1); +// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r2); // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: {{.*}}value_type _r0 = 0.; -// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0); +// CHECK-NEXT: {{.*}}value_type _r1 = 0.; +// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r1); // CHECK-NEXT: } -// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file