Skip to content

Commit

Permalink
Add initial support for pointers in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 19, 2023
1 parent 6a42b46 commit 5bb4fe8
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 102 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ template <typename T> class array_ref {
/// Constructor for clad::array types
CUDA_HOST_DEVICE array_ref(array<T>& a) : m_arr(a.ptr()), m_size(a.size()) {}

/// Operator for conversion from array_ref<T> to T*.
CUDA_HOST_DEVICE operator T*() { return m_arr; }

template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array<U>& a) {
assert(m_size == a.size());
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ namespace clad {
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(clang::QualType T);

static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
};
} // end namespace clad

Expand Down
46 changes: 0 additions & 46 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,52 +1344,6 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
}
}

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr();
} else if (utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr();
derivedR = RDiff.getExpr_dx();
}
}

StmtDiff
BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
StmtDiff Ldiff = Visit(BinOp->getLHS());
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ namespace clad {
/// be more complex than just a DeclRefExpr.
/// (e.g. `__real (n++ ? z1 : z2)`)
m_Exprs.push_back(UnOp);
}
} else if (opCode == UnaryOperatorKind::UO_Deref)
m_Exprs.push_back(UnOp);
}

void VisitDeclRefExpr(clang::DeclRefExpr* DRE) {
Expand Down
152 changes: 111 additions & 41 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
// FIXME: not sure if this is generic.
// Don't update derivatives of non-record types.
// Don't update derivatives of record types.
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
Expand Down Expand Up @@ -2035,6 +2035,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If it is a post-increment/decrement operator, its result is a reference
// and we should return it.
Expr* ResultRef = nullptr;

// For increment/decrement of pointer, perform the same on the
// derivative pointer also.
bool isPointerOp = E->getType()->isPointerType();

if (opCode == UO_Plus)
// xi = +xj
// dxi/dxj = +1.0
Expand All @@ -2048,6 +2053,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, d);
} else if (opCode == UO_PostInc || opCode == UO_PostDec) {
diff = Visit(E, dfdx());
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
Expand All @@ -2060,10 +2068,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff);
} else if (opCode == UO_PreInc || opCode == UO_PreDec) {
diff = Visit(E, dfdx());
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);
}
auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add
: BinaryOperatorKind::BO_Sub;
Expand All @@ -2081,42 +2094,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
} else {
// FIXME: This is not adding 'address-of' operator support.
// This is just making this special case differentiable that is required
// for computing hessian:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// This code snippet should be removed once reverse mode officially
// supports pointers.
if (opCode == UnaryOperatorKind::UO_AddrOf) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD);
if (utils::SameCanonicalType(thisType, UnOp->getType())) {
diff = Visit(E);
Expr* cloneE =
BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr());
Expr* derivedE =
BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx());
return {cloneE, derivedE};
}
}
} else if (opCode == UnaryOperatorKind::UO_AddrOf) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr());
Expr* derivedE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx());
return {cloneE, derivedE};
} else if (opCode == UnaryOperatorKind::UO_Deref) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* diff_dx = diff.getExpr_dx();
bool specialDThisCase = false;
Expr* derivedE = nullptr;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance() && !diff_dx->getType()->isPointerType())
specialDThisCase = true; // _d_this is already dereferenced.
}
if (specialDThisCase)
derivedE = diff_dx;
else {
derivedE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
auto* add_assign = BuildOp(BO_AddAssign, derivedE, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
return {cloneE, derivedE};
} else if (opCode != UO_LNot) {
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
if (opCode != UO_LNot)
unsupportedOpWarn(UnOp->getEndLoc());

if (isa<DeclRefExpr>(E))
diff = Visit(E);
else
diff = StmtDiff(E);
}
unsupportedOpWarn(UnOp->getEndLoc());
} else if (isa<DeclRefExpr>(E))
diff = Visit(E);
else
diff = StmtDiff(E);
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}
Expand All @@ -2134,6 +2147,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// we should return it.
Expr* ResultRef = nullptr;

bool isPointerOp =
L->getType()->isPointerType() || R->getType()->isPointerType();

if (opCode == BO_Add) {
// xi = xl + xr
// dxi/xl = 1.0
Expand Down Expand Up @@ -2306,6 +2322,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* Lblock = endBlock(direction::reverse);
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore);

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverese pass.
if (isPointerOp) {
Expr* Edx = Ldiff.getExpr_dx();
ExprsToStore.push_back(Edx);
}

if (L->HasSideEffects(m_Context)) {
Expr* E = Ldiff.getExpr();
auto* storeE =
Expand Down Expand Up @@ -2352,21 +2376,35 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Save old value for the derivative of LHS, to avoid problems with cases
// like x = x.
auto* oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);
clang::Expr* oldValue = nullptr;

// For pointer types, no need to store old derivatives.
if (!isPointerOp)
oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);

if (opCode == BO_Assign) {
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
// Ensure correct operation in case of pointer types.
// if (utils::isArrayOrPointerType(Ldiff.getExpr()->getType())) {
// Expr* derivedL = nullptr;
// Expr* derivedR = nullptr;
// ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
// addToCurrentBlock(BuildOp(opCode, derivedL,
// derivedR),direction::forward);
// }
} else if (opCode == BO_AddAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_SubAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
Expand Down Expand Up @@ -2427,8 +2465,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue);

// Update the derivative.
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse);
// Update the derivative only if LHS is not a pointer type.
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);

// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
Expand Down Expand Up @@ -2460,6 +2501,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return BuildOp(opCode, LExpr, RExpr);
}
Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());

// For pointer types.
if (isPointerOp) {
if (opCode == BO_Add || opCode == BO_Sub) {
Expr* derivedL = nullptr;
Expr* derivedR = nullptr;
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
if (opCode == BO_Sub)
derivedR = BuildParens(derivedR);
return StmtDiff(op, BuildOp(opCode, derivedL, derivedR), nullptr,
valueForRevPass);
}
if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expr* derivedL = nullptr;
Expr* derivedR = nullptr;
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
addToCurrentBlock(BuildOp(opCode, derivedL, derivedR),
direction::forward);
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

Expand All @@ -2469,6 +2531,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto VDDerivedType = ComputeAdjointType(VD->getType());
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
Expand Down Expand Up @@ -2529,6 +2592,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
if (VDDerivedType->isRecordType())
Expand All @@ -2546,7 +2616,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType) {
if (!isDerivativeOfRefType && !isPointerType) {
Expr* derivedE = BuildDeclRef(VDDerived);
initDiff = StmtDiff{};
if (VD->getInit()) {
Expand Down
46 changes: 46 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,4 +781,50 @@ namespace clad {
auto& TAL = specialization->getTemplateArgs();
return TAL.get(0).getAsType();
}

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
void VisitorBase::ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr();
} else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr();
derivedR = RDiff.getExpr_dx();
}
}
} // end namespace clad
Loading

0 comments on commit 5bb4fe8

Please sign in to comment.