Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more cases computing the derivative name in ComputeDerivativeName. #1195

Merged
merged 1 commit into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/clad/Differentiator/ParseDiffArgsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace clad {

std::size_t size() { return Finish - Start; }

bool isInInterval(std::size_t n) { return n >= Start && n <= Finish; }
bool isInInterval(std::size_t n) const { return n >= Start && n <= Finish; }
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
bool isValid() const { return Start != Finish; }
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

bool operator==(const IndexInterval& rhs) const {
return Start == rhs.Start && Finish == rhs.Finish;
Expand Down Expand Up @@ -83,4 +84,4 @@ namespace clad {
using DiffParamsWithIndices = std::pair<DiffParams, IndexIntervalTable>;
} // namespace clad

#endif
#endif
81 changes: 56 additions & 25 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include "TBRAnalyzer.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclarationName.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Sema.h"
Expand All @@ -20,6 +22,7 @@
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <string>

using namespace clang;

Expand Down Expand Up @@ -670,40 +673,68 @@
}

std::string DiffRequest::ComputeDerivativeName() const {
if (Mode != DiffMode::forward && Mode != DiffMode::reverse)
return BaseFunctionName + "_" + DiffModeToString(Mode);
if (Mode != DiffMode::forward && Mode != DiffMode::reverse &&
Mode != DiffMode::vector_forward_mode) {
std::string name = BaseFunctionName + "_" + DiffModeToString(Mode);
for (auto index : CUDAGlobalArgsIndexes)
name += "_" + std::to_string(index);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
return name;
}

if (DVI.empty())
return "<no independent variable specified>";

DiffInputVarInfo VarInfo = DVI.back();
const ValueDecl* IndependentVar = VarInfo.param;
unsigned argIndex = ~0;
// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
if (Function->param_empty() && Functor)
argIndex = std::distance(Functor->field_begin(),
std::find(Functor->field_begin(),
Functor->field_end(), IndependentVar));
else
argIndex =
std::distance(Function->param_begin(),
std::find(Function->param_begin(),
Function->param_end(), IndependentVar));

std::string argInfo = std::to_string(argIndex);
for (const std::string& field : VarInfo.fields)
argInfo += "_" + field;
// FIXME: Harmonize names accross modes. We have darg0 but dvec_0 and _grad.
std::string argInfo;
for (const DiffInputVarInfo& dParamInfo : DVI) {
// If we differentiate w.r.t all arguments we do not need to specify them.
if (DVI.size() == Function->getNumParams() && Mode != DiffMode::forward)
break;

const ValueDecl* IndP = dParamInfo.param;
// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
unsigned idx = ~0U;
if (Function->param_empty() && Functor) {
auto it = std::find(Functor->field_begin(), Functor->field_end(), IndP);
idx = std::distance(Functor->field_begin(), it);
} else {
const auto* it =
std::find(Function->param_begin(), Function->param_end(), IndP);
idx = std::distance(Function->param_begin(), it);
}
argInfo += ((Mode == DiffMode::forward) ? "" : "_") + std::to_string(idx);

if (dParamInfo.paramIndexInterval.isValid()) {
assert(utils::isArrayOrPointerType(IndP->getType()) && "Not array?");
// FIXME: What about ranges [Start;Finish)?
argInfo += "_" + std::to_string(dParamInfo.paramIndexInterval.Start);
}

for (const std::string& field : dParamInfo.fields)
argInfo += "_" + field;
}

if (Mode == DiffMode::vector_forward_mode) {
if (DVI.size() != Function->getNumParams())
return BaseFunctionName + "_dvec" + argInfo;
return BaseFunctionName + "_dvec";
}

if (Mode == DiffMode::reverse) {
if (DVI.size() != Function->getNumParams())
return BaseFunctionName + "_grad" + argInfo;
if (use_enzyme)
return BaseFunctionName + "_grad" + "_enzyme";

Check warning on line 730 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L730

Added line #L730 was not covered by tests
return BaseFunctionName + "_grad";
}

std::string s;
if (CurrentDerivativeOrder > 1)
s = std::to_string(CurrentDerivativeOrder);

if (utils::isArrayOrPointerType(IndependentVar->getType()))
argInfo += "_" + std::to_string(VarInfo.paramIndexInterval.Start);

return BaseFunctionName + "_d" + s + "arg" + argInfo;
}

Expand Down
24 changes: 5 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args);

auto derivativeBaseName = m_DiffReq.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix();
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
if (args.size() != FD->getNumParams()) {
for (const auto* arg : args) {
const auto* it = std::find(FD->param_begin(), FD->param_end(), arg);
auto idx = std::distance(FD->param_begin(), it);
gradientName += ('_' + std::to_string(idx));
}
}

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
DeclarationNameInfo name(II, noLoc);

// If we are in error estimation mode, we have an extra `double&`
// parameter that stores the final error
unsigned numExtraParam = 0;
Expand Down Expand Up @@ -266,6 +251,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
originalFnType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
std::string gradientName = m_DiffReq.ComputeDerivativeName();
IdentifierInfo* II = &m_Context.Idents.get(gradientName);
DeclarationNameInfo name(II, noLoc);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down Expand Up @@ -385,10 +374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args);

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto derivativeName = m_DiffReq.ComputeDerivativeName();
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down
21 changes: 7 additions & 14 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,11 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::vector_forward_mode);

// Generate the function type for the derivative.
DiffParams args{};
for (const auto& dParam : m_DiffReq.DVI)
args.push_back(dParam.param);

// Generate name for the derivative function.
std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec";
if (args.size() != FD->getNumParams()) {
for (auto arg : args) {
auto it = std::find(FD->param_begin(), FD->param_end(), arg);
auto idx = std::distance(FD->param_begin(), it);
derivedFnName += ('_' + std::to_string(idx));
}
}
IdentifierInfo* II = &m_Context.Idents.get(derivedFnName);
SourceLocation loc{m_DiffReq->getLocation()};
DeclarationNameInfo name(II, loc);

// Generate the function type for the derivative.
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_DiffReq->getNumParams() + args.size());
for (auto* PVD : m_DiffReq->parameters())
Expand Down Expand Up @@ -106,6 +93,12 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
dyn_cast<FunctionProtoType>(m_DiffReq->getType())->getExtProtoInfo());

// Create the function declaration for the derivative.
std::string derivedFnName = m_DiffReq.ComputeDerivativeName();

IdentifierInfo* II = &m_Context.Idents.get(derivedFnName);
SourceLocation loc{m_DiffReq->getLocation()};
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
DeclarationNameInfo name(II, loc);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down
Loading