-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for non-differentiable attribute in reverse mode
fixes #717
- Loading branch information
1 parent
7b71e25
commit f7c65a8
Showing
3 changed files
with
295 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s | ||
// RUN: ./NonDifferentiable.out | %filecheck_exec %s | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
|
||
#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
class SimpleFunctions1 { | ||
public: | ||
SimpleFunctions1() noexcept : x(0), y(0), x_pointer(&x), y_pointer(&y) {} | ||
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y), x_pointer(&x), y_pointer(&y) {} | ||
double x; | ||
non_differentiable double y; | ||
double* x_pointer; | ||
non_differentiable double* y_pointer; | ||
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } | ||
non_differentiable double mem_fn_2(double i, double j) { return i * j; } | ||
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } | ||
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } | ||
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; } | ||
SimpleFunctions1 operator+(const SimpleFunctions1& other) const { | ||
return SimpleFunctions1(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s1_mem_fn(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.mem_fn_1(i, j) + i * j; | ||
} | ||
|
||
double fn_s1_field(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s1_field_pointer(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return (*obj.x_pointer) * (*obj.y_pointer) + i * j; | ||
} | ||
|
||
double fn_s1_operator(double i, double j) { | ||
SimpleFunctions1 obj1(2, 3); | ||
SimpleFunctions1 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn_1(i, j); | ||
} | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } | ||
SimpleFunctions2 operator+(const SimpleFunctions2& other) const { | ||
return SimpleFunctions2(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
double fn_s2_field(double i, double j) { | ||
SimpleFunctions2 *obj0, obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s2_operator(double i, double j) { | ||
SimpleFunctions2 obj1(2, 3); | ||
SimpleFunctions2 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn(i, j); | ||
} | ||
|
||
double fn_non_diff_var(double i, double j) { | ||
non_differentiable double k = i * i * j; | ||
return k; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::gradient(&classname::name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::gradient(&name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
int main() { | ||
// FIXME: The parts of this test that are commented out are currently not working, due to bugs | ||
// not related to the implementation of the non-differentiable attribute. | ||
INIT_EXPR(SimpleFunctions1); | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 33.00 | ||
|
||
TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
TEST_FUNC(fn_s1_field_pointer, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s1_operator, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s2_field, 3, 5)*/ | ||
|
||
/*TEST_FUNC(fn_s2_operator, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_non_diff_var, 3, 5) // CHECK-EXEC: 0.00 0.00 | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); | ||
|
||
// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 _t0; | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: _t0 = obj; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 0; | ||
// CHECK-NEXT: double _r1 = 0; | ||
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: _d_obj.x += 1 * obj.y; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_pointer_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer); | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions2 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_non_diff_var_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: double _d_k = 0; | ||
// CHECK-NEXT: double k = i * i * j; | ||
// CHECK-NEXT: _d_k += 1; | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: (*_d_this).x += _d_y * i; | ||
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; | ||
// CHECK-NEXT: *_d_i += _d_y * j * j; | ||
// CHECK-NEXT: *_d_j += i * _d_y * j; | ||
// CHECK-NEXT: *_d_j += i * j * _d_y; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | ||
|
||
#define non_differentiable __attribute__((annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}} | ||
}; | ||
|
||
namespace clad { | ||
namespace custom_derivatives { | ||
void fn_s2_mem_fn_pullback(double i, double j, double _d_y, double* _d_i, double* _d_j) { | ||
*_d_i = 1.5; | ||
*_d_j = 2.5; | ||
} | ||
} // namespace custom_derivatives | ||
} // namespace clad | ||
|
||
non_differentiable double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::differentiate(&classname::name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \ | ||
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \ | ||
printf("\n"); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::differentiate(&name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(i, j)); \ | ||
printf("\n"); | ||
|
||
int main() { | ||
INIT_EXPR(SimpleFunctions2); | ||
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); | ||
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} | ||
} |