Skip to content

Commit

Permalink
Temporarily re-enable adjoints for non-differentiable parameters in e…
Browse files Browse the repository at this point in the history
…rror estimation.
  • Loading branch information
PetroZarytskyi committed Apr 30, 2024
1 parent 9619a4a commit 59be2a3
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 17 deletions.
12 changes: 10 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (request.Args) {
DVI = request.DVI;
for (const auto& dParam : DVI)
if (utils::IsDifferentiableType(dParam.param->getType()))
// no need to create adjoints for non-differentiable parameters.
// FIXME: we have to create adjoints for all parameters when any
// external sources are enabled because gradient overloads don't support
// additional parameters.
if (utils::IsDifferentiableType(dParam.param->getType()) ||
m_ExternalSource)
args.push_back(dParam.param);
}
else
Expand Down Expand Up @@ -597,7 +602,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) {
ParmVarDecl* param = paramsRef[i];
// no need to create adjoints for non-differentiable variables.
if (!utils::IsDifferentiableType(param->getType()))
// FIXME: we have to create adjoints for all parameters when any
// external sources are enabled because gradient overloads don't support
// additional parameters.
if (!utils::IsDifferentiableType(param->getType()) && !m_ExternalSource)
continue;
// derived variables are already created for independent variables.
if (m_Variables.count(param))
Expand Down
26 changes: 19 additions & 7 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ float func2(float x, int y) {
return x;
}

//CHECK: void func2_grad_0(float x, int y, float *_d_x, double &_final_error) {
//CHECK: void func2_grad(float x, int y, float *_d_x, int *_d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y * x + x * x;
Expand All @@ -57,6 +57,7 @@ float func2(float x, int y) {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_y += _r_d0 * x;
//CHECK-NEXT: *_d_x += y * _r_d0;
//CHECK-NEXT: *_d_x += _r_d0 * x;
//CHECK-NEXT: *_d_x += x * _r_d0;
Expand All @@ -69,14 +70,19 @@ float func3(int x, int y) {
return y;
}

//CHECK: void func3_grad(int x, int y, double &_final_error) {
//CHECK: void func3_grad(int x, int y, int *_d_x, int *_d_y, double &_final_error) {
//CHECK-NEXT: int _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: ;
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: *_d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: int _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

float func4(float x, float y) {
Expand Down Expand Up @@ -162,14 +168,20 @@ float func8(int x, int y) {
return x;
}

//CHECK: void func8_grad(int x, int y, double &_final_error) {
//CHECK: void func8_grad(int x, int y, int *_d_x, int *_d_y, double &_final_error) {
//CHECK-NEXT: int _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y * y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: ;
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: int _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_y += _r_d0 * y;
//CHECK-NEXT: *_d_y += y * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

int main() {
Expand Down
2 changes: 1 addition & 1 deletion test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ float func(float* p, int n) {
return sum;
}

//CHECK: void func_grad_0(float *p, int n, float *_d_p, double &_final_error) {
//CHECK: void func_grad(float *p, int n, float *_d_p, int *_d_n, double &_final_error) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int i = 0;
Expand Down
15 changes: 9 additions & 6 deletions test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ double runningSum(float* f, int n) {
return sum;
}

//CHECK: void runningSum_grad_0(float *f, int n, float *_d_f, double &_final_error) {
//CHECK: void runningSum_grad(float *f, int n, float *_d_f, int *_d_n, double &_final_error) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int i = 0;
Expand Down Expand Up @@ -58,7 +58,7 @@ double mulSum(float* a, float* b, int n) {
return sum;
}

//CHECK: void mulSum_grad_0_1(float *a, float *b, int n, float *_d_a, float *_d_b, double &_final_error) {
//CHECK: void mulSum_grad(float *a, float *b, int n, float *_d_a, float *_d_b, int *_d_n, double &_final_error) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int i = 0;
Expand Down Expand Up @@ -116,7 +116,7 @@ double divSum(float* a, float* b, int n) {
return sum;
}

//CHECK: void divSum_grad_0_1(float *a, float *b, int n, float *_d_a, float *_d_b, double &_final_error) {
//CHECK: void divSum_grad(float *a, float *b, int n, float *_d_a, float *_d_b, int *_d_n, double &_final_error) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int i = 0;
Expand Down Expand Up @@ -161,24 +161,27 @@ int main() {
float arrf[3] = {0.456, 0.77, 0.95};
double finalError = 0;
float darr[3] = {0, 0, 0};
df.execute(arrf, 3, darr, finalError);
int dn = 0;
df.execute(arrf, 3, darr, &dn, finalError);
printf("Result (RS) = {%.2f, %.2f, %.2f} error = %.5f\n", darr[0], darr[1],
darr[2], finalError); // CHECK-EXEC: Result (RS) = {1.00, 2.00, 1.00} error = 0.00000

finalError = 0;
darr[0] = darr[1] = darr[2] = 0;
dn = 0;
float darr2[3] = {0, 0, 0};
auto df2 = clad::estimate_error(mulSum);
df2.execute(arrf, arrf, 3, darr, darr2, finalError);
df2.execute(arrf, arrf, 3, darr, darr2, &dn, finalError);
printf("Result (MS) = {%.2f, %.2f, %.2f}, {%.2f, %.2f, %.2f} error = %.5f\n",
darr[0], darr[1], darr[2], darr2[0], darr2[1], darr2[2],
finalError); // CHECK-EXEC: Result (MS) = {2.18, 2.18, 2.18}, {2.18, 2.18, 2.18} error = 0.00000

finalError = 0;
darr[0] = darr[1] = darr[2] = 0;
darr2[0] = darr2[1] = darr2[2] = 0;
dn = 0;
auto df3 = clad::estimate_error(divSum);
df3.execute(arrf, arrf, 3, darr, darr2, finalError);
df3.execute(arrf, arrf, 3, darr, darr2, &dn, finalError);
printf("Result (DS) = {%.2f, %.2f, %.2f}, {%.2f, %.2f, %.2f} error = %.5f\n",
darr[0], darr[1], darr[2], darr2[0], darr2[1], darr2[2],
finalError); // CHECK-EXEC: Result (DS) = {2.19, 1.30, 1.05}, {-2.19, -1.30, -1.05} error = 0.00000
Expand Down
2 changes: 1 addition & 1 deletion test/Misc/RunDemos.C
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
// RUN: %cladclang %S/../../demos/ErrorEstimation/FloatSum.cpp -I%S/../../include 2>&1 | FileCheck -check-prefix CHECK_FLOAT_SUM %s
//CHECK_FLOAT_SUM-NOT: {{.*error|warning|note:.*}}

//CHECK_FLOAT_SUM: void vanillaSum_grad_0(float x, unsigned int n, float *_d_x, double &_final_error) {
//CHECK_FLOAT_SUM: void vanillaSum_grad(float x, unsigned int n, float *_d_x, unsigned int *_d_n, double &_final_error) {
//CHECK_FLOAT_SUM: float _d_sum = 0;
//CHECK_FLOAT_SUM: unsigned {{int|long}} _t0;
//CHECK_FLOAT_SUM: unsigned int i = 0;
Expand Down

0 comments on commit 59be2a3

Please sign in to comment.