Skip to content

Commit

Permalink
Move BuildCallToCustomDerivativeOrNumericalDiff and noOverloadExists …
Browse files Browse the repository at this point in the history
…to VisitorBase
  • Loading branch information
PetroZarytskyi committed Apr 15, 2024
1 parent 537cbbf commit 8f0ae29
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 150 deletions.
20 changes: 0 additions & 20 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,6 @@ namespace clad {
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType);
/// Looks for a suitable overload for a given function.
///
/// \param[in] Name The identification information of the function
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
/// if the overload function namespace was not found. If false and
/// the function containing namespace was not found, nullptr is returned.
///
/// \returns The call expression if a suitable function overload was found,
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
/// Shorthand to issues a warning or error.
template <std::size_t N>
void diag(clang::DiagnosticsEngine::Level level, // Warning or Error
Expand Down
20 changes: 20 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,26 @@ namespace clad {
void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
/// Looks for a suitable overload for a given function.
///
/// \param[in] Name The identification information of the function
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
/// if the overload function namespace was not found. If false and
/// the function containing namespace was not found, nullptr is returned.
///
/// \returns The call expression if a suitable function overload was found,
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
};
} // end namespace clad

Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// Try to find a user-defined overloaded derivative.
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Expr* callDiff = BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

Expand Down
119 changes: 0 additions & 119 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,125 +134,6 @@ namespace clad {
return { returnedFD, enclosingNS };
}

// This method is derived from the source code of both
// buildOverloadedCallSet() in SemaOverload.cpp
// and ActOnCallExpr() in SemaExpr.cpp.
bool
DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup,
llvm::MutableArrayRef<Expr*> ARargs) {
if (UnresolvedLookup->getType() == m_Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup);

if (!find.HasFormOfMemberPointer) {
OverloadExpr* ovl = find.Expression;

if (isa<UnresolvedLookupExpr>(ovl)) {
ExprResult result;
SourceLocation Loc;
OverloadCandidateSet CandidateSet(Loc,
OverloadCandidateSet::CSK_Normal);
Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext);
auto* ULE = cast<UnresolvedLookupExpr>(ovl);
// Populate CandidateSet.
m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc,
&CandidateSet, &result);
OverloadCandidateSet::iterator Best = nullptr;
OverloadingResult OverloadResult = CandidateSet.BestViableFunction(
m_Sema, UnresolvedLookup->getBeginLoc(), Best);
if (OverloadResult != 0U) // No overloads were found.
return true;
}
}
}
return false;
}

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
namespaceID = "custom_derivatives";
NamespaceDecl* cladNS = nullptr;
if (m_BuiltinDerivativesNSD)
NSD = m_BuiltinDerivativesNSD;
else {
cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true);
NSD =
utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS);
m_BuiltinDerivativesNSD = NSD;
}
} else {
NSD = m_NumericalDiffNSD;
namespaceID = "numerical_diff";
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
return nullptr;
}
}
CXXScopeSpec SS;
DeclContext* DC = NSD;

// FIXME: Here `if` branch should be removed once we update
// numerical diff to use correct declaration context.
if (forCustomDerv) {
DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC);
// FIXME: We should ideally construct nested name specifier from the
// found custom derivative function. Current way will compute incorrect
// nested name specifier in some cases.
if (outermostDC &&
outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) {
utils::BuildNNS(m_Sema, originalFnDC, SS);
DC = originalFnDC;
} else {
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
}
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));

LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);
if (DC)
m_Sema.LookupQualifiedName(R, DC);
Expr* OverloadedFn = nullptr;
if (!R.empty()) {
// FIXME: We should find a way to specify nested name specifier
// after finding the custom derivative.
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

auto MARargs = llvm::MutableArrayRef<Expr*>(CallArgs);

SourceLocation Loc;

if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
}
return OverloadedFn;
}

void DerivativeBuilder::AddErrorEstimationModel(
std::unique_ptr<FPErrorEstimationModel> estModel) {
m_EstModel.push_back(std::move(estModel));
Expand Down
16 changes: 7 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,10 +1669,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto pushforwardCallArgs = DerivedCallArgs;
pushforwardCallArgs.push_back(ConstantFolder::synthesizeLiteral(
DerivedCallArgs.front()->getType(), m_Context, 1));
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (OverloadedDerivedFn)
asGrad = false;
}
Expand Down Expand Up @@ -1766,10 +1765,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
Expand Down Expand Up @@ -2044,7 +2042,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
return BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
Expand Down
120 changes: 119 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ namespace clad {
NumDiffArgs.insert(NumDiffArgs.end(), args.begin(), args.begin() + numArgs);
// Return the found overload.
std::string Name = "forward_central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
return BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
Expand Down Expand Up @@ -815,4 +815,122 @@ namespace clad {
derivedR = BuildArrayRefPtrRefExpr(derivedR);
}
}

Expr* VisitorBase::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
namespaceID = "custom_derivatives";
NamespaceDecl* cladNS = nullptr;
if (m_Builder.m_BuiltinDerivativesNSD)
NSD = m_Builder.m_BuiltinDerivativesNSD;
else {
cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true);
NSD =
utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS);
m_Builder.m_BuiltinDerivativesNSD = NSD;
}
} else {
NSD = m_Builder.m_NumericalDiffNSD;
namespaceID = "numerical_diff";
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
return nullptr;
}
}
CXXScopeSpec SS;
DeclContext* DC = NSD;

// FIXME: Here `if` branch should be removed once we update
// numerical diff to use correct declaration context.
if (forCustomDerv) {
DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC);
// FIXME: We should ideally construct nested name specifier from the
// found custom derivative function. Current way will compute incorrect
// nested name specifier in some cases.
if (outermostDC &&
outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) {
utils::BuildNNS(m_Sema, originalFnDC, SS);
DC = originalFnDC;
} else {
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
}
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));

LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);
if (DC)
m_Sema.LookupQualifiedName(R, DC);
Expr* OverloadedFn = nullptr;
if (!R.empty()) {
// FIXME: We should find a way to specify nested name specifier
// after finding the custom derivative.
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

auto MARargs = llvm::MutableArrayRef<Expr*>(CallArgs);

SourceLocation Loc;

if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
}
return OverloadedFn;
}

// This method is derived from the source code of both
// buildOverloadedCallSet() in SemaOverload.cpp
// and ActOnCallExpr() in SemaExpr.cpp.
bool VisitorBase::noOverloadExists(Expr* UnresolvedLookup,
llvm::MutableArrayRef<Expr*> ARargs) {
if (UnresolvedLookup->getType() == m_Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup);

if (!find.HasFormOfMemberPointer) {
OverloadExpr* ovl = find.Expression;

if (isa<UnresolvedLookupExpr>(ovl)) {
ExprResult result;
SourceLocation Loc;
OverloadCandidateSet CandidateSet(Loc,
OverloadCandidateSet::CSK_Normal);
Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext);
auto* ULE = cast<UnresolvedLookupExpr>(ovl);
// Populate CandidateSet.
m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc,
&CandidateSet, &result);
OverloadCandidateSet::iterator Best = nullptr;
OverloadingResult OverloadResult = CandidateSet.BestViableFunction(
m_Sema, UnresolvedLookup->getBeginLoc(), Best);
if (OverloadResult != 0U) // No overloads were found.
return true;
}
}
}
return false;
}
} // end namespace clad

0 comments on commit 8f0ae29

Please sign in to comment.