Skip to content

Commit

Permalink
working version
Browse files Browse the repository at this point in the history
  • Loading branch information
DeadSpheroid committed Mar 21, 2024
1 parent 159a330 commit 1ad01ef
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,59 @@ namespace clad {
}
}

// Only verify forward declaration in reverse mode
if (request.Mode == DiffMode::reverse) {
bool hasForwardDecl = false;
FunctionDecl* deriv = nullptr;
FunctionDecl* forwdecl = nullptr;

if (isa<CXXMethodDecl>(DerivativeDecl)) {
// Derivative is a class/struct method
// Search for forw decl in same scope
DeclContextLookupResult DCLR = DerivativeDecl->getParent()->lookup(
DerivativeDecl->getDeclName());
if (std::distance(DCLR.begin(), DCLR.end()) >= 2) {
// Forward declaration found
hasForwardDecl = true;
deriv = (*DCLR.begin())->getAsFunction();
forwdecl = (*(++DCLR.begin()))->getAsFunction();
}
}

else {
LookupResult P(S, DerivativeDecl->getNameInfo(),
Sema::LookupNameKind::LookupUsingDeclName,
Sema::RedeclarationKind::ForVisibleRedeclaration);
S.LookupName(P, S.TUScope);

if (std::distance(P.begin(), P.end()) >= 2) {
// Forward declaration found with possibly mismatched signature
hasForwardDecl = true;
deriv = (*P.begin())->getAsFunction();
forwdecl = (*(++P.begin()))->getAsFunction();
}
}
// Compare signatures including return type
if (hasForwardDecl) {
assert(
deriv->getReturnType() == forwdecl->getReturnType() &&
"Forward declaration of a gradient must have void return type");

auto derivsignature = deriv->parameters();
auto forwsignature = forwdecl->parameters();
assert(derivsignature.size() == forwsignature.size() &&
"Unequal number of args in forward declaration and generated "
"gradient");

for (auto it1 = derivsignature.begin(), it2 = forwsignature.begin();
it1 != derivsignature.end() && it2 != forwsignature.end();
it1++, it2++) {
assert((*it1)->getType() == (*it2)->getType() &&
"Types of arguments do not match in forward decl");
}
}
}

if (DerivativeDecl) {
if (!alreadyDerived) {
m_DFC.Add(
Expand Down

0 comments on commit 1ad01ef

Please sign in to comment.