Skip to content

Commit

Permalink
Respect shadow declarations when writing propagators.
Browse files Browse the repository at this point in the history
In cases where the public declaration is introduced with using declaration
pointing to an internal namespace with the implementation details, we should
put the propagator function in the namespace of the public function and not the
implementation. That would allow users to position their pullbacks in the same
namespace structure as the used functions.
  • Loading branch information
vgvassilev committed Dec 15, 2024
1 parent fb10bdc commit 58d68d6
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 44 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ namespace clad {
/// overload to be found.
/// \param[in] CallArgs The call args to be used to resolve to the
/// correct overload.
/// \param[in] callSite - The call expression which triggers the custom
/// derivative call.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
Expand All @@ -117,8 +119,8 @@ namespace clad {
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, const clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true,
clang::Scope* S, const clang::Expr* callSite, bool forCustomDerv = true,
bool namespaceShouldExist = true,
clang::Expr* CUDAExecConfig = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
Expand Down
4 changes: 3 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include <llvm/ADT/ArrayRef.h>

#include <array>
#include <limits>
#include <memory>
Expand Down Expand Up @@ -107,7 +109,7 @@ namespace clad {

/// Tries to find and build call to user-provided `_forw` function.
clang::Expr* BuildCallToCustomForwPassFn(
const clang::FunctionDecl* FD, llvm::ArrayRef<clang::Expr*> primalArgs,
const clang::Expr* callSite, llvm::ArrayRef<clang::Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, clang::Expr* baseExpr);

public:
Expand Down
9 changes: 3 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1226,17 +1226,15 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
FD->getDeclContext());
customPushforward, customDerivativeArgs, getCurrentScope(), CE);
// Custom derivative templates can be written in a
// general way that works for both vectorized and non-vectorized
// modes. We have to also look for the pushforward with the regular name.
if (!callDiff && m_DiffReq.Mode != DiffMode::forward) {
customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
FD->getDeclContext());
customPushforward, customDerivativeArgs, getCurrentScope(), CE);
}
if (!isLambda) {
// Check if it is a recursive call.
Expand Down Expand Up @@ -2316,8 +2314,7 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
clad::utils::ComputeEffectiveFnName(CE->getConstructor()) +
GetPushForwardFunctionSuffix();
Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforwardName, customPushforwardArgs, getCurrentScope(),
CE->getConstructor()->getDeclContext());
customPushforwardName, customPushforwardArgs, getCurrentScope(), CE);
return pushforwardCall;
}
} // end namespace clad
23 changes: 22 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "JacobianModeVisitor.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
Expand Down Expand Up @@ -247,9 +248,29 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, const clang::DeclContext* originalFnDC,
clang::Scope* S, const clang::Expr* callSite,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
Expr* CUDAExecConfig /*=nullptr*/) {
const DeclContext* originalFnDC = nullptr;

// FIXME: callSite must not be null but it comes when we try to build
// a numerical diff call. We should merge both paths and remove the
// special branches being taken for propagators and numerical diff.
if (callSite) {
// Check if the callSite is not associated with a shadow declaration.
if (const auto* ME = dyn_cast<CXXMemberCallExpr>(callSite)) {
originalFnDC = ME->getMethodDecl()->getParent();
} else if (const auto* CE = dyn_cast<CallExpr>(callSite)) {
const Expr* Callee = CE->getCallee()->IgnoreParenCasts();
if (const auto* DRE = dyn_cast<DeclRefExpr>(Callee))
originalFnDC = DRE->getFoundDecl()->getDeclContext();
else if (const auto* MemberE = dyn_cast<MemberExpr>(Callee))
originalFnDC = MemberE->getFoundDecl().getDecl()->getDeclContext();
} else if (const auto* CtorExpr = dyn_cast<CXXConstructExpr>(callSite)) {
originalFnDC = CtorExpr->getConstructor()->getDeclContext();
}
}

CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand Down
33 changes: 18 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivedCallArgs.front()->getType(), m_Context, 1));
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
FD->getDeclContext(),
customPushforward, pushforwardCallArgs, getCurrentScope(), CE,
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (OverloadedDerivedFn)
Expand Down Expand Up @@ -1931,8 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
FD->getDeclContext(),
customPullback, pullbackCallArgs, getCurrentScope(), CE,
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (baseDiff.getExpr())
Expand Down Expand Up @@ -2064,7 +2062,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseDiff.getExpr_dx(), Loc));

if (Expr* customForwardPassCE =
BuildCallToCustomForwPassFn(FD, CallArgs, CallArgDx, baseExpr)) {
BuildCallToCustomForwPassFn(CE, CallArgs, CallArgDx, baseExpr)) {
if (!utils::isNonConstReferenceType(returnType) &&
!returnType->isPointerType())
return StmtDiff{customForwardPassCE};
Expand Down Expand Up @@ -2214,7 +2212,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*callSite=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false, CUDAExecConfig);
}
Expand Down Expand Up @@ -4247,8 +4245,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string customPullbackName = "constructor_pullback";
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(),
CE->getConstructor()->getDeclContext())) {
customPullbackName, pullbackArgs, getCurrentScope(), CE)) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
Expand Down Expand Up @@ -4278,9 +4275,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// SomeClass _d_c = _t0.adjoint;
// SomeClass c = _t0.value;
// ```
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn(
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (Expr* customReverseForwFnCall =
BuildCallToCustomForwPassFn(CE, primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (RD->isAggregate()) {
SmallString<128> Name_class;
llvm::raw_svector_ostream OS_class(Name_class);
Expand Down Expand Up @@ -4555,16 +4552,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn(
const FunctionDecl* FD, llvm::ArrayRef<Expr*> primalArgs,
const Expr* callSite, llvm::ArrayRef<Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, Expr* baseExpr) {
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw";
llvm::SmallVector<Expr*, 4> args;
if (baseExpr) {
baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr,
m_DiffReq->getLocation());
args.push_back(baseExpr);
}
const FunctionDecl* FD = nullptr;
if (const auto* CE = dyn_cast<CallExpr>(callSite))
FD = CE->getDirectCallee();
else
FD = cast<CXXConstructExpr>(callSite)->getConstructor();

if (auto CD = llvm::dyn_cast<CXXConstructorDecl>(FD)) {
const RecordDecl* RD = CD->getParent();
QualType constructorReverseForwTagT =
Expand All @@ -4582,9 +4583,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
args.append(primalArgs.begin(), primalArgs.end());
args.append(derivedArgs.begin(), derivedArgs.end());
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw";
Expr* customForwPassCE =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
forwPassFnName, args, getCurrentScope(), FD->getDeclContext());
forwPassFnName, args, getCurrentScope(), callSite);
return customForwPassCE;
}

Expand Down
33 changes: 32 additions & 1 deletion test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@
#include "../TestUtils.h"
extern "C" int printf(const char* fmt, ...);


namespace N {
namespace impl {
double sq(double x) { return x * x;}
}
using impl::sq; // using shadow
}

namespace clad {
namespace custom_derivatives {
namespace N {
clad::ValueAndPushforward<double, double> sq_pushforward(double x, double d_x) {
return { x * x, 2 * x * d_x };
}
}
}
}

float f0 (float x) {
return N::sq(x); // must find the sq_pushforward.
}

// CHECK: float f0_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::N::sq_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

namespace clad{
namespace custom_derivatives{
float f1_darg0(float x) {
Expand Down Expand Up @@ -208,7 +236,7 @@ double f12(double a, double b) { return std::fma(a, b, b); }
//CHECK: double f12_darg1(double a, double b) {
//CHECK-NEXT: double _d_a = 0;
//CHECK-NEXT: double _d_b = 1;
//CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::fma(double(), double(), double())), decltype(::std::fma(double(), double(), double()))> _t0 = clad::custom_derivatives::fma_pushforward(a, b, b, _d_a, _d_b, _d_b);
//CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<decltype(::std::fma(double(), double(), double())), decltype(::std::fma(double(), double(), double()))> _t0 = clad::custom_derivatives::std::fma_pushforward(a, b, b, _d_a, _d_b, _d_b);
//CHECK-NEXT: return _t0.pushforward;
//CHECK-NEXT: }

Expand Down Expand Up @@ -296,6 +324,9 @@ int main () { //expected-no-diagnostics
double d_result[2];
int i_result[1];

auto f0_darg0 = clad::differentiate(f0, 0);
printf("Result is = %f\n", f0_darg0.execute(1)); // CHECK-EXEC: Result is = 2

auto f1_darg0 = clad::differentiate(f1, 0);
printf("Result is = %f\n", f1_darg0.execute(60)); // CHECK-EXEC: Result is = -0.952413

Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ double f4_darg0(double x, int y);
// CHECK-NEXT: {
// CHECK-NEXT: _d_i = 0;
// CHECK-NEXT: for (i = 0; i < y; [&] {
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::std::sin_pushforward(x, _d_x);
// CHECK-NEXT: double &_t3 = _t2.value;
// CHECK-NEXT: _d_r = _d_r * _t3 + r * _t2.pushforward;
// CHECK-NEXT: r = r * _t3;
Expand Down Expand Up @@ -204,7 +204,7 @@ double f4_inc_darg0(double x, int y);
//CHECK-NEXT: {
//CHECK-NEXT: _d_i = 0;
//CHECK-NEXT: for (i = 0; i < y; [&] {
//CHECK-NEXT: ValueAndPushforward<double, double> _t3 = clad::custom_derivatives::sin_pushforward(x, _d_x);
//CHECK-NEXT: ValueAndPushforward<double, double> _t3 = clad::custom_derivatives::std::sin_pushforward(x, _d_x);
//CHECK-NEXT: double &_t4 = _t3.pushforward;
//CHECK-NEXT: double &_t5 = _t3.value;
//CHECK-NEXT: _d_r = _d_r * _t5 + r * _t4;
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ int main() {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::pow_pushforward(a.data[i], b.data[i], _d_a.data[i], _d_b.data[i]);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::std::pow_pushforward(a.data[i], b.data[i], _d_a.data[i], _d_b.data[i]);
// CHECK-NEXT: _d_res.data[i] = _t0.pushforward;
// CHECK-NEXT: res.data[i] = _t0.value;
// CHECK-NEXT: }
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ double f19(double a, double b) {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2);
//CHECK-NEXT: clad::custom_derivatives::std::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2);
//CHECK-NEXT: *_d_a += _r0;
//CHECK-NEXT: *_d_b += _r1;
//CHECK-NEXT: *_d_b += _r2;
Expand Down
6 changes: 3 additions & 3 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ void f_norm_grad(double x,
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r5 = 0.;
//CHECK-NEXT: clad::custom_derivatives::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5);
//CHECK-NEXT: clad::custom_derivatives::std::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5);
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: double _r3 = 0.;
Expand All @@ -430,10 +430,10 @@ void f_sin_grad(double x, double y, double *_d_x, double *_d_y);
//CHECK-NEXT: double _t0 = (std::sin(x) + std::sin(y));
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward;
//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(y, 1.).pushforward;
//CHECK-NEXT: *_d_y += _r1;
//CHECK-NEXT: *_d_x += _t0 * 1;
//CHECK-NEXT: *_d_y += _t0 * 1;
Expand Down
8 changes: 4 additions & 4 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -400,20 +400,20 @@ double f_log_gaus(double* x, double* p /*means*/, double n, double sigma) {
// CHECK-NEXT: double gaus = 1. / _t6 * _t5;
// CHECK-NEXT: {
// CHECK-NEXT: double _r8 = 0.;
// CHECK-NEXT: _r8 += 1 * clad::custom_derivatives::log_pushforward(gaus, 1.).pushforward;
// CHECK-NEXT: _r8 += 1 * clad::custom_derivatives::std::log_pushforward(gaus, 1.).pushforward;
// CHECK-NEXT: _d_gaus += _r8;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r3 = _d_gaus * _t5 * -(1. / (_t6 * _t6));
// CHECK-NEXT: double _r4 = 0.;
// CHECK-NEXT: _r4 += _r3 * clad::custom_derivatives::sqrt_pushforward(_t7 * sigma, 1.).pushforward;
// CHECK-NEXT: _r4 += _r3 * clad::custom_derivatives::std::sqrt_pushforward(_t7 * sigma, 1.).pushforward;
// CHECK-NEXT: double _r5 = 0.;
// CHECK-NEXT: double _r6 = 0.;
// CHECK-NEXT: clad::custom_derivatives::pow_pullback(2 * 3.1415926535897931, n, _r4 * sigma, &_r5, &_r6);
// CHECK-NEXT: clad::custom_derivatives::std::pow_pullback(2 * 3.1415926535897931, n, _r4 * sigma, &_r5, &_r6);
// CHECK-NEXT: _d_n += _r6;
// CHECK-NEXT: _d_sigma += _t7 * _r4;
// CHECK-NEXT: double _r7 = 0.;
// CHECK-NEXT: _r7 += 1. / _t6 * _d_gaus * clad::custom_derivatives::exp_pushforward(power, 1.).pushforward;
// CHECK-NEXT: _r7 += 1. / _t6 * _d_gaus * clad::custom_derivatives::std::exp_pushforward(power, 1.).pushforward;
// CHECK-NEXT: _d_power += _r7;
// CHECK-NEXT: }
// CHECK-NEXT: {
Expand Down
4 changes: 2 additions & 2 deletions test/Jacobian/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ void fn1(double i, double j, double* output) {
// CHECK-NEXT: clad::array<double> _d_vector_i = clad::one_hot_vector(indepVarCount, {{0U|0UL|0ULL}});
// CHECK-NEXT: clad::array<double> _d_vector_j = clad::one_hot_vector(indepVarCount, {{1U|1UL|1ULL}});
// CHECK-NEXT: *_d_vector_output = clad::identity_matrix(_d_vector_output->rows(), indepVarCount, {{2U|2UL|2ULL}});
// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::pow_pushforward(i, j, _d_vector_i, _d_vector_j);
// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::std::pow_pushforward(i, j, _d_vector_i, _d_vector_j);
// CHECK-NEXT: *_d_vector_output[0] = _t0.pushforward;
// CHECK-NEXT: output[0] = _t0.value;
// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::pow_pushforward(j, i, _d_vector_j, _d_vector_i);
// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::std::pow_pushforward(j, i, _d_vector_j, _d_vector_i);
// CHECK-NEXT: *_d_vector_output[1] = _t1.pushforward;
// CHECK-NEXT: output[1] = _t1.value;
// CHECK-NEXT: }
Expand Down
8 changes: 4 additions & 4 deletions test/NestedCalls/NestedCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ int main () { // expected-no-diagnostics
// CHECK: clad::ValueAndPushforward<double, double> sq_pushforward(double x, double _d_x);

// CHECK: clad::ValueAndPushforward<double, double> one_pushforward(double x, double _d_x) {
// CHECK-NEXT: ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::std::sin_pushforward(x, _d_x);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = sq_pushforward(_t0.value, _t0.pushforward);
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::cos_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::std::cos_pushforward(x, _d_x);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t3 = sq_pushforward(_t2.value, _t2.pushforward);
// CHECK-NEXT: return {_t1.value + _t3.value, _t1.pushforward + _t3.pushforward};
// CHECK-NEXT: }
Expand All @@ -71,12 +71,12 @@ int main () { // expected-no-diagnostics
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: sq_pullback(std::sin(x), _d_y, &_r0);
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: double _r2 = 0.;
//CHECK-NEXT: sq_pullback(std::cos(x), _d_y, &_r2);
//CHECK-NEXT: double _r3 = 0.;
//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::cos_pushforward(x, 1.).pushforward;
//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::std::cos_pushforward(x, 1.).pushforward;
//CHECK-NEXT: *_d_x += _r3;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down
Loading

0 comments on commit 58d68d6

Please sign in to comment.