diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 89f3a17ad..6db6ee915 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1976,13 +1976,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, else pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest); - // Clad failed to derive it. - // FIXME: Add support for reference arguments to the numerical diff. If - // it already correctly support reference arguments then confirm the - // support and add tests for the same. - if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(FD)) { - // Try numerically deriving it. + if (pullbackFD) { + if (MD) { + Expr* baseE = baseDiff.getExpr(); + OverloadedDerivedFn = BuildCallExprToMemFn( + baseE, pullbackFD->getName(), pullbackCallArgs, Loc); + } else { + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), + Loc, pullbackCallArgs, Loc, CUDAExecConfig) + .get(); + } + } else if (!utils::HasAnyReferenceOrPointerArgument(FD) && !MD) { + // FIXME: Add support for reference arguments to the numerical diff. If + // it already correctly support reference arguments then confirm the + // support and add tests for the same. + // + // Clad failed to derive it. Try numerically deriving it. if (NArgs == 1) { OverloadedDerivedFn = GetSingleArgCentralDiffCall( Clone(CE->getCallee()), DerivedCallArgs[0], @@ -2002,18 +2013,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, block.insert(block.begin(), PreCallStmts.begin(), PreCallStmts.end()); return StmtDiff(Clone(CE)); } - } else if (pullbackFD) { - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); - OverloadedDerivedFn = BuildCallExprToMemFn( - baseE, pullbackFD->getName(), pullbackCallArgs, Loc); - } else { - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), - Loc, pullbackCallArgs, Loc, CUDAExecConfig) - .get(); - } } }