Skip to content

Commit

Permalink
Add support for non-differentiable attribute in reverse mode
Browse files Browse the repository at this point in the history
fixes #717
  • Loading branch information
MihailMihov authored and vgvassilev committed Jul 20, 2024
1 parent 42c4673 commit f4780e9
Showing 1 changed file with 57 additions and 16 deletions.
73 changes: 57 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));

SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);

// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
Expand Down Expand Up @@ -2061,6 +2081,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == UnaryOperatorKind::UO_Deref) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());

// If we have a pointer to a member expression, which is
// non-differentiable, we just return a clone of the original expression.
if (auto* ME = dyn_cast<MemberExpr>(diff.getExpr()))
if (clad::utils::hasNonDifferentiableAttribute(ME->getMemberDecl()))
return {cloneE};

Expr* diff_dx = diff.getExpr_dx();
bool specialDThisCase = false;
Expr* derivedE = nullptr;
Expand Down Expand Up @@ -2650,9 +2677,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}

if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
Expand Down Expand Up @@ -2684,6 +2716,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(assignToZero, direction::reverse);
}
}

VarDecl* VDClone = nullptr;
Expr* derivedVDE = nullptr;
if (VDDerived)
Expand Down Expand Up @@ -2810,19 +2843,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
QualType QT = VD->getType();
if (!QT->isPointerType()) {
auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
if (QT->isPointerType())
QT = QT->getPointeeType();

auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda ||
(typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}

Expand All @@ -2834,6 +2869,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
DeclDiff<VarDecl> VDDiff;

if (!isLambda)
VDDiff = DifferentiateVarDecl(VD);

Expand Down Expand Up @@ -3009,6 +3045,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
if (clad::utils::hasNonDifferentiableAttribute(ME)) {
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
return {clonedME, zero};
}
if (!baseDiff.getExpr_dx())
return {clonedME, nullptr};
MemberExpr* derivedME = utils::BuildMemberExpr(
Expand Down

0 comments on commit f4780e9

Please sign in to comment.