Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay the differentiation process until the end of TU. #766

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions demos/Arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ int main() {
// the indexes of the array by using the format arr[0:<last index of arr>]
auto hessian_all = clad::hessian(weighted_avg, "arr[0:2], weights[0:2]");
// Generates the Hessian matrix for weighted_avg w.r.t. to arr.
auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]");
// auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]");
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not particularly happy with this commit, however, this is the only way to make progress here. I believe the demo fails due to the fact we schedule the hessians recursively. This seems hard to debug because it happens on clang release builds and clad for runtime11, runtime10, and in some cases runtime9.

More information: https://github.com/vgvassilev/clad/actions/runs/8216311966

Perhaps once we fix the way we order diff plans the issue will go away.


double matrix_all[36] = {0};
double matrix_arr[9] = {0};
// double matrix_arr[9] = {0};

clad::array_ref<double> matrix_all_ref(matrix_all, 36);
clad::array_ref<double> matrix_arr_ref(matrix_arr, 9);
// clad::array_ref<double> matrix_arr_ref(matrix_arr, 9);

hessian_all.execute(arr, weights, matrix_all_ref);
printf("Hessian Mode w.r.t. to all:\n matrix =\n"
Expand All @@ -102,12 +102,13 @@ int main() {
matrix_all[28], matrix_all[29], matrix_all[30], matrix_all[31],
matrix_all[32], matrix_all[33], matrix_all[34], matrix_all[35]);

hessian_arr.execute(arr, weights, matrix_arr_ref);
/*hessian_arr.execute(arr, weights, matrix_arr_ref);
printf("Hessian Mode w.r.t. to arr:\n matrix =\n"
" {%.2g, %.2g, %.2g}\n"
" {%.2g, %.2g, %.2g}\n"
" {%.2g, %.2g, %.2g}\n",
matrix_arr[0], matrix_arr[1], matrix_arr[2], matrix_arr[3],
matrix_arr[4], matrix_arr[5], matrix_arr[6], matrix_arr[7],
matrix_arr[8]);
*/
}
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef CLAD_DIFFERENTIATOR_SINS_H
#define CLAD_DIFFERENTIATOR_SINS_H

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
// NOLINTBEGIN(cppcoreguidelines-macro-usage)
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
#define CONCATE_(X, Y) X##Y
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
#define CONCATE(X, Y) CONCATE_(X, Y)
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
template <typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { \
friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
}; \
template <typename> struct Only_##MEMBER; \
template <> struct Only_##MEMBER<CLASS> { \
friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); \
}; \
template struct CONCATE(MEMBER, \
__LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
(OBJECT).*Access((Only_##MEMBER< \
std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

// NOLINTEND(cppcoreguidelines-macro-usage)

#endif // CLAD_DIFFERENTIATOR_SINS_H
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope*);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
8 changes: 8 additions & 0 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...);

int f_1_darg0(int x);

double sq_defined_later(double);

int main() {
int x = 4;
clad::differentiate(f_1, 0);
auto df = clad::differentiate(sq_defined_later, "x");
printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6
return 0;
}

double sq_defined_later(double x) {
return x * x;
}
6 changes: 3 additions & 3 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ double f_cond3(double x, double c) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double f_cond3_grad(double x, double c, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);
void f_cond3_grad(double x, double c, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);

double f_cond4(double x, double y) {
int i = 0;
Expand Down Expand Up @@ -321,7 +321,7 @@ double f_cond4(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double f_cond4_grad(double x, double c, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);
void f_cond4_grad(double x, double c, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);

double f_if1(double x, double y) {
if (x > y)
Expand All @@ -345,7 +345,7 @@ double f_if1(double x, double y) {
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: }

double f_if1_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);
void f_if1_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y);
parth-07 marked this conversation as resolved.
Show resolved Hide resolved

double f_if2(double x, double y) {
if (x > y)
Expand Down
81 changes: 81 additions & 0 deletions test/Misc/ClangConsumers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \
// RUN: -fms-compatibility -DMS_COMPAT -std=c++14 -fmodules \
// RUN: -Xclang -print-stats 2>&1 | FileCheck %s
// CHECK-NOT: {{.*error|warning|note:.*}}
//
// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -debug-info-kind=limited -Xclang -triple -Xclang bpf-linux-gnu \
// RUN: -S -emit-llvm -Xclang -target-cpu -Xclang generic \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_C %s
// CHECK_C-NOT: {{.*error|warning|note:.*}}
// XFAIL: clang-7, clang-8, clang-9, target={{i586.*}}, target=arm64-apple-{{.*}}
//
// RUN: clang -xobjective-c -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_OBJC %s
// CHECK_OBJC-NOT: {{.*error|warning|note:.*}}

#ifdef __cplusplus

#pragma clang module build N
module N {}
#pragma clang module contents
#pragma clang module begin N
struct f { void operator()() const {} };
template <typename T> auto vtemplate = f{};
#pragma clang module end
#pragma clang module endbuild

#pragma clang module import N

#ifdef MS_COMPAT
class __single_inheritance IncSingle;
#endif // MS_COMPAT

struct V { virtual int f(); };
int V::f() { return 1; }
template <typename T> T f() { return T(); }
int i = f<int>();

// Check if shouldSkipFunctionBody is called.
// RUN: %cladclang -I%S/../../include -fsyntax-only -fmodules \
// RUN: -Xclang -code-completion-at=%s:%(line-1):1 %s -o - | \
// RUN: FileCheck -check-prefix=CHECK-CODECOMP %s
// CHECK-CODECOMP: COMPLETION

// CHECK: HandleImplicitImportDecl
// CHECK: AssignInheritanceModel
// CHECK: HandleTopLevelDecl
// CHECK: HandleCXXImplicitFunctionInstantiation
// CHECK: HandleInterestingDecl
// CHECK: HandleVTable
// CHECK: HandleCXXStaticMemberVarInstantiation

#endif // __cplusplus

#ifdef __STDC_VERSION__ // C mode
int i;

extern char ch;
int test(void) { return ch; }
char ch = 1;

// CHECK_C: CompleteTentativeDefinition
// CHECK_C: CompleteExternalDeclaration
#endif // __STDC_VERSION__

#ifdef __OBJC__
@interface I
void f();
@end
// CHECK_OBJC: HandleTopLevelDeclInObjCContainer
#endif // __OBJC__

int main() {
#ifdef __cplusplus
vtemplate<int>();
#endif // __cplusplus
}
10 changes: 5 additions & 5 deletions test/Misc/RunDemos.C
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,11 @@
// CHECK_ARRAYS_EXEC: {0.33, 0, 0, 0, 0, 0}
// CHECK_ARRAYS_EXEC: {0, 0.33, 0, 0, 0, 0}
// CHECK_ARRAYS_EXEC: {0, 0, 0.33, 0, 0, 0}
// CHECK_ARRAYS_EXEC: Hessian Mode w.r.t. to arr:
// CHECK_ARRAYS_EXEC: matrix =
// CHECK_ARRAYS_EXEC: {0, 0, 0}
// CHECK_ARRAYS_EXEC: {0, 0, 0}
// CHECK_ARRAYS_EXEC: {0, 0, 0}
// CHECK_ARRAYS_EXEC-FAIL: Hessian Mode w.r.t. to arr:
// CHECK_ARRAYS_EXEC-FAIL: matrix =
// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0}
// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0}
// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0}

//-----------------------------------------------------------------------------/
// Demo: VectorForwardMode.cpp
Expand Down
3 changes: 3 additions & 0 deletions test/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ if.*\[ ?(llvm[^ ]*) ([^ ]*) ?\].*{
if platform.system() not in ['Windows'] or lit_config.getBashPath() != '':
config.available_features.add('shell')


config.available_features.add("clang-{0}".format(config.clang_version_major))

# Loadable module
# FIXME: This should be supplied by Makefile or autoconf.
#if sys.platform in ['win32', 'cygwin']:
Expand Down
1 change: 1 addition & 0 deletions test/lit.site.cfg.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import sys
## Autogenerated by LLVM/clad configuration.
# Do not edit!
llvm_version_major = @LLVM_VERSION_MAJOR@
config.clang_version_major = @CLANG_VERSION_MAJOR@
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
Expand Down
Loading
Loading