From 8cdcaa9587a4f5de5dd0cf2541cdad53619d5758 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Mon, 25 Dec 2023 00:30:39 +0530 Subject: [PATCH] fix increment operators in pointer --- include/clad/Differentiator/ArrayRef.h | 28 ----- lib/Differentiator/ReverseModeVisitor.cpp | 18 ++-- test/Gradient/Pointers.C | 118 +++++++++++++++++++--- 3 files changed, 113 insertions(+), 51 deletions(-) diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index d38e93eb7..6fd89247d 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -68,34 +68,6 @@ template class array_ref { /// Returns the reference to the underlying array CUDA_HOST_DEVICE T& operator*() { return *m_arr; } - // Increment and decrement operators - update the underlying pointer. - /// Prefix increment operator. - CUDA_HOST_DEVICE array_ref& operator++() { - ++m_arr; - --m_size; - return *this; - } - /// Postfix increment operator. - CUDA_HOST_DEVICE array_ref operator++(int) { - array_ref tmp(*this); - ++m_arr; - --m_size; - return tmp; - } - /// Prefix decrement operator. - CUDA_HOST_DEVICE array_ref& operator--() { - --m_arr; - ++m_size; - return *this; - } - /// Postfix decrement operator. - CUDA_HOST_DEVICE array_ref operator--(int) { - array_ref tmp(*this); - --m_arr; - ++m_size; - return tmp; - } - // Arithmetic overloads /// Divides the arrays element wise template diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b8eed80b8..45bda28be 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2051,32 +2051,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, d); } else if (opCode == UO_PostInc || opCode == UO_PostDec) { diff = Visit(E, dfdx()); + Expr* diff_dx = diff.getExpr_dx(); + if (isPointerOp && isCladArrayType(diff_dx->getType())) + diff_dx = BuildArrayRefPtrRefExpr(diff_dx); if (isPointerOp) - addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()), - direction::forward); + addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); if (isPointerOp) - addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse); + addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse); } - ResultRef = diff.getExpr_dx(); + ResultRef = diff_dx; valueForRevPass = diff.getRevSweepAsExpr(); if (m_ExternalSource) m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); + Expr* diff_dx = diff.getExpr_dx(); + if (isPointerOp && isCladArrayType(diff_dx->getType())) + diff_dx = BuildArrayRefPtrRefExpr(diff_dx); if (isPointerOp) - addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()), - direction::forward); + addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); if (isPointerOp) - addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse); + addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse); } auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add : BinaryOperatorKind::BO_Sub; diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 77339c178..65da7df43 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -228,11 +228,18 @@ double pointerParam(const double* arr, size_t n) { // CHECK-NEXT: } double pointerMultipleParams(const double* a, const double* b) { - double sum = b[2]; - a = 1+a; + double sum = b[2]; b = a; - sum += a[0] + b[0]; - return sum; // 2*a[1] + b[2] + a = 1+a; + ++b; + sum += a[0] + b[0]; // += 2*a[1] + b++; a++; + sum += a[0] + b[0]; // += 2*a[2] + b--; a--; + sum += a[0] + b[0]; // += 2*a[1] + --b; --a; + sum += a[0] + b[0]; // += 2*a[0] + return sum; // 2*a[0] + 4*a[1] + 2*a[2] + b[2] } // CHECK: void pointerMultipleParams_grad(const double *a, const double *b, clad::array_ref _d_a, clad::array_ref _d_b) { @@ -242,21 +249,96 @@ double pointerMultipleParams(const double* a, const double* b) { // CHECK-NEXT: const double *_t2; // CHECK-NEXT: clad::array_ref _t3; // CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: double _t7; // CHECK-NEXT: double sum = b[2]; -// CHECK-NEXT: _t0 = a; -// CHECK-NEXT: _t1 = _d_a; -// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref(); -// CHECK-NEXT: a = 1 + a; -// CHECK-NEXT: _t2 = b; -// CHECK-NEXT: _t3 = _d_b; +// CHECK-NEXT: _t0 = b; +// CHECK-NEXT: _t1 = _d_b; // CHECK-NEXT: _d_b.ptr_ref() = _d_a.ptr_ref(); // CHECK-NEXT: b = a; -// CHECK-NEXT: _t4 = sum; +// CHECK-NEXT: _t2 = a; +// CHECK-NEXT: _t3 = _d_a; +// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref(); +// CHECK-NEXT: a = 1 + a; +// CHECK-NEXT: ++_d_b.ptr_ref(); +// CHECK-NEXT: ++b; +// CHECK-NEXT: _t4 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: _d_b.ptr_ref()++; +// CHECK-NEXT: b++; +// CHECK-NEXT: _d_a.ptr_ref()++; +// CHECK-NEXT: a++; +// CHECK-NEXT: _t5 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: _d_b.ptr_ref()--; +// CHECK-NEXT: b--; +// CHECK-NEXT: _d_a.ptr_ref()--; +// CHECK-NEXT: a--; +// CHECK-NEXT: _t6 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: --_d_b.ptr_ref(); +// CHECK-NEXT: --b; +// CHECK-NEXT: --_d_a.ptr_ref(); +// CHECK-NEXT: --a; +// CHECK-NEXT: _t7 = sum; // CHECK-NEXT: sum += a[0] + b[0]; // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: _d_sum += 1; // CHECK-NEXT: { +// CHECK-NEXT: sum = _t7; +// CHECK-NEXT: double _r_d3 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d3; +// CHECK-NEXT: _d_a[0] += _r_d3; +// CHECK-NEXT: _d_b[0] += _r_d3; +// CHECK-NEXT: _d_sum -= _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ++a; +// CHECK-NEXT: ++_d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ++b; +// CHECK-NEXT: ++_d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t6; +// CHECK-NEXT: double _r_d2 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d2; +// CHECK-NEXT: _d_a[0] += _r_d2; +// CHECK-NEXT: _d_b[0] += _r_d2; +// CHECK-NEXT: _d_sum -= _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a++; +// CHECK-NEXT: _d_a.ptr_ref()++; +// CHECK-NEXT: _d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b++; +// CHECK-NEXT: _d_b.ptr_ref()++; +// CHECK-NEXT: _d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t5; +// CHECK-NEXT: double _r_d1 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d1; +// CHECK-NEXT: _d_a[0] += _r_d1; +// CHECK-NEXT: _d_b[0] += _r_d1; +// CHECK-NEXT: _d_sum -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a--; +// CHECK-NEXT: _d_a.ptr_ref()--; +// CHECK-NEXT: _d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b--; +// CHECK-NEXT: _d_b.ptr_ref()--; +// CHECK-NEXT: _d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { // CHECK-NEXT: sum = _t4; // CHECK-NEXT: double _r_d0 = _d_sum; // CHECK-NEXT: _d_sum += _r_d0; @@ -265,12 +347,16 @@ double pointerMultipleParams(const double* a, const double* b) { // CHECK-NEXT: _d_sum -= _r_d0; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: b = _t2; -// CHECK-NEXT: _d_b = _t3; +// CHECK-NEXT: --b; +// CHECK-NEXT: --_d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a = _t2; +// CHECK-NEXT: _d_a = _t3; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: a = _t0; -// CHECK-NEXT: _d_a = _t1; +// CHECK-NEXT: b = _t0; +// CHECK-NEXT: _d_b = _t1; // CHECK-NEXT: } // CHECK-NEXT: _d_b[2] += _d_sum; // CHECK-NEXT: } @@ -367,6 +453,6 @@ int main() { clad::array_ref d_b_arr_ref(d_b_arr, 5); d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref); - printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 2.00 0.00 0.00 0.00 + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 2.00 4.00 2.00 0.00 0.00 printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00 }