Skip to content

Commit

Permalink
Move more cases computing the derivative name in ComputeDerivativeName.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 28, 2024
1 parent c893c31 commit ea536a6
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 60 deletions.
5 changes: 3 additions & 2 deletions include/clad/Differentiator/ParseDiffArgsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace clad {

std::size_t size() { return Finish - Start; }

bool isInInterval(std::size_t n) { return n >= Start && n <= Finish; }
bool isInInterval(std::size_t n) const { return n >= Start && n <= Finish; }
bool isValid() const { return Start != Finish; }

bool operator==(const IndexInterval& rhs) const {
return Start == rhs.Start && Finish == rhs.Finish;
Expand Down Expand Up @@ -83,4 +84,4 @@ namespace clad {
using DiffParamsWithIndices = std::pair<DiffParams, IndexIntervalTable>;
} // namespace clad

#endif
#endif
76 changes: 51 additions & 25 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,40 +670,66 @@ namespace clad {
}

std::string DiffRequest::ComputeDerivativeName() const {
if (Mode != DiffMode::forward && Mode != DiffMode::reverse)
return BaseFunctionName + "_" + DiffModeToString(Mode);
if (Mode != DiffMode::forward && Mode != DiffMode::reverse &&
Mode != DiffMode::vector_forward_mode) {
std::string name = BaseFunctionName + "_" + DiffModeToString(Mode);
for (auto index : CUDAGlobalArgsIndexes)
name += "_" + std::to_string(index);
return name;
}

if (DVI.empty())
return "<no independent variable specified>";

DiffInputVarInfo VarInfo = DVI.back();
const ValueDecl* IndependentVar = VarInfo.param;
unsigned argIndex = ~0;
// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
if (Function->param_empty() && Functor)
argIndex = std::distance(Functor->field_begin(),
std::find(Functor->field_begin(),
Functor->field_end(), IndependentVar));
else
argIndex =
std::distance(Function->param_begin(),
std::find(Function->param_begin(),
Function->param_end(), IndependentVar));

std::string argInfo = std::to_string(argIndex);
for (const std::string& field : VarInfo.fields)
argInfo += "_" + field;
// FIXME: Harmonize names accross modes. We have darg0 but dvec_0 and _grad.
std::string argInfo = "";
for (DiffInputVarInfo dParamInfo : DVI) {
// If we differentiate w.r.t all arguments we do not need to specify them.
if (DVI.size() == Function->getNumParams() && Mode != DiffMode::forward)
break;

const ValueDecl* IndP = dParamInfo.param;
// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
unsigned idx = ~0U;
if (Function->param_empty() && Functor) {
auto it = std::find(Functor->field_begin(), Functor->field_end(), IndP);
idx = std::distance(Functor->field_begin(), it);
} else {
auto it =
std::find(Function->param_begin(), Function->param_end(), IndP);
idx = std::distance(Function->param_begin(), it);
}
argInfo += ((Mode == DiffMode::forward) ? "" : "_") + std::to_string(idx);

if (dParamInfo.paramIndexInterval.isValid()) {
assert(utils::isArrayOrPointerType(IndP->getType()) && "Not array?");
// FIXME: What about ranges [Start;Finish)?
argInfo += "_" + std::to_string(dParamInfo.paramIndexInterval.Start);
}

for (const std::string& field : dParamInfo.fields)
argInfo += "_" + field;
}

if (Mode == DiffMode::vector_forward_mode) {
if (DVI.size() != Function->getNumParams())
return BaseFunctionName + "_dvec" + argInfo;
return BaseFunctionName + "_dvec";
}

if (Mode == DiffMode::reverse) {
if (DVI.size() != Function->getNumParams())
return BaseFunctionName + "_grad" + argInfo;
return BaseFunctionName + "_grad";
}

std::string s;
if (CurrentDerivativeOrder > 1)
s = std::to_string(CurrentDerivativeOrder);

if (utils::isArrayOrPointerType(IndependentVar->getType()))
argInfo += "_" + std::to_string(VarInfo.paramIndexInterval.Start);

return BaseFunctionName + "_d" + s + "arg" + argInfo;
}

Expand Down
24 changes: 5 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args);

auto derivativeBaseName = m_DiffReq.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix();
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
if (args.size() != FD->getNumParams()) {
for (const auto* arg : args) {
const auto* it = std::find(FD->param_begin(), FD->param_end(), arg);
auto idx = std::distance(FD->param_begin(), it);
gradientName += ('_' + std::to_string(idx));
}
}

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
DeclarationNameInfo name(II, noLoc);

// If we are in error estimation mode, we have an extra `double&`
// parameter that stores the final error
unsigned numExtraParam = 0;
Expand Down Expand Up @@ -266,6 +251,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
originalFnType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
std::string gradientName = m_DiffReq.ComputeDerivativeName();
IdentifierInfo* II = &m_Context.Idents.get(gradientName);
DeclarationNameInfo name(II, noLoc);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down Expand Up @@ -385,10 +374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args);

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto derivativeName = m_DiffReq.ComputeDerivativeName();
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down
21 changes: 7 additions & 14 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,11 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::vector_forward_mode);

// Generate the function type for the derivative.
DiffParams args{};
for (const auto& dParam : m_DiffReq.DVI)
args.push_back(dParam.param);

// Generate name for the derivative function.
std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec";
if (args.size() != FD->getNumParams()) {
for (auto arg : args) {
auto it = std::find(FD->param_begin(), FD->param_end(), arg);
auto idx = std::distance(FD->param_begin(), it);
derivedFnName += ('_' + std::to_string(idx));
}
}
IdentifierInfo* II = &m_Context.Idents.get(derivedFnName);
SourceLocation loc{m_DiffReq->getLocation()};
DeclarationNameInfo name(II, loc);

// Generate the function type for the derivative.
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_DiffReq->getNumParams() + args.size());
for (auto* PVD : m_DiffReq->parameters())
Expand Down Expand Up @@ -106,6 +93,12 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
dyn_cast<FunctionProtoType>(m_DiffReq->getType())->getExtProtoInfo());

// Create the function declaration for the derivative.
std::string derivedFnName = m_DiffReq.ComputeDerivativeName();

IdentifierInfo* II = &m_Context.Idents.get(derivedFnName);
SourceLocation loc{m_DiffReq->getLocation()};
DeclarationNameInfo name(II, loc);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down

0 comments on commit ea536a6

Please sign in to comment.