diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 9453e3c16..036d88601 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -215,6 +215,59 @@ namespace clad { } } + // Only verify forward declaration in reverse mode + if (request.Mode == DiffMode::reverse) { + bool hasForwardDecl = false; + FunctionDecl* deriv = nullptr; + FunctionDecl* forwdecl = nullptr; + + if (isa(DerivativeDecl)) { + // Derivative is a class/struct method + // Search for forw decl in same scope + DeclContextLookupResult DCLR = DerivativeDecl->getParent()->lookup( + DerivativeDecl->getDeclName()); + if (std::distance(DCLR.begin(), DCLR.end()) >= 2) { + // Forward declaration found + hasForwardDecl = true; + deriv = (*DCLR.begin())->getAsFunction(); + forwdecl = (*(++DCLR.begin()))->getAsFunction(); + } + } + + else { + LookupResult P(S, DerivativeDecl->getNameInfo(), + Sema::LookupNameKind::LookupUsingDeclName, + Sema::RedeclarationKind::ForVisibleRedeclaration); + S.LookupName(P, S.TUScope); + + if (std::distance(P.begin(), P.end()) >= 2) { + // Forward declaration found with possibly mismatched signature + hasForwardDecl = true; + deriv = (*P.begin())->getAsFunction(); + forwdecl = (*(++P.begin()))->getAsFunction(); + } + } + // Compare signatures including return type + if (hasForwardDecl) { + assert( + deriv->getReturnType() == forwdecl->getReturnType() && + "Forward declaration of a gradient must have void return type"); + + auto derivsignature = deriv->parameters(); + auto forwsignature = forwdecl->parameters(); + assert(derivsignature.size() == forwsignature.size() && + "Unequal number of args in forward declaration and generated " + "gradient"); + + for (auto it1 = derivsignature.begin(), it2 = forwsignature.begin(); + it1 != derivsignature.end() && it2 != forwsignature.end(); + it1++, it2++) { + assert((*it1)->getType() == (*it2)->getType() && + "Types of arguments do not match in forward decl"); + } + } + } + if (DerivativeDecl) { if (!alreadyDerived) { m_DFC.Add(