Skip to content

Commit

Permalink
fix increment operators in pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 24, 2023
1 parent 514e58c commit c86c6d2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 51 deletions.
28 changes: 0 additions & 28 deletions include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,6 @@ template <typename T> class array_ref {
/// Returns the reference to the underlying array
CUDA_HOST_DEVICE T& operator*() { return *m_arr; }

// Increment and decrement operators - update the underlying pointer.
/// Prefix increment operator.
CUDA_HOST_DEVICE array_ref<T>& operator++() {
++m_arr;
--m_size;
return *this;
}
/// Postfix increment operator.
CUDA_HOST_DEVICE array_ref<T> operator++(int) {
array_ref<T> tmp(*this);
++m_arr;
--m_size;
return tmp;
}
/// Prefix decrement operator.
CUDA_HOST_DEVICE array_ref<T>& operator--() {
--m_arr;
++m_size;
return *this;
}
/// Postfix decrement operator.
CUDA_HOST_DEVICE array_ref<T> operator--(int) {
array_ref<T> tmp(*this);
--m_arr;
++m_size;
return tmp;
}

// Arithmetic overloads
/// Divides the arrays element wise
template <typename U>
Expand Down
18 changes: 11 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2053,32 +2053,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, d);
} else if (opCode == UO_PostInc || opCode == UO_PostDec) {
diff = Visit(E, dfdx());
Expr* diff_dx = diff.getExpr_dx();
if (isPointerOp && isCladArrayType(diff_dx->getType()))
diff_dx = BuildArrayRefPtrRefExpr(diff_dx);
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);
addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse);
}

ResultRef = diff.getExpr_dx();
ResultRef = diff_dx;
valueForRevPass = diff.getRevSweepAsExpr();
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff);
} else if (opCode == UO_PreInc || opCode == UO_PreDec) {
diff = Visit(E, dfdx());
Expr* diff_dx = diff.getExpr_dx();
if (isPointerOp && isCladArrayType(diff_dx->getType()))
diff_dx = BuildArrayRefPtrRefExpr(diff_dx);
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);
addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse);
}
auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add
: BinaryOperatorKind::BO_Sub;
Expand Down
118 changes: 102 additions & 16 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,18 @@ double pointerParam(const double* arr, size_t n) {
// CHECK-NEXT: }

double pointerMultipleParams(const double* a, const double* b) {
double sum = b[2];
a = 1+a;
double sum = b[2];
b = a;
sum += a[0] + b[0];
return sum; // 2*a[1] + b[2]
a = 1+a;
++b;
sum += a[0] + b[0]; // += 2*a[1]
b++; a++;
sum += a[0] + b[0]; // += 2*a[2]
b--; a--;
sum += a[0] + b[0]; // += 2*a[1]
--b; --a;
sum += a[0] + b[0]; // += 2*a[0]
return sum; // 2*a[0] + 4*a[1] + 2*a[2] + b[2]
}

// CHECK: void pointerMultipleParams_grad(const double *a, const double *b, clad::array_ref<double> _d_a, clad::array_ref<double> _d_b) {
Expand All @@ -242,21 +249,96 @@ double pointerMultipleParams(const double* a, const double* b) {
// CHECK-NEXT: const double *_t2;
// CHECK-NEXT: clad::array_ref<double> _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double _t5;
// CHECK-NEXT: double _t6;
// CHECK-NEXT: double _t7;
// CHECK-NEXT: double sum = b[2];
// CHECK-NEXT: _t0 = a;
// CHECK-NEXT: _t1 = _d_a;
// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref();
// CHECK-NEXT: a = 1 + a;
// CHECK-NEXT: _t2 = b;
// CHECK-NEXT: _t3 = _d_b;
// CHECK-NEXT: _t0 = b;
// CHECK-NEXT: _t1 = _d_b;
// CHECK-NEXT: _d_b.ptr_ref() = _d_a.ptr_ref();
// CHECK-NEXT: b = a;
// CHECK-NEXT: _t4 = sum;
// CHECK-NEXT: _t2 = a;
// CHECK-NEXT: _t3 = _d_a;
// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref();
// CHECK-NEXT: a = 1 + a;
// CHECK-NEXT: ++_d_b.ptr_ref();
// CHECK-NEXT: ++b;
// CHECK-NEXT: _t4 = sum;
// CHECK-NEXT: sum += a[0] + b[0];
// CHECK-NEXT: _d_b.ptr_ref()++;
// CHECK-NEXT: b++;
// CHECK-NEXT: _d_a.ptr_ref()++;
// CHECK-NEXT: a++;
// CHECK-NEXT: _t5 = sum;
// CHECK-NEXT: sum += a[0] + b[0];
// CHECK-NEXT: _d_b.ptr_ref()--;
// CHECK-NEXT: b--;
// CHECK-NEXT: _d_a.ptr_ref()--;
// CHECK-NEXT: a--;
// CHECK-NEXT: _t6 = sum;
// CHECK-NEXT: sum += a[0] + b[0];
// CHECK-NEXT: --_d_b.ptr_ref();
// CHECK-NEXT: --b;
// CHECK-NEXT: --_d_a.ptr_ref();
// CHECK-NEXT: --a;
// CHECK-NEXT: _t7 = sum;
// CHECK-NEXT: sum += a[0] + b[0];
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_sum += 1;
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t7;
// CHECK-NEXT: double _r_d3 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d3;
// CHECK-NEXT: _d_a[0] += _r_d3;
// CHECK-NEXT: _d_b[0] += _r_d3;
// CHECK-NEXT: _d_sum -= _r_d3;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: ++a;
// CHECK-NEXT: ++_d_a.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: ++b;
// CHECK-NEXT: ++_d_b.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t6;
// CHECK-NEXT: double _r_d2 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d2;
// CHECK-NEXT: _d_a[0] += _r_d2;
// CHECK-NEXT: _d_b[0] += _r_d2;
// CHECK-NEXT: _d_sum -= _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: a++;
// CHECK-NEXT: _d_a.ptr_ref()++;
// CHECK-NEXT: _d_a.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: b++;
// CHECK-NEXT: _d_b.ptr_ref()++;
// CHECK-NEXT: _d_b.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t5;
// CHECK-NEXT: double _r_d1 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d1;
// CHECK-NEXT: _d_a[0] += _r_d1;
// CHECK-NEXT: _d_b[0] += _r_d1;
// CHECK-NEXT: _d_sum -= _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: a--;
// CHECK-NEXT: _d_a.ptr_ref()--;
// CHECK-NEXT: _d_a.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: b--;
// CHECK-NEXT: _d_b.ptr_ref()--;
// CHECK-NEXT: _d_b.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t4;
// CHECK-NEXT: double _r_d0 = _d_sum;
// CHECK-NEXT: _d_sum += _r_d0;
Expand All @@ -265,12 +347,16 @@ double pointerMultipleParams(const double* a, const double* b) {
// CHECK-NEXT: _d_sum -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: b = _t2;
// CHECK-NEXT: _d_b = _t3;
// CHECK-NEXT: --b;
// CHECK-NEXT: --_d_b.ptr_ref();
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: a = _t2;
// CHECK-NEXT: _d_a = _t3;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: a = _t0;
// CHECK-NEXT: _d_a = _t1;
// CHECK-NEXT: b = _t0;
// CHECK-NEXT: _d_b = _t1;
// CHECK-NEXT: }
// CHECK-NEXT: _d_b[2] += _d_sum;
// CHECK-NEXT: }
Expand Down Expand Up @@ -367,6 +453,6 @@ int main() {
clad::array_ref<double> d_b_arr_ref(d_b_arr, 5);
d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0;
d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 2.00 0.00 0.00 0.00
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 2.00 4.00 2.00 0.00 0.00
printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00
}

0 comments on commit c86c6d2

Please sign in to comment.