From a2ed7f208e72ea398a3f295b50d5fd1a2ba650ea Mon Sep 17 00:00:00 2001 From: Parth Arora Date: Fri, 20 May 2022 14:19:17 +0530 Subject: [PATCH] Enable differentiation wrt data members of parameters in forw mode This PR adds initial support for differentiation with respect to data members in Forward Mode AD. Functions can be differentiated with respect to a data member by specifying the data member in the differentiation args. For example: ```cpp double fn(std::pair u, std::pair v) { return u.first + v.second; } auto d_fn = clad::differentiate(fn, "u.first"); ``` This PR does not add support for differentiating with respect to array data members. --- include/clad/Differentiator/CladUtils.h | 42 ++++++ include/clad/Differentiator/DiffPlanner.h | 8 +- .../clad/Differentiator/ParseDiffArgsTypes.h | 47 ++++++- lib/Differentiator/CladUtils.cpp | 60 ++++++++ lib/Differentiator/DiffPlanner.cpp | 130 ++++++++++++------ lib/Differentiator/ForwardModeVisitor.cpp | 80 ++++++++--- lib/Differentiator/HessianModeVisitor.cpp | 10 +- lib/Differentiator/ReverseModeVisitor.cpp | 5 +- test/FirstDerivative/DiffInterface.C | 24 ++++ test/ForwardMode/UserDefinedTypes.C | 59 +++++++- tools/DerivedFnInfo.cpp | 6 +- tools/DerivedFnInfo.h | 2 +- 12 files changed, 396 insertions(+), 77 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index e15d3bfc5..8da3200f4 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -250,6 +250,48 @@ namespace clad { /// If `S` is `null`, then nothing happens. void AppendIndividualStmts(llvm::SmallVectorImpl& block, clang::Stmt* S); + /// Builds a nested member expression that consist of base expression + /// specified by `base` argument and data members specified in `fields` + /// argument in the original sequence. + /// + /// For example, if `base` represents `b` -- an expression of a record type, + /// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function + /// builds and returns the following expression: + /// ``` + /// b.mem1.mem2.mem3 + /// ``` + clang::MemberExpr* + BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, + llvm::ArrayRef fields); + + /// Returns true if member expression path specified by `fields` is correct; + /// otherwise returns false. + /// + /// For example, if `base` represents `b` -- an expression of a record type, + /// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function + /// returns true if `b.mem1.mem2.mem3` is a valid data member reference + /// expression, otherwise returns false. + /// + /// \note This function returns true if `fields` is an empty sequence. + bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields); + + /// Perform lookup for data member with name `name`. If lookup finds a + /// declaration, then return the field declaration; otherwise returns + /// `nullptr`. + clang::FieldDecl* LookupDataMember(clang::Sema& semaRef, + clang::RecordDecl* RD, + llvm::StringRef name); + + /// Computes the type of a data member of the record specified by `RD` + /// and nested fields specified in `fields` argument. + /// For example, if `RD` represents `std::pair, std::pair`, and `fields` is the sequence + /// {'first', 'first'}, then the corresponding data member is + // of type `std::complex`. + clang::QualType + ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields); } // namespace utils } diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 0c6a46d80..3a85f435b 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -53,9 +53,13 @@ namespace clad { /// differentiated, for example, when we are computing higher /// order derivatives. const clang::CXXRecordDecl* Functor = nullptr; - DiffParamsWithIndices DiffParamsInfo; - /// Recomputes `DiffParamsInfo` using the current values of data members. + /// Stores differentiation parameters information. Stored information + /// includes info on indices range for array parameters, and nested data + /// member information for record (class) type parameters. + DiffInputVarsInfo DVI; + + /// Recomputes `DiffInputVarsInfo` using the current values of data members. /// /// Differentiation parameters info is computed by parsing the argument /// expression for the clad differentiation function calls. The argument is diff --git a/include/clad/Differentiator/ParseDiffArgsTypes.h b/include/clad/Differentiator/ParseDiffArgsTypes.h index e7ae0210e..fb53930ae 100644 --- a/include/clad/Differentiator/ParseDiffArgsTypes.h +++ b/include/clad/Differentiator/ParseDiffArgsTypes.h @@ -6,7 +6,9 @@ #define CLAD_PARSE_DIFF_ARGS_TYPES_H #include "clang/AST/Decl.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include #include @@ -34,10 +36,51 @@ namespace clad { return Start == rhs.Start && Finish == rhs.Finish; } }; + + using IndexIntervalTable = llvm::SmallVector; + + /// `DiffInputVarInfo` is designed to store all the essential information about a + /// differentiation input variable. Please note that here input variable corresponds + /// to mathematical variable, not a programming one. + // FIXME: 'DiffInputVarInfo' name is probably not accurate, since we can have multiple + // differentiation input variables for same parameter as well. 'DiffInputVarInfo' + // name implicitly guides that there would be at most one `DiffInputVarInfo` object for + // one parameter, but that is not strictly true. + struct DiffInputVarInfo { + /// Source string specified by user that defines differentiation + /// specification for the input variable. + /// For example, if complete input string specified by user is: + /// 'u, v.first, arr[3]' + /// then `source` data member value for 2nd input variable should be + /// 'v.first' + std::string source; + /// Parameter associated with the input variable. + const clang::ValueDecl* param = nullptr; + /// array index range associated with the parameter. + IndexInterval paramIndexInterval; + /// Nested field information. + llvm::SmallVector fields; + // FIXME: Add support for differentiating with respect to array fields. + // llvm::SmallVector fieldIndexIntervals; + + DiffInputVarInfo(const clang::ValueDecl* pParam = nullptr, + IndexInterval pParamIndexInterval = {}, + llvm::SmallVector pFields = {}) + : param(pParam), paramIndexInterval(pParamIndexInterval), + fields(pFields) {} + + // FIXME: Move function definitions to ParseDiffArgTypes.cpp + bool operator==(const DiffInputVarInfo& rhs) const { + return param == rhs.param && + paramIndexInterval == rhs.paramIndexInterval && + fields == rhs.fields; + } + }; + + using DiffInputVarsInfo = llvm::SmallVector; using DiffParams = llvm::SmallVector; - using IndexIntervalTable = llvm::SmallVector; using DiffParamsWithIndices = std::pair; -} // namespace clad + } // namespace clad #endif \ No newline at end of file diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index a2f114fec..928c4f601 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -420,5 +420,65 @@ namespace clad { else if (S) block.push_back(S); } + + MemberExpr* + BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, + llvm::ArrayRef fields) { + MemberExpr* ME = nullptr; + for (auto field : fields) { + ME = BuildMemberExpr(semaRef, S, base, field); + base = ME; + } + return ME; + } + + clang::FieldDecl* LookupDataMember(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::StringRef name) { + LookupResult R(semaRef, BuildDeclarationNameInfo(semaRef, name), + Sema::LookupNameKind::LookupMemberName); + CXXScopeSpec CSS; + semaRef.LookupQualifiedName(R, RD, CSS); + if (R.empty()) + return nullptr; + assert(R.isSingleResult() && "Lookup in valid classes should always " + "return a single data member result."); + auto D = R.getFoundDecl(); + // We are looking data members only! + if (auto FD = dyn_cast(D)) + return FD; + return nullptr; + } + + bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields) { + for (std::size_t i = 0; i < fields.size(); ++i) { + FieldDecl* FD = LookupDataMember(semaRef, RD, fields[i]); + if (!FD) + return false; + if (FD->getType()->isRecordType()) + RD = FD->getType()->getAsCXXRecordDecl(); + // Current Field declaration is not of record type, therefore + // it cannot have any field within it. And any member access + // ('.') expression would be an invalid path. + else if (i != fields.size() - 1) + return false; + } + return true; + } + + clang::QualType + ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields) { + assert(IsValidMemExprPath(semaRef, RD, fields) && + "Invalid field path specified!"); + QualType T = RD->getTypeForDecl()->getCanonicalTypeInternal(); + for (auto field : fields) { + auto FD = LookupDataMember(semaRef, RD, field); + if (FD->getType()->isRecordType()) + RD = FD->getType()->getAsCXXRecordDecl(); + T = FD->getType(); + } + return T; + } } // namespace utils } // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d9a111eed..e9125b2a1 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -263,12 +263,12 @@ namespace clad { } void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) { + DVI.clear(); auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; const FunctionDecl* FD = Function; FD = FD->getDefinition(); if (!diffArgs || !FD) { - DiffParamsInfo = {{}, {}}; return; } DiffParams params{}; @@ -283,11 +283,11 @@ namespace clad { return; } // Split the string by ',' characters, trim whitespaces. - llvm::SmallVector names{}; - llvm::StringRef name{}; + llvm::SmallVector diffParamsSpec{}; do { - std::tie(name, string) = string.split(','); - names.push_back(name.trim()); + llvm::StringRef pInfo{}; + std::tie(pInfo, string) = string.split(','); + diffParamsSpec.push_back(pInfo.trim()); } while (!string.empty()); // Stores parameters and field declarations to be used as candidates for // independent arguments. @@ -303,38 +303,46 @@ namespace clad { if (FD->param_empty() && Functor) { for (FieldDecl* fieldDecl : Functor->fields()) candidates.emplace_back(fieldDecl->getName(), fieldDecl); - } else { for (auto PVD : FD->parameters()) candidates.emplace_back(PVD->getName(), PVD); } + auto computeParamName = [](llvm::StringRef diffSpec) { + std::size_t idx = diffSpec.find_first_of(".["); + return diffSpec.substr(0, idx); + }; + // Ensure that diff params are always considered in the same order. + // This is required to uniquely identify differentiation requests. std::sort( - names.begin(), names.end(), - [&candidates](llvm::StringRef a, llvm::StringRef b) { + diffParamsSpec.begin(), diffParamsSpec.end(), + [&candidates, &computeParamName](llvm::StringRef a, + llvm::StringRef b) { auto a_it = std::find_if( candidates.begin(), candidates.end(), - [a](const std::pair& candidate) { - return candidate.first == a; + [a, &computeParamName]( + const std::pair& candidate) { + return candidate.first == computeParamName(a); }); auto b_it = std::find_if( candidates.begin(), candidates.end(), - [b](const std::pair& candidate) { - return candidate.first == b; + [b, &computeParamName]( + const std::pair& candidate) { + return candidate.first == computeParamName(b); }); return a_it < b_it; }); - for (const auto& name : names) { - size_t loc = name.find('['); - loc = (loc == llvm::StringRef::npos) ? name.size() : loc; - llvm::StringRef base = name.substr(0, loc); + for (const auto& diffSpec : diffParamsSpec) { + DiffInputVarInfo dVarInfo; + dVarInfo.source = diffSpec.str(); + llvm::StringRef pName = computeParamName(diffSpec); auto it = std::find_if(std::begin(candidates), std::end(candidates), - [&base]( + [&pName]( const std::pair& - p) { return p.first == base; }); + p) { return p.first == pName; }); if (it == std::end(candidates)) { // Fail if the function has no parameter with specified name. @@ -343,13 +351,16 @@ namespace clad { "Requested parameter name '%0' was not found among " "function " "parameters", - {base}); + {pName}); return; } - auto f_it = std::find(std::begin(params), std::end(params), it->second); + auto f_it = std::find_if(std::begin(DVI), std::end(DVI), + [&it](const DiffInputVarInfo& dVarInfo) { + return dVarInfo.param == it->second; + }); - if (f_it != params.end()) { + if (f_it != DVI.end()) { utils:: EmitDiag(semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), "Requested parameter '%0' was specified multiple times", @@ -357,10 +368,11 @@ namespace clad { return; } - params.push_back(it->second); - - if (loc != name.size()) { - llvm::StringRef interval(name.slice(loc + 1, name.find(']'))); + dVarInfo.param = it->second; + + std::size_t lSqBracketIdx = diffSpec.find("["); + if (lSqBracketIdx != llvm::StringRef::npos) { + llvm::StringRef interval(diffSpec.slice(lSqBracketIdx + 1, diffSpec.find(']'))); llvm::StringRef firstStr, lastStr; std::tie(firstStr, lastStr) = interval.split(':'); @@ -368,7 +380,7 @@ namespace clad { // The string is not a range just a single index size_t index; firstStr.getAsInteger(10, index); - indexes.push_back(IndexInterval(index)); + dVarInfo.paramIndexInterval = IndexInterval(index); } else { size_t first, last; firstStr.getAsInteger(10, first); @@ -377,23 +389,57 @@ namespace clad { utils::EmitDiag(semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), "Range specified in '%0' is in incorrect format", - {name}); + {diffSpec}); return; } - indexes.push_back(IndexInterval(first, last)); + dVarInfo.paramIndexInterval = IndexInterval(first, last); } } else { - indexes.push_back(IndexInterval()); + dVarInfo.paramIndexInterval = IndexInterval(); + } + + std::size_t dotIdx = diffSpec.find("."); + dotIdx += (dotIdx != StringRef::npos); + StringRef fieldsSpec = diffSpec.substr(dotIdx); + while (!fieldsSpec.empty()) { + StringRef fieldName; + std::tie(fieldName, fieldsSpec) = fieldsSpec.split('.'); + dVarInfo.fields.push_back(fieldName.str()); + } + + if (!dVarInfo.param->getType()->isRecordType() && + !dVarInfo.fields.empty()) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Level::Error, diffArgs->getEndLoc(), + "Fields can only be provided for class type parameters. " + "Field information is incorrectly specified in '%0' " + "for non-class type parameter '%1'", + {diffSpec, pName}); + return; } + + if (!dVarInfo.fields.empty()) { + RecordDecl* RD = dVarInfo.param->getType()->getAsCXXRecordDecl(); + llvm::SmallVector ref(dVarInfo.fields.begin(), + dVarInfo.fields.end()); + bool isValid = utils::IsValidMemExprPath(semaRef, RD, ref); + if (!isValid) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Level::Error, diffArgs->getEndLoc(), + "Path specified by fields in '%0' is invalid.", {diffSpec}); + return; + } + } + + DVI.push_back(dVarInfo); } - // Return a sequence of function's parameters. - DiffParamsInfo = {params, indexes}; return; } // Case 2) // Check if the provided literal can be evaluated as an integral value. llvm::APSInt intValue; if (clad_compat::Expr_EvaluateAsInt(E, intValue, C)) { + DiffInputVarInfo dVarInfo; auto idx = intValue.getExtValue(); // If we are differentiating a call operator that have no parameters, then // we need to search for independent parameters in fields of the @@ -410,7 +456,7 @@ namespace clad { {std::to_string(idx), std::to_string(totalFields)}); return; } - params.push_back(*std::next(Functor->field_begin(), idx)); + dVarInfo.param = *std::next(Functor->field_begin(), idx); } else { // Fail if the specified index is invalid. if ((idx < 0) || (idx >= FD->getNumParams())) { @@ -421,10 +467,10 @@ namespace clad { std::to_string(FD->getNumParams())}); return; } - params.push_back(FD->getParamDecl(idx)); + dVarInfo.param = FD->getParamDecl(idx); } // Returns a single parameter. - DiffParamsInfo = {params, {}}; + DVI.push_back(dVarInfo); return; } // Case 3) @@ -443,26 +489,24 @@ namespace clad { return; } - // If it is a Vector valued function, the last parameter is to store the output vector - // and hence is not a differentiable parameter, so we must pop it out + // If it is a Vector valued function, the last parameter is to store the + // output vector and hence is not a differentiable parameter, so we must + // pop it out if (this->Mode == DiffMode::jacobian){ params.pop_back(); } - - IndexIntervalTable indexes{}; // insert an empty index for each parameter. - for (unsigned i=0; igetEndLoc(), "Failed to parse the parameters, must be a string or " "numeric literal"); - DiffParamsInfo = {{}, {}}; return; } diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index 8b80b53ba..b6c75876e 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -199,22 +199,20 @@ namespace clad { "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - DiffParams args{}; - IndexIntervalTable indexIntervalTable{}; - if (request.Args) - std::tie(args, indexIntervalTable) = request.DiffParamsInfo; - else { - // FIXME: implement gradient-vector products to fix the issue. - assert((FD->getNumParams() <= 1) && - "nested forward mode differentiation for several args is broken"); - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - } - if (args.empty()) + DiffInputVarsInfo DVI = request.DVI; + + DVI = request.DVI; + + // FIXME: Shouldn't we give error here that no arg is specified? + if (DVI.empty()) return {}; + + DiffInputVarInfo diffVarInfo = DVI.back(); + // Check that only one arg is requested and if the arg requested is of array // or pointer type, only one of the indices have been requested - if (args.size() > 1 || (isArrayOrPointerType(args[0]->getType()) && - (indexIntervalTable.size() != 1 || indexIntervalTable[0].size() != 1))) { + if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) && + (diffVarInfo.paramIndexInterval.size() != 1))) { diag(DiagnosticsEngine::Error, request.Args ? request.Args->getEndLoc() : noLoc, "Forward mode differentiation w.r.t. several parameters at once is " @@ -224,7 +222,17 @@ namespace clad { return {}; } - m_IndependentVar = args.back(); + // FIXME: implement gradient-vector products to fix the issue. + assert((DVI.size() == 1) && + "nested forward mode differentiation for several args is broken"); + + // FIXME: Differentiation variable cannot always be represented just by + // `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc. + // FIXME: independent variable is misleading terminology, what we actually + // mean here is 'variable' with respect to which differentiation is being + // performed. Mathematically, independent variables are all the function + // parameters, thus, does not convey the intendend meaning. + m_IndependentVar = DVI.back().param; std::string derivativeSuffix(""); // If param is not real (i.e. floating point or integral), a pointer to a // real type, or an array of a real type we cannot differentiate it. @@ -240,14 +248,26 @@ namespace clad { {m_IndependentVar->getNameAsString()}); return {}; } - m_IndependentVarIndex = indexIntervalTable[0].Start; + m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start; derivativeSuffix = "_" + std::to_string(m_IndependentVarIndex); - } else if (!IsRealNonReferenceType(m_IndependentVar->getType())) { - diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(), - "attempted differentiation w.r.t. a parameter ('%0') which is not " - "of a real type", - {m_IndependentVar->getNameAsString()}); - return {}; + } + else { + QualType T = m_IndependentVar->getType(); + bool isField = false; + if (auto RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) { + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + T = utils::ComputeMemExprPathType(m_Sema, RD, ref); + isField = true; + } + if (!IsRealNonReferenceType(T)) { + diag(DiagnosticsEngine::Error, request.Args->getEndLoc(), + "Attempted differentiation w.r.t. %0 '%1' which is not " + "of real type.", + {(isField ? "member" : "parameter"), + diffVarInfo.source}); + return {}; + } } m_DerivativeOrder = request.CurrentDerivativeOrder; std::string s = std::to_string(m_DerivativeOrder); @@ -269,9 +289,13 @@ namespace clad { m_IndependentVar)); } + std::string argInfo = std::to_string(m_ArgIndex); + for (auto field : diffVarInfo.fields) + argInfo += "_" + field; + IdentifierInfo* II = &m_Context.Idents.get(request.BaseFunctionName + "_d" + s + "arg" + - std::to_string(m_ArgIndex) + derivativeSuffix); + argInfo + derivativeSuffix); SourceLocation loc{m_Function->getLocation()}; DeclarationNameInfo name(II, loc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); @@ -362,6 +386,18 @@ namespace clad { dParam); addToCurrentBlock(BuildDeclStmt(dParamDecl)); dParam = BuildDeclRef(dParamDecl); + if (dParamType->isRecordType() && param == m_IndependentVar) { + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + Expr* memRef = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref); + assert(memRef->getType()->isRealType() && + "Forward mode can only differentiate w.r.t builtin scalar " + "numerical types."); + addToCurrentBlock(BuildOp( + BinaryOperatorKind::BO_Assign, memRef, + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1))); + } // Memorize the derivative of param, i.e. whenever the param is visited // in the future, it's derivative dParam is found (unless reassigned with // something new). diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 8142f3760..7f7c137c7 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -80,8 +80,14 @@ namespace clad { const DiffRequest& request) { DiffParams args{}; IndexIntervalTable indexIntervalTable{}; - if (request.Args) - std::tie(args, indexIntervalTable) = request.DiffParamsInfo; + DiffInputVarsInfo DVI; + if (request.Args) { + DVI = request.DVI; + for (auto dParam : DVI) { + args.push_back(dParam.param); + indexIntervalTable.push_back(dParam.paramIndexInterval); + } + } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 10b662599..1cfa37721 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -240,8 +240,11 @@ namespace clad { assert(m_Function && "Must not be null."); DiffParams args{}; + DiffInputVarsInfo DVI; if (request.Args) { - std::tie(args, std::ignore) = request.DiffParamsInfo; + DVI = request.DVI; + for (auto dParam : DVI) + args.push_back(dParam.param); } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); diff --git a/test/FirstDerivative/DiffInterface.C b/test/FirstDerivative/DiffInterface.C index 5a62e9b72..b3373a47a 100644 --- a/test/FirstDerivative/DiffInterface.C +++ b/test/FirstDerivative/DiffInterface.C @@ -89,6 +89,25 @@ double fn_with_no_params() { return 11; } +struct Complex { + double real, im; + double getReal() { + return real; + } +}; + +double fn_with_Complex_type_param(Complex c) { + return c.real + c.im; +} + +struct ComplexPair { + Complex c1, c2; +}; + +double fn_with_ComplexPair_type_param(ComplexPair cp) { + return 11; +}; + int main () { int x = 4 * 5; clad::differentiate(f_1, 0); @@ -133,5 +152,10 @@ int main () { clad::differentiate(f_2, ""); // expected-error {{No parameters were provided}} clad::differentiate(fn_with_no_params); // expected-error {{Attempted to differentiate a function without parameters}} + clad::differentiate(f_2, "x.mem1"); // expected-error {{Fields can only be provided for class type parameters. Field information is incorrectly specified in 'x.mem1' for non-class type parameter 'x'}} + clad::differentiate(fn_with_Complex_type_param, "c.real.im"); // expected-error {{Path specified by fields in 'c.real.im' is invalid.}} + clad::differentiate(fn_with_ComplexPair_type_param, "cp.c1"); // expected-error {{Attempted differentiation w.r.t. member 'cp.c1' which is not of real type.}} + clad::differentiate(fn_with_Complex_type_param, "c.getReal"); // expected-error {{Path specified by fields in 'c.getReal' is invalid.}} + clad::differentiate(fn_with_Complex_type_param, "c.invalidField"); // expected-error {{Path specified by fields in 'c.invalidField' is invalid.}} return 0; } diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index 69e4a2fd5..135598c0e 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -1145,6 +1145,57 @@ double fn14(double i, double j) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +using pairdd = std::pair; +using pair_of_pairdd = std::pair; + +double fn15(pairdd u, pairdd v) { + return u.first + 2*v.first; +} + +// CHECK: double fn15_darg0_first(pairdd u, pairdd v) { +// CHECK-NEXT: pairdd _d_u; +// CHECK-NEXT: _d_u.first = 1; +// CHECK-NEXT: pairdd _d_v; +// CHECK-NEXT: double &_t0 = v.first; +// CHECK-NEXT: return _d_u.first + 0 * _t0 + 2 * _d_v.first; +// CHECK-NEXT: } + +double fn16(pair_of_pairdd u, pair_of_pairdd v) { + return u.first.first + 2*v.second.second; +} + +// CHECK: double fn16_darg1_second_second(pair_of_pairdd u, pair_of_pairdd v) { +// CHECK-NEXT: pair_of_pairdd _d_u; +// CHECK-NEXT: pair_of_pairdd _d_v; +// CHECK-NEXT: _d_v.second.second = 1; +// CHECK-NEXT: double &_t0 = v.second.second; +// CHECK-NEXT: return _d_u.first.first + 0 * _t0 + 2 * _d_v.second.second; +// CHECK-NEXT: } + + +struct A { + double mem; + A(double p_mem = 0) : mem(p_mem) {} +}; + +struct B : public A { + double mem; + B(double p_mem = 0) : A(0), mem(p_mem) {} +}; + +double fn17(A a, B b) { + return a.mem * b.mem; +} + +// CHECK: double fn17_darg1_mem(A a, B b) { +// CHECK-NEXT: A _d_a; +// CHECK-NEXT: B _d_b; +// CHECK-NEXT: _d_b.mem = 1; +// CHECK-NEXT: double &_t0 = a.mem; +// CHECK-NEXT: double &_t1 = b.mem; +// CHECK-NEXT: return _d_a.mem * _t1 + _t0 * _d_b.mem; +// CHECK-NEXT: } + template void print(const Tensor& t) { for (int i=0; i