Skip to content

Commit

Permalink
Fix builtin macros and zero init for memory operations
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 27, 2024
1 parent e37039d commit 89e3fde
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
30 changes: 26 additions & 4 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -400,14 +401,21 @@ namespace clad {
auto& C = semaRef.getASTContext();
if (!TSI)
TSI = C.getTrivialTypeSourceInfo(qType);
bool implicitInit = false;
if (initializer && isa<ImplicitValueInitExpr>(initializer))
// If the initializer is an implicit value init expression, then
// we don't need to pass it explicitly to the CXXNewExpr. As, clang
// internally adds it when initializer is nullptr and DirectInitRange
// is valid.
implicitInit = true;
auto newExpr =
semaRef
.BuildCXXNew(
SourceRange(), false, noLoc, MultiExprArg(), noLoc,
SourceRange(), qType, TSI,
(arraySize ? arraySize : clad_compat::ArraySize_None()),
initializer ? GetValidSRange(semaRef) : SourceRange(),
initializer)
implicitInit ? nullptr : initializer)
.getAs<CXXNewExpr>();
return newExpr;
}
Expand Down Expand Up @@ -643,17 +651,31 @@ namespace clad {
}

bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) {

#if CLANG_VERSION_MAJOR > 12
if (FD->getBuiltinID() == Builtin::BImalloc)
return true;
if (FD->getBuiltinID() == Builtin::BIcalloc)
if (FD->getBuiltinID() == Builtin::ID::BIcalloc)
return true;
if (FD->getBuiltinID() == Builtin::ID::BIrealloc)
return true;
#else
if (FD->getNameAsString() == "malloc")
return true;
if (FD->getBuiltinID() == Builtin::BIrealloc)
if (FD->getNameAsString() == "calloc")
return true;
if (FD->getNameAsString() == "realloc")
return true;
#endif
return false;
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
return FD->getBuiltinID() == Builtin::BIfree;
#if CLANG_VERSION_MAJOR > 12
return FD->getBuiltinID() == Builtin::ID::BIfree;
#else
return FD->getNameAsString() == "free";
#endif
}
} // namespace utils
} // namespace clad
17 changes: 13 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2649,8 +2649,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
// Check if the variable is pointer type and initialized by new expression
if (isPointerType && (VD->getInit() != nullptr) &&
isa<CXXNewExpr>(VD->getInit()))
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;

// VDDerivedInit now serves two purposes -- as the initial derivative value
Expand Down Expand Up @@ -3824,9 +3823,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* clonedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), clonedArraySizeE,
initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo());
Expr* diffInit = initializerDiff.getExpr_dx();
if (!diffInit) {
// we should initialize it implicitly using ImplicitValueInitExpr
QualType type = CNE->getAllocatedType();
if (CNE->isArray()) {
type = m_Context.getVariableArrayType(
type, derivedArraySizeE, ArrayType::Normal, 0, SourceRange());
}
diffInit = new (m_Context) ImplicitValueInitExpr(type);
}
Expr* derivedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), derivedArraySizeE,
initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo());
m_Sema, CNE->getAllocatedType(), derivedArraySizeE, diffInit,
CNE->getAllocatedTypeSourceInfo());
return {clonedNewE, derivedNewE};
}

Expand Down
4 changes: 2 additions & 2 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ double newAndDeletePointer(double i, double j) {
// CHECK-NEXT: double *p = new double(i);
// CHECK-NEXT: _d_q = new double(* _d_j);
// CHECK-NEXT: double *q = new double(j);
// CHECK-NEXT: _d_r = new double [2];
// CHECK-NEXT: _d_r = new double [2](/*implicit*/(double[2])0);
// CHECK-NEXT: double *r = new double [2];
// CHECK-NEXT: _t0 = r[0];
// CHECK-NEXT: r[0] = i + j;
Expand Down Expand Up @@ -418,7 +418,7 @@ double structPointer (double x) {
// CHECK: void structPointer_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: T *_d_t = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: _d_t = new T;
// CHECK-NEXT: _d_t = new T();
// CHECK-NEXT: T *t = new T({x, /*implicit*/(int)0});
// CHECK-NEXT: double res = t->x;
// CHECK-NEXT: goto _label0;
Expand Down

0 comments on commit 89e3fde

Please sign in to comment.