Skip to content

Commit

Permalink
Revert not skipping cuda host functions with const args
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 3, 2024
1 parent bf26f58 commit 0a81882
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1823,8 +1823,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If all arguments are constant literals, then this does not contribute to
// the gradient.
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE) &&
CE->getCallReturnType(m_Context).getAsString() != "cudaError_t") {
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
for (const Expr* arg : CE->arguments()) {
// if it's of type MaterializeTemporaryExpr, then check its
Expand All @@ -1848,8 +1847,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE) &&
CE->getCallReturnType(m_Context).getAsString() != "cudaError_t") {
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand Down

0 comments on commit 0a81882

Please sign in to comment.