Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable computation of CUDA global kernels derivative in reverse mode #1059

Merged
merged 12 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 107 additions & 26 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,23 @@
return count;
}

/// Tape type used for storing values in reverse-mode AD inside loops.
template <typename T>
using tape = tape_impl<T>;
#ifdef __CUDACC__
#define CUDA_ARGS bool CUDAkernel, dim3 grid, dim3 block,
#define CUDA_REST_ARGS size_t shared_mem, cudaStream_t stream,
#else
#define CUDA_ARGS
#define CUDA_REST_ARGS
#endif

/// Add value to the end of the tape, return the same value.
template <typename T, typename... ArgsT>
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
to.emplace_back(std::forward<ArgsT>(val)...);
return to.back();
}
/// Tape type used for storing values in reverse-mode AD inside loops.
template <typename T> using tape = tape_impl<T>;

/// Add value to the end of the tape, return the same value.
template <typename T, typename... ArgsT>
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
to.emplace_back(std::forward<ArgsT>(val)...);
return to.back();
}

/// Add value to the end of the tape, return the same value.
/// A specialization for clad::array_ref types to use in reverse mode.
Expand Down Expand Up @@ -115,17 +122,35 @@
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE return_type_t<F>
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
Args&&... args) {
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args..., (void*)static_cast<Rest>(nullptr)...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
} else {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}
#else
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
#endif
}

template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
return_type_t<F> execute_with_default_args(list<Rest...>, F f,
list<fArgTypes...>,
Args&&... args) {
return_type_t<F>
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
} else {
return f(static_cast<Args>(args)...);
}
#else
return f(static_cast<Args>(args)...);
#endif
}

// for executing member-functions
Expand Down Expand Up @@ -167,12 +192,13 @@
CladFunctionType m_Function;
char* m_Code;
FunctorType *m_Functor = nullptr;
bool m_CUDAkernel = false;

public:
CUDA_HOST_DEVICE CladFunction(CladFunctionType f,
const char* code,
FunctorType* functor = nullptr)
: m_Functor(functor) {
CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code,
FunctorType* functor = nullptr,
bool CUDAkernel = false)
: m_Functor(functor), m_CUDAkernel(CUDAkernel) {
assert(f && "Must pass a non-0 argument.");
if (size_t length = GetLength(code)) {
m_Function = f;
Expand Down Expand Up @@ -210,9 +236,37 @@
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
if (m_CUDAkernel) {
printf("Use execute_kernel() for global CUDA kernels\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not call c-style vararg functions [cppcoreguidelines-pro-type-vararg]

        printf("Use execute_kernel() for global CUDA kernels\n");
        ^

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably assert-out if users use execute instead of execute_kernel for CUDA kernels. 'printf' seems to be too subtle for reporting such a big error. @vgvassilev What do you think?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a programmatic error or user error. We can use assert/abort if the error is programmer error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a user error basically. They should use the appropriate execute function depending on whether their function is a CUDA kernel or not.

return static_cast<return_type_t<F>>(return_type_t<F>());

Check warning on line 241 in include/clad/Differentiator/Differentiator.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/Differentiator.h#L240-L241

Added lines #L240 - L241 were not covered by tests
}
// here static_cast is used to achieve perfect forwarding
#ifdef __CUDACC__
return execute_helper(m_Function, m_CUDAkernel, dim3(0), dim3(0),
static_cast<Args>(args)...);
#else
return execute_helper(m_Function, static_cast<Args>(args)...);
#endif
}

#ifdef __CUDACC__
template <typename... Args, class FnType = CladFunctionType>
typename std::enable_if<!std::is_same<FnType, NoFunction*>::value,
return_type_t<F>>::type
execute_kernel(dim3 grid, dim3 block, Args&&... args) CUDA_HOST_DEVICE {
if (!m_Function) {
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
if (!m_CUDAkernel) {
printf("Use execute() for non-global CUDA kernels\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}

return execute_helper(m_Function, m_CUDAkernel, grid, block,
static_cast<Args>(args)...);
}
#endif

/// `Execute` overload to be used when derived function type cannot be
/// deduced. One reason for this can be when user tries to differentiate
Expand Down Expand Up @@ -258,12 +312,39 @@
/// Helper function for executing non-member derived functions.
template <class Fn, class... Args>
CUDA_HOST_DEVICE return_type_t<CladFunctionType>
execute_helper(Fn f, Args&&... args) {
execute_helper(Fn f, CUDA_ARGS Args&&... args) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
#if defined(__CUDACC__)
if constexpr (sizeof...(Args) >= 2) {
auto secondArg =
std::get<1>(std::forward_as_tuple(std::forward<Args>(args)...));
if constexpr (std::is_same<std::decay_t<decltype(secondArg)>,
cudaStream_t>::value) {
return [&](auto shared_mem, cudaStream_t stream, auto&&... args_) {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args) - 2, F>{}, f,
TakeNFirstArgs_t<sizeof...(Args) - 2, decltype(f)>{},
CUDAkernel, grid, block, shared_mem, stream,
static_cast<decltype(args_)>(args_)...);
}(static_cast<Args>(args)...);
} else {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel,
grid, block, 0, nullptr, static_cast<Args>(args)...);
}
} else {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel,
grid, block, 0, nullptr, static_cast<Args>(args)...);
}
#else
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
#endif
}

/// Helper functions for executing member derived functions.
Expand Down Expand Up @@ -393,10 +474,10 @@
annotate("G"))) CUDA_HOST_DEVICE
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
const char* code = "", bool CUDAkernel = false) {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, nullptr, CUDAkernel);
}

/// Specialization for differentiating functors.
Expand Down
21 changes: 18 additions & 3 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ namespace clad {
void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD,
Sema& SemaRef) {
CallExpr* call = this->CallContext;
// Index of "code" parameter:
auto codeArgIdx = static_cast<int>(call->getNumArgs()) - 1;
auto derivedFnArgIdx = codeArgIdx - 1;

assert(call && "Must be set");
assert(FD && "Trying to update with null FunctionDecl");
Expand All @@ -191,6 +188,24 @@ namespace clad {
ASTContext& C = SemaRef.getASTContext();

FunctionDecl* replacementFD = OverloadedFD ? OverloadedFD : FD;

// Index of "CUDAkernel" parameter:
int numArgs = static_cast<int>(call->getNumArgs());
if (numArgs > 4) {
auto kernelArgIdx = numArgs - 1;
auto* cudaKernelFlag =
SemaRef
.ActOnCXXBoolLiteral(noLoc,
replacementFD->hasAttr<CUDAGlobalAttr>()
? tok::kw_true
: tok::kw_false)
.get();
call->setArg(kernelArgIdx, cudaKernelFlag);
numArgs--;
}
auto codeArgIdx = numArgs - 1;
auto derivedFnArgIdx = numArgs - 2;

// Create ref to generated FD.
DeclRefExpr* DRE =
DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD,
Expand Down
8 changes: 8 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildDeclStmt(gradientVD));
}

// If the function is a global kernel, we need to transform it
// into a device function when calling it inside the overload function
// which is the final global kernel returned.
if (m_Derivative->hasAttr<clang::CUDAGlobalAttr>()) {
m_Derivative->dropAttr<clang::CUDAGlobalAttr>();
m_Derivative->addAttr(clang::CUDADeviceAttr::CreateImplicit(m_Context));
}

Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs,
/*UseRefQualifiedThisObj=*/true);
addToCurrentBlock(callExpr);
Expand Down
79 changes: 79 additions & 0 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: %cladclang_cuda -I%S/../../include %s -fsyntax-only \
// RUN: %cudasmlevel --cuda-path=%cudapath -Xclang -verify 2>&1 | %filecheck %s

// RUN: %cladclang_cuda -I%S/../../include %s -xc++ %cudasmlevel \
// RUN: --cuda-path=%cudapath -L/usr/local/cuda/lib64 -lcudart_static \
// RUN: -L%cudapath/lib64/stubs \
// RUN: -ldl -lrt -pthread -lm -lstdc++ -lcuda -lnvrtc

// REQUIRES: cuda-runtime

// expected-no-diagnostics

// XFAIL: clang-15
parth-07 marked this conversation as resolved.
Show resolved Hide resolved

#include "clad/Differentiator/Differentiator.h"

__global__ void kernel(int *a) {
*a *= *a;
}

// CHECK: void kernel_grad(int *a, int *_d_a) {
//CHECK-NEXT: int _t0 = *a;
//CHECK-NEXT: *a *= *a;
//CHECK-NEXT: {
//CHECK-NEXT: *a = _t0;
//CHECK-NEXT: int _r_d0 = *_d_a;
//CHECK-NEXT: *_d_a = 0;
//CHECK-NEXT: *_d_a += _r_d0 * *a;
//CHECK-NEXT: *_d_a += *a * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

void fake_kernel(int *a) {
*a *= *a;
}

int main(void) {
int *a = (int*)malloc(sizeof(int));
*a = 2;
int *d_a;
cudaMalloc(&d_a, sizeof(int));
cudaMemcpy(d_a, a, sizeof(int), cudaMemcpyHostToDevice);

int *asquare = (int*)malloc(sizeof(int));
*asquare = 1;
int *d_square;
cudaMalloc(&d_square, sizeof(int));
cudaMemcpy(d_square, asquare, sizeof(int), cudaMemcpyHostToDevice);

auto test = clad::gradient(kernel);
dim3 grid(1);
dim3 block(1);
cudaStream_t cudaStream;
cudaStreamCreate(&cudaStream);
test.execute_kernel(grid, block, 0, cudaStream, d_a, d_square);

cudaDeviceSynchronize();

cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost);
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 2, a^2 = 4

auto error = clad::gradient(fake_kernel);
error.execute_kernel(grid, block, d_a, d_square); // CHECK-EXEC: Use execute() for non-global CUDA kernels

test.execute(d_a, d_square); // CHECK-EXEC: Use execute_kernel() for global CUDA kernels

cudaMemset(d_a, 5, 1); // first byte is set to 5
cudaMemset(d_square, 1, 1);

test.execute_kernel(grid, block, d_a, d_square);
cudaDeviceSynchronize();

cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost);
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 5, a^2 = 10

return 0;
}
6 changes: 3 additions & 3 deletions test/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ lit.util.usePlatformSdkOnDarwin(config, lit_config)
#\ -plugin-arg-ad -Xclang -fdump-derived-fn -Xclang -load -Xclang../../Debug+Asserts/lib/libclad.so
#FIXME: we need to introduce a better way to check compatible version of clang, propagating
#-fvalidate-clang-version flag is not enough.
flags = ' -std=c++11 -Xclang -add-plugin -Xclang clad -Xclang \
flags = ' -Xclang -add-plugin -Xclang clad -Xclang \
-plugin-arg-clad -Xclang -fdump-derived-fn -Xclang \
-load -Xclang ' + config.cladlib

config.substitutions.append( ('%cladclang_cuda', config.clang + flags) )
config.substitutions.append( ('%cladclang_cuda', config.clang + ' -std=c++17' + flags) )

config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + flags) )
config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + ' -std=c++11' + flags) )

config.substitutions.append( ('%cladlib', config.cladlib) )

Expand Down
Loading