Skip to content

Commit

Permalink
tmp + dbg
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Nov 9, 2024
1 parent 111dff3 commit ad6809b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 28 deletions.
6 changes: 3 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ namespace clad {
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand All @@ -310,7 +310,7 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand All @@ -326,7 +326,7 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false);
/// Builds variable declaration to be used inside the derivative
/// body in the derivative function global scope.
clang::VarDecl*
Expand Down
72 changes: 54 additions & 18 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,13 +2133,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

const auto* MD = dyn_cast<CXXMethodDecl>(FD);
bool isLambda = (MD ? isLambdaCallOperator(MD) : false);
// Method operators have a base like methods do but it's included in the
// call arguments so we have to shift the indexing of call arguments.
bool isMethodOperatorCall = MD && isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isMethodOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
llvm::errs() << "i: " << i << '\n';
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
Expand All @@ -2159,7 +2161,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// same as the call expression as it is the type used to declare the
// _gradX array
QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, true);
PreCallStmts.push_back(BuildDeclStmt(dArgDecl));
CallArgDx.push_back(BuildDeclRef(dArgDecl));
// Visit using uninitialized reference.
Expand Down Expand Up @@ -2261,13 +2263,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (MD) {
if (isLambdaCallOperator(MD)) {
QualType ptrType = m_Context.getPointerType(m_Context.getRecordType(
FD->getDeclContext()->getOuterLexicalRecordContext()));
baseDiff =
StmtDiff(Clone(dyn_cast<CXXOperatorCallExpr>(CE)->getArg(0)),
new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc));
} else if (MD->isInstance()) {
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
Expand All @@ -2276,18 +2272,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

baseDiff = Visit(baseOriginalE);
baseExpr = baseDiff.getExpr();
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);
if (!isLambda) {
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);
}

llvm::errs() << "diff base: ";
baseExpr->dumpPretty(m_Context);
llvm::errs() << " ";
baseDiff.getExpr_dx()->dumpPretty(m_Context);
llvm::errs() << " ";
llvm::errs() << "\n";

Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative);
DerivedCallOutputArgs.push_back(baseDerivative);
if (!isLambda)
DerivedCallOutputArgs.push_back(baseDerivative);
}
}

for (auto* argDerivative : CallArgDx) {
Expr* gradArgExpr = nullptr;
llvm::errs() << "i: " << idx << '\n';
QualType paramTy = FD->getParamDecl(idx)->getType();
if (!argDerivative || utils::isArrayOrPointerType(paramTy) ||
isCladArrayType(argDerivative->getType()))
Expand Down Expand Up @@ -2358,12 +2366,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();

OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, Loc,
pullbackCallArgs, Loc, CUDAExecConfig)
.get();
} else {
OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
Loc, pullbackCallArgs, Loc)
.get();
} else if (!isLambda) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
pullbackCallArgs, PreCallStmts, dfdx());
Expand Down Expand Up @@ -2446,7 +2453,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (OverloadedDerivedFn) {
if (isLambda) {
Stmts& block = getCurrentBlock(direction::reverse);
Stmts::iterator it = std::begin(block) + insertionPoint;
// Insert PreCallStmts
it = block.insert(it, PreCallStmts.begin(), PreCallStmts.end());
it += PreCallStmts.size();
// Insert the call
Expr* baseEdx = baseDiff.getExpr_dx(); // The pullback lambda
const CXXRecordDecl* EdxRD = baseEdx->getType()->getAsCXXRecordDecl();
auto* CMD = const_cast<CXXMethodDecl*>(EdxRD->getLambdaCallOperator());
NestedNameSpecifierLoc NNS(CMD->getQualifier(),
/*Data=*/nullptr);
auto DAP = DeclAccessPair::make(CMD, CMD->getAccess());
auto* memberExpr = MemberExpr::Create(
m_Context, Clone(baseEdx), /*isArrow=*/false, Loc, NNS, noLoc,
CMD, DAP, CMD->getNameInfo(),
/*TemplateArgs=*/nullptr, m_Context.BoundMemberTy,
CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams(
NOUR_None));
OverloadedDerivedFn = m_Sema
.BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc,
pullbackCallArgs, Loc)
.get();

// OverloadedDerivedFn = BuildCallExprToMemFn(
// baseEdx, FD->getName(), pullbackCallArgs, Loc);
it = block.insert(it, OverloadedDerivedFn);
it++;
} else if (OverloadedDerivedFn) {
// Derivative was found.
FunctionDecl* fnDecl = dyn_cast<CallExpr>(OverloadedDerivedFn)
->getDirectCallee();
Expand Down
14 changes: 7 additions & 7 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ namespace clad {
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) {
return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit,
TSI, IS);
TSI, IS, pushCodeSynthCtxt);
}
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto* VD = VarDecl::Create(
Expand All @@ -126,7 +126,7 @@ namespace clad {
if (const CXXRecordDecl* RD = dyn_cast<CXXRecordDecl>(RT->getDecl()))
isLambda = RD->isLambda();
}
if (isLambda) {
if (isLambda || pushCodeSynthCtxt) {
clang::Sema::CodeSynthesisContext csc;
m_Sema.pushCodeSynthesisContext(csc);
}
Expand All @@ -138,7 +138,7 @@ namespace clad {
m_Sema.FinalizeDeclaration(VD);
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false);
if (isLambda)
if (Init && (isLambda || pushCodeSynthCtxt))
m_Sema.popCodeSynthesisContext();
return VD;
}
Expand All @@ -152,9 +152,9 @@ namespace clad {
VarDecl* VisitorBase::BuildVarDecl(QualType Type, llvm::StringRef prefix,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), Init, DirectInit,
TSI, IS);
TSI, IS, pushCodeSynthCtxt);
}

VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
Expand Down

0 comments on commit ad6809b

Please sign in to comment.