From 514e58c9ba655ed582e49f7b8bd55ca321fc0d7e Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sat, 23 Dec 2023 12:19:42 +0530 Subject: [PATCH] add more tests --- test/Gradient/Pointers.C | 104 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 538c2e8eb..77339c178 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -180,6 +180,101 @@ double pointerParam(const double* arr, size_t n) { return sum; } +// CHECK: void pointerParam_grad_0(const double *arr, size_t n, clad::array_ref _d_arr) { +// CHECK-NEXT: size_t _d_n = 0; +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: size_t _d_i = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: size_t *_d_j = 0; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape > _t5 = {}; +// CHECK-NEXT: double sum = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: for (size_t i = 0; i < n; ++i) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: _d_j = &_d_i; +// CHECK-NEXT: clad::push(_t1, _d_j); +// CHECK-NEXT: size_t *j = &i; +// CHECK-NEXT: clad::push(_t3, sum); +// CHECK-NEXT: sum += arr[0] * (*j); +// CHECK-NEXT: clad::push(_t4, arr); +// CHECK-NEXT: clad::push(_t5, _d_arr); +// CHECK-NEXT: _d_arr.ptr_ref() = _d_arr.ptr_ref() + 1; +// CHECK-NEXT: arr = arr + 1; +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: --i; +// CHECK-NEXT: size_t *_t2 = clad::pop(_t1); +// CHECK-NEXT: { +// CHECK-NEXT: arr = clad::pop(_t4); +// CHECK-NEXT: _d_arr = clad::pop(_t5); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = clad::pop(_t3); +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d0; +// CHECK-NEXT: double _r0 = _r_d0 * (*j); +// CHECK-NEXT: _d_arr[0] += _r0; +// CHECK-NEXT: double _r1 = arr[0] * _r_d0; +// CHECK-NEXT: *_t2 += _r1; +// CHECK-NEXT: _d_sum -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double pointerMultipleParams(const double* a, const double* b) { + double sum = b[2]; + a = 1+a; + b = a; + sum += a[0] + b[0]; + return sum; // 2*a[1] + b[2] +} + +// CHECK: void pointerMultipleParams_grad(const double *a, const double *b, clad::array_ref _d_a, clad::array_ref _d_b) { +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: const double *_t0; +// CHECK-NEXT: clad::array_ref _t1; +// CHECK-NEXT: const double *_t2; +// CHECK-NEXT: clad::array_ref _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double sum = b[2]; +// CHECK-NEXT: _t0 = a; +// CHECK-NEXT: _t1 = _d_a; +// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref(); +// CHECK-NEXT: a = 1 + a; +// CHECK-NEXT: _t2 = b; +// CHECK-NEXT: _t3 = _d_b; +// CHECK-NEXT: _d_b.ptr_ref() = _d_a.ptr_ref(); +// CHECK-NEXT: b = a; +// CHECK-NEXT: _t4 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t4; +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d0; +// CHECK-NEXT: _d_a[0] += _r_d0; +// CHECK-NEXT: _d_b[0] += _r_d0; +// CHECK-NEXT: _d_sum -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b = _t2; +// CHECK-NEXT: _d_b = _t3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a = _t0; +// CHECK-NEXT: _d_a = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: _d_b[2] += _d_sum; +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -265,4 +360,13 @@ int main() { d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; d_pointerParam.execute(arr, 5, d_arr_ref); printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 1.00 2.00 3.00 4.00 + + auto d_pointerMultipleParams = clad::gradient(pointerMultipleParams); + double b_arr[5] = {1, 2, 3, 4, 5}; + double d_b_arr[5] = {0, 0, 0, 0, 0}; + clad::array_ref d_b_arr_ref(d_b_arr, 5); + d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; + d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 2.00 0.00 0.00 0.00 + printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00 }