-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for [[clad::non_differentiable]]
in reverse mode
#916
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1420,6 +1420,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
return StmtDiff(Clone(CE)); | ||
} | ||
|
||
// If the function is non_differentiable, return zero derivative. | ||
if (clad::utils::hasNonDifferentiableAttribute(CE)) { | ||
// Calling the function without computing derivatives | ||
llvm::SmallVector<Expr*, 4> ClonedArgs; | ||
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) | ||
ClonedArgs.push_back(Clone(CE->getArg(i))); | ||
|
||
SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); | ||
Expr* Call = m_Sema | ||
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), | ||
validLoc, ClonedArgs, validLoc) | ||
parth-07 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.get(); | ||
// Creating a zero derivative | ||
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, | ||
/*val=*/0); | ||
|
||
// Returning the function call and zero derivative | ||
return StmtDiff(Call, zero); | ||
} | ||
|
||
auto NArgs = FD->getNumParams(); | ||
// If the function has no args and is not a member function call then we | ||
// assume that it is not related to independent variables and does not | ||
|
@@ -2061,6 +2081,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
} else if (opCode == UnaryOperatorKind::UO_Deref) { | ||
diff = Visit(E); | ||
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); | ||
|
||
// If we have a pointer to a member expression, which is | ||
// non-differentiable, we just return a clone of the original expression. | ||
if (auto* ME = dyn_cast<MemberExpr>(diff.getExpr())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it be handled more uniformly in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is already handled in
With the above check I eliminate one of the cases where we could end up with a 0 above, if you can think of anything else, then we should handle those too. Do you have anything in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am concerned there might be many such cases... For example, It might be better to test if the |
||
if (clad::utils::hasNonDifferentiableAttribute(ME->getMemberDecl())) | ||
return {cloneE}; | ||
|
||
Expr* diff_dx = diff.getExpr_dx(); | ||
bool specialDThisCase = false; | ||
Expr* derivedE = nullptr; | ||
|
@@ -2655,9 +2682,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
// If `VD` is a reference to a non-local variable then also there's no | ||
// need to call `Visit` since non-local variables are not differentiated. | ||
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { | ||
Expr* derivedE = BuildDeclRef(VDDerived); | ||
if (isInitializedByNewExpr) | ||
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); | ||
Expr* derivedE = nullptr; | ||
|
||
if (!clad::utils::hasNonDifferentiableAttribute(VD)) { | ||
derivedE = BuildDeclRef(VDDerived); | ||
if (isInitializedByNewExpr) | ||
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); | ||
} | ||
|
||
if (VD->getInit()) { | ||
if (isa<CXXConstructExpr>(VD->getInit())) | ||
initDiff = Visit(VD->getInit()); | ||
|
@@ -2689,6 +2721,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
addToCurrentBlock(assignToZero, direction::reverse); | ||
} | ||
} | ||
|
||
VarDecl* VDClone = nullptr; | ||
Expr* derivedVDE = nullptr; | ||
if (VDDerived) | ||
|
@@ -2815,19 +2848,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) { | ||
auto* VD = dyn_cast<VarDecl>(*declsBegin); | ||
QualType QT = VD->getType(); | ||
if (!QT->isPointerType()) { | ||
auto* typeDecl = QT->getAsCXXRecordDecl(); | ||
// We should also simply copy the original lambda. The differentiation | ||
// of lambdas is happening in the `VisitCallExpr`. For now, only the | ||
// declarations with lambda expressions without captures are supported. | ||
isLambda = typeDecl && typeDecl->isLambda(); | ||
if (isLambda) { | ||
for (auto* D : DS->decls()) | ||
if (auto* VD = dyn_cast<VarDecl>(D)) | ||
decls.push_back(VD); | ||
Stmt* DSClone = BuildDeclStmt(decls); | ||
return StmtDiff(DSClone, nullptr); | ||
} | ||
if (QT->isPointerType()) | ||
QT = QT->getPointeeType(); | ||
|
||
auto* typeDecl = QT->getAsCXXRecordDecl(); | ||
// We should also simply copy the original lambda. The differentiation | ||
// of lambdas is happening in the `VisitCallExpr`. For now, only the | ||
// declarations with lambda expressions without captures are supported. | ||
isLambda = typeDecl && typeDecl->isLambda(); | ||
if (isLambda || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add a test for a local variable declaration with non_differentiable attribute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please create an issue for this and resolve it in a follow-up pull-request? |
||
(typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { | ||
for (auto* D : DS->decls()) | ||
if (auto* VD = dyn_cast<VarDecl>(D)) | ||
decls.push_back(VD); | ||
Stmt* DSClone = BuildDeclStmt(decls); | ||
return StmtDiff(DSClone, nullptr); | ||
} | ||
} | ||
|
||
|
@@ -2839,6 +2874,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
for (auto* D : DS->decls()) { | ||
if (auto* VD = dyn_cast<VarDecl>(D)) { | ||
DeclDiff<VarDecl> VDDiff; | ||
|
||
if (!isLambda) | ||
VDDiff = DifferentiateVarDecl(VD); | ||
|
||
|
@@ -3014,6 +3050,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
"CXXMethodDecl nodes not supported yet!"); | ||
MemberExpr* clonedME = utils::BuildMemberExpr( | ||
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); | ||
if (clad::utils::hasNonDifferentiableAttribute(ME)) { | ||
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, | ||
/*val=*/0); | ||
return {clonedME, zero}; | ||
} | ||
if (!baseDiff.getExpr_dx()) | ||
return {clonedME, nullptr}; | ||
MemberExpr* derivedME = utils::BuildMemberExpr( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s | ||
// RUN: ./NonDifferentiable.out | %filecheck_exec %s | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
|
||
#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
class SimpleFunctions1 { | ||
public: | ||
SimpleFunctions1() noexcept : x(0), y(0), x_pointer(&x), y_pointer(&y) {} | ||
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y), x_pointer(&x), y_pointer(&y) {} | ||
double x; | ||
non_differentiable double y; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also test with some pointer member types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just added an |
||
double* x_pointer; | ||
non_differentiable double* y_pointer; | ||
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } | ||
non_differentiable double mem_fn_2(double i, double j) { return i * j; } | ||
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } | ||
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } | ||
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; } | ||
SimpleFunctions1 operator+(const SimpleFunctions1& other) const { | ||
return SimpleFunctions1(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s1_mem_fn(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.mem_fn_1(i, j) + i * j; | ||
} | ||
|
||
double fn_s1_field(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s1_field_pointer(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return (*obj.x_pointer) * (*obj.y_pointer) + i * j; | ||
} | ||
|
||
double fn_s1_operator(double i, double j) { | ||
SimpleFunctions1 obj1(2, 3); | ||
SimpleFunctions1 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn_1(i, j); | ||
} | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } | ||
SimpleFunctions2 operator+(const SimpleFunctions2& other) const { | ||
return SimpleFunctions2(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
double fn_s2_field(double i, double j) { | ||
SimpleFunctions2 *obj0, obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s2_operator(double i, double j) { | ||
SimpleFunctions2 obj1(2, 3); | ||
SimpleFunctions2 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn(i, j); | ||
} | ||
|
||
double fn_non_diff_var(double i, double j) { | ||
non_differentiable double k = i * i * j; | ||
return k; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::gradient(&classname::name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::gradient(&name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); | ||
|
||
int main() { | ||
// FIXME: The parts of this test that are commented out are currently not working, due to bugs | ||
// not related to the implementation of the non-differentiable attribute. | ||
INIT_EXPR(SimpleFunctions1); | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5)*/ | ||
|
||
/*TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 33.00 | ||
|
||
TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
TEST_FUNC(fn_s1_field_pointer, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s1_operator, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 3.00 | ||
|
||
/*TEST_FUNC(fn_s2_field, 3, 5)*/ | ||
|
||
/*TEST_FUNC(fn_s2_operator, 3, 5)*/ | ||
|
||
TEST_FUNC(fn_non_diff_var, 3, 5) // CHECK-EXEC: 0.00 0.00 | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); | ||
|
||
// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 _t0; | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: _t0 = obj; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 0; | ||
// CHECK-NEXT: double _r1 = 0; | ||
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: _d_obj.x += 1 * obj.y; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_pointer_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer); | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions2 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_non_diff_var_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: double _d_k = 0; | ||
// CHECK-NEXT: double k = i * i * j; | ||
// CHECK-NEXT: _d_k += 1; | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: (*_d_this).x += _d_y * i; | ||
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; | ||
// CHECK-NEXT: *_d_i += _d_y * j * j; | ||
// CHECK-NEXT: *_d_j += i * _d_y * j; | ||
// CHECK-NEXT: *_d_j += i * j * _d_y; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | ||
|
||
#define non_differentiable __attribute__((annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}} | ||
}; | ||
|
||
namespace clad { | ||
namespace custom_derivatives { | ||
void fn_s2_mem_fn_pullback(double i, double j, double _d_y, double* _d_i, double* _d_j) { | ||
*_d_i = 1.5; | ||
*_d_j = 2.5; | ||
} | ||
} // namespace custom_derivatives | ||
} // namespace clad | ||
|
||
non_differentiable double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::differentiate(&classname::name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \ | ||
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \ | ||
printf("\n"); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::differentiate(&name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(i, j)); \ | ||
printf("\n"); | ||
|
||
int main() { | ||
INIT_EXPR(SimpleFunctions2); | ||
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); | ||
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} | ||
parth-07 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply cloning the argument seems incorrect. What if the arguments have side-effect which can affect the derivative computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that I understand the issue here. It the arguments do have side effects then those would be kept when we clone them, is that not what is expected? When do you think that this wouldn't work correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider an example such as this:
some_non_differentiable_fn_call(r = u * v, s = u + v);
Now, if we simply clone the arguments then we will not generate adjoint statements for
r = u * v
ands = u + v
.You don't necessarily need to fix this issue in this PR.