-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for non-differentiable attribute in reverse mode
- Loading branch information
1 parent
ef668c7
commit e225aa5
Showing
2 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s | ||
// RUN: ./NonDifferentiable.out | %filecheck_exec %s | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
|
||
#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class SimpleFunctions1 { | ||
public: | ||
SimpleFunctions1() noexcept : x(0), y(0) {} | ||
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
non_differentiable double y; | ||
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } | ||
non_differentiable double mem_fn_2(double i, double j) { return i * j; } | ||
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } | ||
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } | ||
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; } | ||
SimpleFunctions1 operator+(const SimpleFunctions1& other) const { | ||
return SimpleFunctions1(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s1_mem_fn(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.mem_fn_1(i, j) + i * j; | ||
} | ||
|
||
double fn_s1_field(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s1_operator(double i, double j) { | ||
SimpleFunctions1 obj1(2, 3); | ||
SimpleFunctions1 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn_1(i, j); | ||
} | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } | ||
SimpleFunctions2 operator+(const SimpleFunctions2& other) const { | ||
return SimpleFunctions2(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
double fn_s2_field(double i, double j) { | ||
SimpleFunctions2 *obj0, obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s2_operator(double i, double j) { | ||
SimpleFunctions2 obj1(2, 3); | ||
SimpleFunctions2 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn(i, j); | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::gradient(&classname::name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::gradient(&name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
int main() { | ||
// FIXME: The parts of this test that are commented out are currently not working, due to bugs | ||
// not related to the implementation of the non-differentiable attribute. | ||
INIT_EXPR(SimpleFunctions1); | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 33.00 | ||
|
||
TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s1_operator, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s2_field, 3, 5)*/ | ||
|
||
/*TEST_FUNC(fn_s2_operator, 3, 5)*/ | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); | ||
|
||
// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 _t0; | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: _t0 = obj; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 0; | ||
// CHECK-NEXT: double _r1 = 0; | ||
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: _d_obj.x += 1 * obj.y; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions2 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: (*_d_this).x += _d_y * i; | ||
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; | ||
// CHECK-NEXT: *_d_i += _d_y * j * j; | ||
// CHECK-NEXT: *_d_j += i * _d_y * j; | ||
// CHECK-NEXT: *_d_j += i * j * _d_y; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | ||
|
||
#define non_differentiable __attribute__((annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}} | ||
}; | ||
|
||
non_differentiable double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::differentiate(&classname::name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \ | ||
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \ | ||
printf("\n"); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::differentiate(&name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(i, j)); \ | ||
printf("\n"); | ||
|
||
int main() { | ||
INIT_EXPR(SimpleFunctions2); | ||
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); | ||
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} | ||
} |