diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index e15d3bfc5..8da3200f4 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -250,6 +250,48 @@ namespace clad { /// If `S` is `null`, then nothing happens. void AppendIndividualStmts(llvm::SmallVectorImpl& block, clang::Stmt* S); + /// Builds a nested member expression that consist of base expression + /// specified by `base` argument and data members specified in `fields` + /// argument in the original sequence. + /// + /// For example, if `base` represents `b` -- an expression of a record type, + /// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function + /// builds and returns the following expression: + /// ``` + /// b.mem1.mem2.mem3 + /// ``` + clang::MemberExpr* + BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, + llvm::ArrayRef fields); + + /// Returns true if member expression path specified by `fields` is correct; + /// otherwise returns false. + /// + /// For example, if `base` represents `b` -- an expression of a record type, + /// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function + /// returns true if `b.mem1.mem2.mem3` is a valid data member reference + /// expression, otherwise returns false. + /// + /// \note This function returns true if `fields` is an empty sequence. + bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields); + + /// Perform lookup for data member with name `name`. If lookup finds a + /// declaration, then return the field declaration; otherwise returns + /// `nullptr`. + clang::FieldDecl* LookupDataMember(clang::Sema& semaRef, + clang::RecordDecl* RD, + llvm::StringRef name); + + /// Computes the type of a data member of the record specified by `RD` + /// and nested fields specified in `fields` argument. + /// For example, if `RD` represents `std::pair, std::pair`, and `fields` is the sequence + /// {'first', 'first'}, then the corresponding data member is + // of type `std::complex`. + clang::QualType + ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields); } // namespace utils } diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 0c6a46d80..3a85f435b 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -53,9 +53,13 @@ namespace clad { /// differentiated, for example, when we are computing higher /// order derivatives. const clang::CXXRecordDecl* Functor = nullptr; - DiffParamsWithIndices DiffParamsInfo; - /// Recomputes `DiffParamsInfo` using the current values of data members. + /// Stores differentiation parameters information. Stored information + /// includes info on indices range for array parameters, and nested data + /// member information for record (class) type parameters. + DiffInputVarsInfo DVI; + + /// Recomputes `DiffInputVarsInfo` using the current values of data members. /// /// Differentiation parameters info is computed by parsing the argument /// expression for the clad differentiation function calls. The argument is diff --git a/include/clad/Differentiator/ParseDiffArgsTypes.h b/include/clad/Differentiator/ParseDiffArgsTypes.h index e7ae0210e..fb53930ae 100644 --- a/include/clad/Differentiator/ParseDiffArgsTypes.h +++ b/include/clad/Differentiator/ParseDiffArgsTypes.h @@ -6,7 +6,9 @@ #define CLAD_PARSE_DIFF_ARGS_TYPES_H #include "clang/AST/Decl.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include #include @@ -34,10 +36,51 @@ namespace clad { return Start == rhs.Start && Finish == rhs.Finish; } }; + + using IndexIntervalTable = llvm::SmallVector; + + /// `DiffInputVarInfo` is designed to store all the essential information about a + /// differentiation input variable. Please note that here input variable corresponds + /// to mathematical variable, not a programming one. + // FIXME: 'DiffInputVarInfo' name is probably not accurate, since we can have multiple + // differentiation input variables for same parameter as well. 'DiffInputVarInfo' + // name implicitly guides that there would be at most one `DiffInputVarInfo` object for + // one parameter, but that is not strictly true. + struct DiffInputVarInfo { + /// Source string specified by user that defines differentiation + /// specification for the input variable. + /// For example, if complete input string specified by user is: + /// 'u, v.first, arr[3]' + /// then `source` data member value for 2nd input variable should be + /// 'v.first' + std::string source; + /// Parameter associated with the input variable. + const clang::ValueDecl* param = nullptr; + /// array index range associated with the parameter. + IndexInterval paramIndexInterval; + /// Nested field information. + llvm::SmallVector fields; + // FIXME: Add support for differentiating with respect to array fields. + // llvm::SmallVector fieldIndexIntervals; + + DiffInputVarInfo(const clang::ValueDecl* pParam = nullptr, + IndexInterval pParamIndexInterval = {}, + llvm::SmallVector pFields = {}) + : param(pParam), paramIndexInterval(pParamIndexInterval), + fields(pFields) {} + + // FIXME: Move function definitions to ParseDiffArgTypes.cpp + bool operator==(const DiffInputVarInfo& rhs) const { + return param == rhs.param && + paramIndexInterval == rhs.paramIndexInterval && + fields == rhs.fields; + } + }; + + using DiffInputVarsInfo = llvm::SmallVector; using DiffParams = llvm::SmallVector; - using IndexIntervalTable = llvm::SmallVector; using DiffParamsWithIndices = std::pair; -} // namespace clad + } // namespace clad #endif \ No newline at end of file diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index a2f114fec..928c4f601 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -420,5 +420,65 @@ namespace clad { else if (S) block.push_back(S); } + + MemberExpr* + BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, + llvm::ArrayRef fields) { + MemberExpr* ME = nullptr; + for (auto field : fields) { + ME = BuildMemberExpr(semaRef, S, base, field); + base = ME; + } + return ME; + } + + clang::FieldDecl* LookupDataMember(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::StringRef name) { + LookupResult R(semaRef, BuildDeclarationNameInfo(semaRef, name), + Sema::LookupNameKind::LookupMemberName); + CXXScopeSpec CSS; + semaRef.LookupQualifiedName(R, RD, CSS); + if (R.empty()) + return nullptr; + assert(R.isSingleResult() && "Lookup in valid classes should always " + "return a single data member result."); + auto D = R.getFoundDecl(); + // We are looking data members only! + if (auto FD = dyn_cast(D)) + return FD; + return nullptr; + } + + bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields) { + for (std::size_t i = 0; i < fields.size(); ++i) { + FieldDecl* FD = LookupDataMember(semaRef, RD, fields[i]); + if (!FD) + return false; + if (FD->getType()->isRecordType()) + RD = FD->getType()->getAsCXXRecordDecl(); + // Current Field declaration is not of record type, therefore + // it cannot have any field within it. And any member access + // ('.') expression would be an invalid path. + else if (i != fields.size() - 1) + return false; + } + return true; + } + + clang::QualType + ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD, + llvm::ArrayRef fields) { + assert(IsValidMemExprPath(semaRef, RD, fields) && + "Invalid field path specified!"); + QualType T = RD->getTypeForDecl()->getCanonicalTypeInternal(); + for (auto field : fields) { + auto FD = LookupDataMember(semaRef, RD, field); + if (FD->getType()->isRecordType()) + RD = FD->getType()->getAsCXXRecordDecl(); + T = FD->getType(); + } + return T; + } } // namespace utils } // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d9a111eed..e9125b2a1 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -263,12 +263,12 @@ namespace clad { } void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) { + DVI.clear(); auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; const FunctionDecl* FD = Function; FD = FD->getDefinition(); if (!diffArgs || !FD) { - DiffParamsInfo = {{}, {}}; return; } DiffParams params{}; @@ -283,11 +283,11 @@ namespace clad { return; } // Split the string by ',' characters, trim whitespaces. - llvm::SmallVector names{}; - llvm::StringRef name{}; + llvm::SmallVector diffParamsSpec{}; do { - std::tie(name, string) = string.split(','); - names.push_back(name.trim()); + llvm::StringRef pInfo{}; + std::tie(pInfo, string) = string.split(','); + diffParamsSpec.push_back(pInfo.trim()); } while (!string.empty()); // Stores parameters and field declarations to be used as candidates for // independent arguments. @@ -303,38 +303,46 @@ namespace clad { if (FD->param_empty() && Functor) { for (FieldDecl* fieldDecl : Functor->fields()) candidates.emplace_back(fieldDecl->getName(), fieldDecl); - } else { for (auto PVD : FD->parameters()) candidates.emplace_back(PVD->getName(), PVD); } + auto computeParamName = [](llvm::StringRef diffSpec) { + std::size_t idx = diffSpec.find_first_of(".["); + return diffSpec.substr(0, idx); + }; + // Ensure that diff params are always considered in the same order. + // This is required to uniquely identify differentiation requests. std::sort( - names.begin(), names.end(), - [&candidates](llvm::StringRef a, llvm::StringRef b) { + diffParamsSpec.begin(), diffParamsSpec.end(), + [&candidates, &computeParamName](llvm::StringRef a, + llvm::StringRef b) { auto a_it = std::find_if( candidates.begin(), candidates.end(), - [a](const std::pair& candidate) { - return candidate.first == a; + [a, &computeParamName]( + const std::pair& candidate) { + return candidate.first == computeParamName(a); }); auto b_it = std::find_if( candidates.begin(), candidates.end(), - [b](const std::pair& candidate) { - return candidate.first == b; + [b, &computeParamName]( + const std::pair& candidate) { + return candidate.first == computeParamName(b); }); return a_it < b_it; }); - for (const auto& name : names) { - size_t loc = name.find('['); - loc = (loc == llvm::StringRef::npos) ? name.size() : loc; - llvm::StringRef base = name.substr(0, loc); + for (const auto& diffSpec : diffParamsSpec) { + DiffInputVarInfo dVarInfo; + dVarInfo.source = diffSpec.str(); + llvm::StringRef pName = computeParamName(diffSpec); auto it = std::find_if(std::begin(candidates), std::end(candidates), - [&base]( + [&pName]( const std::pair& - p) { return p.first == base; }); + p) { return p.first == pName; }); if (it == std::end(candidates)) { // Fail if the function has no parameter with specified name. @@ -343,13 +351,16 @@ namespace clad { "Requested parameter name '%0' was not found among " "function " "parameters", - {base}); + {pName}); return; } - auto f_it = std::find(std::begin(params), std::end(params), it->second); + auto f_it = std::find_if(std::begin(DVI), std::end(DVI), + [&it](const DiffInputVarInfo& dVarInfo) { + return dVarInfo.param == it->second; + }); - if (f_it != params.end()) { + if (f_it != DVI.end()) { utils:: EmitDiag(semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), "Requested parameter '%0' was specified multiple times", @@ -357,10 +368,11 @@ namespace clad { return; } - params.push_back(it->second); - - if (loc != name.size()) { - llvm::StringRef interval(name.slice(loc + 1, name.find(']'))); + dVarInfo.param = it->second; + + std::size_t lSqBracketIdx = diffSpec.find("["); + if (lSqBracketIdx != llvm::StringRef::npos) { + llvm::StringRef interval(diffSpec.slice(lSqBracketIdx + 1, diffSpec.find(']'))); llvm::StringRef firstStr, lastStr; std::tie(firstStr, lastStr) = interval.split(':'); @@ -368,7 +380,7 @@ namespace clad { // The string is not a range just a single index size_t index; firstStr.getAsInteger(10, index); - indexes.push_back(IndexInterval(index)); + dVarInfo.paramIndexInterval = IndexInterval(index); } else { size_t first, last; firstStr.getAsInteger(10, first); @@ -377,23 +389,57 @@ namespace clad { utils::EmitDiag(semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), "Range specified in '%0' is in incorrect format", - {name}); + {diffSpec}); return; } - indexes.push_back(IndexInterval(first, last)); + dVarInfo.paramIndexInterval = IndexInterval(first, last); } } else { - indexes.push_back(IndexInterval()); + dVarInfo.paramIndexInterval = IndexInterval(); + } + + std::size_t dotIdx = diffSpec.find("."); + dotIdx += (dotIdx != StringRef::npos); + StringRef fieldsSpec = diffSpec.substr(dotIdx); + while (!fieldsSpec.empty()) { + StringRef fieldName; + std::tie(fieldName, fieldsSpec) = fieldsSpec.split('.'); + dVarInfo.fields.push_back(fieldName.str()); + } + + if (!dVarInfo.param->getType()->isRecordType() && + !dVarInfo.fields.empty()) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Level::Error, diffArgs->getEndLoc(), + "Fields can only be provided for class type parameters. " + "Field information is incorrectly specified in '%0' " + "for non-class type parameter '%1'", + {diffSpec, pName}); + return; } + + if (!dVarInfo.fields.empty()) { + RecordDecl* RD = dVarInfo.param->getType()->getAsCXXRecordDecl(); + llvm::SmallVector ref(dVarInfo.fields.begin(), + dVarInfo.fields.end()); + bool isValid = utils::IsValidMemExprPath(semaRef, RD, ref); + if (!isValid) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Level::Error, diffArgs->getEndLoc(), + "Path specified by fields in '%0' is invalid.", {diffSpec}); + return; + } + } + + DVI.push_back(dVarInfo); } - // Return a sequence of function's parameters. - DiffParamsInfo = {params, indexes}; return; } // Case 2) // Check if the provided literal can be evaluated as an integral value. llvm::APSInt intValue; if (clad_compat::Expr_EvaluateAsInt(E, intValue, C)) { + DiffInputVarInfo dVarInfo; auto idx = intValue.getExtValue(); // If we are differentiating a call operator that have no parameters, then // we need to search for independent parameters in fields of the @@ -410,7 +456,7 @@ namespace clad { {std::to_string(idx), std::to_string(totalFields)}); return; } - params.push_back(*std::next(Functor->field_begin(), idx)); + dVarInfo.param = *std::next(Functor->field_begin(), idx); } else { // Fail if the specified index is invalid. if ((idx < 0) || (idx >= FD->getNumParams())) { @@ -421,10 +467,10 @@ namespace clad { std::to_string(FD->getNumParams())}); return; } - params.push_back(FD->getParamDecl(idx)); + dVarInfo.param = FD->getParamDecl(idx); } // Returns a single parameter. - DiffParamsInfo = {params, {}}; + DVI.push_back(dVarInfo); return; } // Case 3) @@ -443,26 +489,24 @@ namespace clad { return; } - // If it is a Vector valued function, the last parameter is to store the output vector - // and hence is not a differentiable parameter, so we must pop it out + // If it is a Vector valued function, the last parameter is to store the + // output vector and hence is not a differentiable parameter, so we must + // pop it out if (this->Mode == DiffMode::jacobian){ params.pop_back(); } - - IndexIntervalTable indexes{}; // insert an empty index for each parameter. - for (unsigned i=0; igetEndLoc(), "Failed to parse the parameters, must be a string or " "numeric literal"); - DiffParamsInfo = {{}, {}}; return; } diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index 8b80b53ba..b6c75876e 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -199,22 +199,20 @@ namespace clad { "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - DiffParams args{}; - IndexIntervalTable indexIntervalTable{}; - if (request.Args) - std::tie(args, indexIntervalTable) = request.DiffParamsInfo; - else { - // FIXME: implement gradient-vector products to fix the issue. - assert((FD->getNumParams() <= 1) && - "nested forward mode differentiation for several args is broken"); - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - } - if (args.empty()) + DiffInputVarsInfo DVI = request.DVI; + + DVI = request.DVI; + + // FIXME: Shouldn't we give error here that no arg is specified? + if (DVI.empty()) return {}; + + DiffInputVarInfo diffVarInfo = DVI.back(); + // Check that only one arg is requested and if the arg requested is of array // or pointer type, only one of the indices have been requested - if (args.size() > 1 || (isArrayOrPointerType(args[0]->getType()) && - (indexIntervalTable.size() != 1 || indexIntervalTable[0].size() != 1))) { + if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) && + (diffVarInfo.paramIndexInterval.size() != 1))) { diag(DiagnosticsEngine::Error, request.Args ? request.Args->getEndLoc() : noLoc, "Forward mode differentiation w.r.t. several parameters at once is " @@ -224,7 +222,17 @@ namespace clad { return {}; } - m_IndependentVar = args.back(); + // FIXME: implement gradient-vector products to fix the issue. + assert((DVI.size() == 1) && + "nested forward mode differentiation for several args is broken"); + + // FIXME: Differentiation variable cannot always be represented just by + // `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc. + // FIXME: independent variable is misleading terminology, what we actually + // mean here is 'variable' with respect to which differentiation is being + // performed. Mathematically, independent variables are all the function + // parameters, thus, does not convey the intendend meaning. + m_IndependentVar = DVI.back().param; std::string derivativeSuffix(""); // If param is not real (i.e. floating point or integral), a pointer to a // real type, or an array of a real type we cannot differentiate it. @@ -240,14 +248,26 @@ namespace clad { {m_IndependentVar->getNameAsString()}); return {}; } - m_IndependentVarIndex = indexIntervalTable[0].Start; + m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start; derivativeSuffix = "_" + std::to_string(m_IndependentVarIndex); - } else if (!IsRealNonReferenceType(m_IndependentVar->getType())) { - diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(), - "attempted differentiation w.r.t. a parameter ('%0') which is not " - "of a real type", - {m_IndependentVar->getNameAsString()}); - return {}; + } + else { + QualType T = m_IndependentVar->getType(); + bool isField = false; + if (auto RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) { + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + T = utils::ComputeMemExprPathType(m_Sema, RD, ref); + isField = true; + } + if (!IsRealNonReferenceType(T)) { + diag(DiagnosticsEngine::Error, request.Args->getEndLoc(), + "Attempted differentiation w.r.t. %0 '%1' which is not " + "of real type.", + {(isField ? "member" : "parameter"), + diffVarInfo.source}); + return {}; + } } m_DerivativeOrder = request.CurrentDerivativeOrder; std::string s = std::to_string(m_DerivativeOrder); @@ -269,9 +289,13 @@ namespace clad { m_IndependentVar)); } + std::string argInfo = std::to_string(m_ArgIndex); + for (auto field : diffVarInfo.fields) + argInfo += "_" + field; + IdentifierInfo* II = &m_Context.Idents.get(request.BaseFunctionName + "_d" + s + "arg" + - std::to_string(m_ArgIndex) + derivativeSuffix); + argInfo + derivativeSuffix); SourceLocation loc{m_Function->getLocation()}; DeclarationNameInfo name(II, loc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); @@ -362,6 +386,18 @@ namespace clad { dParam); addToCurrentBlock(BuildDeclStmt(dParamDecl)); dParam = BuildDeclRef(dParamDecl); + if (dParamType->isRecordType() && param == m_IndependentVar) { + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + Expr* memRef = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref); + assert(memRef->getType()->isRealType() && + "Forward mode can only differentiate w.r.t builtin scalar " + "numerical types."); + addToCurrentBlock(BuildOp( + BinaryOperatorKind::BO_Assign, memRef, + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1))); + } // Memorize the derivative of param, i.e. whenever the param is visited // in the future, it's derivative dParam is found (unless reassigned with // something new). diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 8142f3760..7f7c137c7 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -80,8 +80,14 @@ namespace clad { const DiffRequest& request) { DiffParams args{}; IndexIntervalTable indexIntervalTable{}; - if (request.Args) - std::tie(args, indexIntervalTable) = request.DiffParamsInfo; + DiffInputVarsInfo DVI; + if (request.Args) { + DVI = request.DVI; + for (auto dParam : DVI) { + args.push_back(dParam.param); + indexIntervalTable.push_back(dParam.paramIndexInterval); + } + } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 10b662599..1cfa37721 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -240,8 +240,11 @@ namespace clad { assert(m_Function && "Must not be null."); DiffParams args{}; + DiffInputVarsInfo DVI; if (request.Args) { - std::tie(args, std::ignore) = request.DiffParamsInfo; + DVI = request.DVI; + for (auto dParam : DVI) + args.push_back(dParam.param); } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); diff --git a/test/FirstDerivative/DiffInterface.C b/test/FirstDerivative/DiffInterface.C index 5a62e9b72..b3373a47a 100644 --- a/test/FirstDerivative/DiffInterface.C +++ b/test/FirstDerivative/DiffInterface.C @@ -89,6 +89,25 @@ double fn_with_no_params() { return 11; } +struct Complex { + double real, im; + double getReal() { + return real; + } +}; + +double fn_with_Complex_type_param(Complex c) { + return c.real + c.im; +} + +struct ComplexPair { + Complex c1, c2; +}; + +double fn_with_ComplexPair_type_param(ComplexPair cp) { + return 11; +}; + int main () { int x = 4 * 5; clad::differentiate(f_1, 0); @@ -133,5 +152,10 @@ int main () { clad::differentiate(f_2, ""); // expected-error {{No parameters were provided}} clad::differentiate(fn_with_no_params); // expected-error {{Attempted to differentiate a function without parameters}} + clad::differentiate(f_2, "x.mem1"); // expected-error {{Fields can only be provided for class type parameters. Field information is incorrectly specified in 'x.mem1' for non-class type parameter 'x'}} + clad::differentiate(fn_with_Complex_type_param, "c.real.im"); // expected-error {{Path specified by fields in 'c.real.im' is invalid.}} + clad::differentiate(fn_with_ComplexPair_type_param, "cp.c1"); // expected-error {{Attempted differentiation w.r.t. member 'cp.c1' which is not of real type.}} + clad::differentiate(fn_with_Complex_type_param, "c.getReal"); // expected-error {{Path specified by fields in 'c.getReal' is invalid.}} + clad::differentiate(fn_with_Complex_type_param, "c.invalidField"); // expected-error {{Path specified by fields in 'c.invalidField' is invalid.}} return 0; } diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index 69e4a2fd5..135598c0e 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -1145,6 +1145,57 @@ double fn14(double i, double j) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +using pairdd = std::pair; +using pair_of_pairdd = std::pair; + +double fn15(pairdd u, pairdd v) { + return u.first + 2*v.first; +} + +// CHECK: double fn15_darg0_first(pairdd u, pairdd v) { +// CHECK-NEXT: pairdd _d_u; +// CHECK-NEXT: _d_u.first = 1; +// CHECK-NEXT: pairdd _d_v; +// CHECK-NEXT: double &_t0 = v.first; +// CHECK-NEXT: return _d_u.first + 0 * _t0 + 2 * _d_v.first; +// CHECK-NEXT: } + +double fn16(pair_of_pairdd u, pair_of_pairdd v) { + return u.first.first + 2*v.second.second; +} + +// CHECK: double fn16_darg1_second_second(pair_of_pairdd u, pair_of_pairdd v) { +// CHECK-NEXT: pair_of_pairdd _d_u; +// CHECK-NEXT: pair_of_pairdd _d_v; +// CHECK-NEXT: _d_v.second.second = 1; +// CHECK-NEXT: double &_t0 = v.second.second; +// CHECK-NEXT: return _d_u.first.first + 0 * _t0 + 2 * _d_v.second.second; +// CHECK-NEXT: } + + +struct A { + double mem; + A(double p_mem = 0) : mem(p_mem) {} +}; + +struct B : public A { + double mem; + B(double p_mem = 0) : A(0), mem(p_mem) {} +}; + +double fn17(A a, B b) { + return a.mem * b.mem; +} + +// CHECK: double fn17_darg1_mem(A a, B b) { +// CHECK-NEXT: A _d_a; +// CHECK-NEXT: B _d_b; +// CHECK-NEXT: _d_b.mem = 1; +// CHECK-NEXT: double &_t0 = a.mem; +// CHECK-NEXT: double &_t1 = b.mem; +// CHECK-NEXT: return _d_a.mem * _t1 + _t0 * _d_b.mem; +// CHECK-NEXT: } + template void print(const Tensor& t) { for (int i=0; i