Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for [[clad::non_differentiable]] in reverse mode #916

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply cloning the argument seems incorrect. What if the arguments have side-effect which can affect the derivative computation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that I understand the issue here. It the arguments do have side effects then those would be kept when we clone them, is that not what is expected? When do you think that this wouldn't work correctly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider an example such as this:

some_non_differentiable_fn_call(r = u * v, s = u + v); 

Now, if we simply clone the arguments then we will not generate adjoint statements for r = u * v and s = u + v.

You don't necessarily need to fix this issue in this PR.


SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
parth-07 marked this conversation as resolved.
Show resolved Hide resolved
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);

// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
Expand Down Expand Up @@ -2061,6 +2081,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == UnaryOperatorKind::UO_Deref) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());

// If we have a pointer to a member expression, which is
// non-differentiable, we just return a clone of the original expression.
if (auto* ME = dyn_cast<MemberExpr>(diff.getExpr()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be handled more uniformly in VisitMemberExpr?

Copy link
Collaborator

@parth-07 parth-07 Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If diff.getExpr_dx() is 0, then we would not need to add a special condition here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be handled more uniformly in VisitMemberExpr?

It is already handled in VisitMemberExpr where we return {clonedME, zero}, but what happened without the above check is that it tries to build something along the lines of *0 += ..., when visiting the UO_Deref.

If diff.getExpr_dx() is 0, then we would need to add a special condition here.

With the above check I eliminate one of the cases where we could end up with a 0 above, if you can think of anything else, then we should handle those too. Do you have anything in mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can think of anything else, then we should handle those too.

I am concerned there might be many such cases... For example, -> operator.

It might be better to test if the diff.getExpr_dx() is a constant (or 0) instead of testing if the member has a non-differentiable attribute. This is because it will help us cover more cases. For example, the adjoint of member expressions of global class objects should also be 0 and consequently they should be handled similarly but they do not have non_differentiable attribute.

if (clad::utils::hasNonDifferentiableAttribute(ME->getMemberDecl()))
return {cloneE};

Expr* diff_dx = diff.getExpr_dx();
bool specialDThisCase = false;
Expr* derivedE = nullptr;
Expand Down Expand Up @@ -2655,9 +2682,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}

if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
Expand Down Expand Up @@ -2689,6 +2721,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(assignToZero, direction::reverse);
}
}

VarDecl* VDClone = nullptr;
Expr* derivedVDE = nullptr;
if (VDDerived)
Expand Down Expand Up @@ -2815,19 +2848,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
QualType QT = VD->getType();
if (!QT->isPointerType()) {
auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
if (QT->isPointerType())
QT = QT->getPointeeType();

auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test for a local variable declaration with non_differentiable attribute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added fn_non_diff_var to Gradient/NonDifferentiable.C test, but I believe it's not working as expected. The correct output would be 0.00 0.00 right? I'll try to get that fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please create an issue for this and resolve it in a follow-up pull-request?

(typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}

Expand All @@ -2839,6 +2874,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
DeclDiff<VarDecl> VDDiff;

if (!isLambda)
VDDiff = DifferentiateVarDecl(VD);

Expand Down Expand Up @@ -3014,6 +3050,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
if (clad::utils::hasNonDifferentiableAttribute(ME)) {
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
return {clonedME, zero};
}
if (!baseDiff.getExpr_dx())
return {clonedME, nullptr};
MemberExpr* derivedME = utils::BuildMemberExpr(
Expand Down
187 changes: 187 additions & 0 deletions test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s
// RUN: ./NonDifferentiable.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable")))

#include "clad/Differentiator/Differentiator.h"

class SimpleFunctions1 {
public:
SimpleFunctions1() noexcept : x(0), y(0), x_pointer(&x), y_pointer(&y) {}
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y), x_pointer(&x), y_pointer(&y) {}
double x;
non_differentiable double y;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test with some pointer member types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added an fn_s1_field_pointer to the test, is that what you wanted me to test?

double* x_pointer;
non_differentiable double* y_pointer;
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
non_differentiable double mem_fn_2(double i, double j) { return i * j; }
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; }
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; }
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; }
SimpleFunctions1 operator+(const SimpleFunctions1& other) const {
return SimpleFunctions1(x + other.x, y + other.y);
}
};

double fn_s1_mem_fn(double i, double j) {
SimpleFunctions1 obj(2, 3);
return obj.mem_fn_1(i, j) + i * j;
}

double fn_s1_field(double i, double j) {
SimpleFunctions1 obj(2, 3);
return obj.x * obj.y + i * j;
}

double fn_s1_field_pointer(double i, double j) {
SimpleFunctions1 obj(2, 3);
return (*obj.x_pointer) * (*obj.y_pointer) + i * j;
}

double fn_s1_operator(double i, double j) {
SimpleFunctions1 obj1(2, 3);
SimpleFunctions1 obj2(3, 5);
return (obj1 + obj2).mem_fn_1(i, j);
}

class non_differentiable SimpleFunctions2 {
public:
SimpleFunctions2() noexcept : x(0), y(0) {}
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
double y;
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; }
SimpleFunctions2 operator+(const SimpleFunctions2& other) const {
return SimpleFunctions2(x + other.x, y + other.y);
}
};

double fn_s2_mem_fn(double i, double j) {
SimpleFunctions2 obj(2, 3);
return obj.mem_fn(i, j) + i * j;
}

double fn_s2_field(double i, double j) {
SimpleFunctions2 *obj0, obj(2, 3);
return obj.x * obj.y + i * j;
}

double fn_s2_operator(double i, double j) {
SimpleFunctions2 obj1(2, 3);
SimpleFunctions2 obj2(3, 5);
return (obj1 + obj2).mem_fn(i, j);
}

double fn_non_diff_var(double i, double j) {
non_differentiable double k = i * i * j;
return k;
}

#define INIT_EXPR(classname) \
classname expr_1(2, 3); \
classname expr_2(3, 5);

#define TEST_CLASS(classname, name, i, j) \
auto d_##name = clad::gradient(&classname::name); \
double result_##name[2] = {}; \
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]);

#define TEST_FUNC(name, i, j) \
auto d_##name = clad::gradient(&name); \
double result_##name[2] = {}; \
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]);

int main() {
// FIXME: The parts of this test that are commented out are currently not working, due to bugs
// not related to the implementation of the non-differentiable attribute.
INIT_EXPR(SimpleFunctions1);

/*TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5)*/

/*TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5)*/

/*TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5)*/

/*TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5)*/

TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 33.00

TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 3.00

TEST_FUNC(fn_s1_field_pointer, 3, 5) // CHECK-EXEC: 5.00 3.00

/*TEST_FUNC(fn_s1_operator, 3, 5)*/

TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 3.00

/*TEST_FUNC(fn_s2_field, 3, 5)*/

/*TEST_FUNC(fn_s2_operator, 3, 5)*/

TEST_FUNC(fn_non_diff_var, 3, 5) // CHECK-EXEC: 0.00 0.00

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 _t0;
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: _d_obj.x += 1 * obj.y;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_field_pointer_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer);
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions2 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_non_diff_var_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_k = 0;
// CHECK-NEXT: double k = i * i * j;
// CHECK-NEXT: _d_k += 1;
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j * j;
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
}
51 changes: 51 additions & 0 deletions test/Gradient/NonDifferentiableError.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1

#define non_differentiable __attribute__((annotate("non_differentiable")))

#include "clad/Differentiator/Differentiator.h"

extern "C" int printf(const char* fmt, ...);

class non_differentiable SimpleFunctions2 {
public:
SimpleFunctions2() noexcept : x(0), y(0) {}
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
double y;
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}}
};

namespace clad {
namespace custom_derivatives {
void fn_s2_mem_fn_pullback(double i, double j, double _d_y, double* _d_i, double* _d_j) {
*_d_i = 1.5;
*_d_j = 2.5;
}
} // namespace custom_derivatives
} // namespace clad

non_differentiable double fn_s2_mem_fn(double i, double j) {
SimpleFunctions2 obj(2, 3);
return obj.mem_fn(i, j) + i * j;
}

#define INIT_EXPR(classname) \
classname expr_1(2, 3); \
classname expr_2(3, 5);

#define TEST_CLASS(classname, name, i, j) \
auto d_##name = clad::differentiate(&classname::name, "i"); \
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \
printf("\n");

#define TEST_FUNC(name, i, j) \
auto d_##name = clad::differentiate(&name, "i"); \
printf("%.2f\n", d_##name.execute(i, j)); \
printf("\n");

int main() {
INIT_EXPR(SimpleFunctions2);
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5);
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}}
parth-07 marked this conversation as resolved.
Show resolved Hide resolved
}
Loading