Skip to content

Commit

Permalink
Enable computation of CUDA global kernels derivative in reverse mode (#…
Browse files Browse the repository at this point in the history
kchristin22 authored Sep 5, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 685bcbf commit 431791f
Showing 5 changed files with 215 additions and 32 deletions.
133 changes: 107 additions & 26 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
@@ -38,16 +38,23 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
return count;
}

/// Tape type used for storing values in reverse-mode AD inside loops.
template <typename T>
using tape = tape_impl<T>;
#ifdef __CUDACC__
#define CUDA_ARGS bool CUDAkernel, dim3 grid, dim3 block,
#define CUDA_REST_ARGS size_t shared_mem, cudaStream_t stream,
#else
#define CUDA_ARGS
#define CUDA_REST_ARGS
#endif

/// Add value to the end of the tape, return the same value.
template <typename T, typename... ArgsT>
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
to.emplace_back(std::forward<ArgsT>(val)...);
return to.back();
}
/// Tape type used for storing values in reverse-mode AD inside loops.
template <typename T> using tape = tape_impl<T>;

/// Add value to the end of the tape, return the same value.
template <typename T, typename... ArgsT>
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
to.emplace_back(std::forward<ArgsT>(val)...);
return to.back();
}

/// Add value to the end of the tape, return the same value.
/// A specialization for clad::array_ref types to use in reverse mode.
@@ -115,17 +122,35 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE return_type_t<F>
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
Args&&... args) {
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args..., (void*)static_cast<Rest>(nullptr)...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
} else {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}
#else
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
#endif
}

template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
return_type_t<F> execute_with_default_args(list<Rest...>, F f,
list<fArgTypes...>,
Args&&... args) {
return_type_t<F>
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
} else {
return f(static_cast<Args>(args)...);
}
#else
return f(static_cast<Args>(args)...);
#endif
}

// for executing member-functions
@@ -167,12 +192,13 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
CladFunctionType m_Function;
char* m_Code;
FunctorType *m_Functor = nullptr;
bool m_CUDAkernel = false;

public:
CUDA_HOST_DEVICE CladFunction(CladFunctionType f,
const char* code,
FunctorType* functor = nullptr)
: m_Functor(functor) {
CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code,
FunctorType* functor = nullptr,
bool CUDAkernel = false)
: m_Functor(functor), m_CUDAkernel(CUDAkernel) {
assert(f && "Must pass a non-0 argument.");
if (size_t length = GetLength(code)) {
m_Function = f;
@@ -210,9 +236,37 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
if (m_CUDAkernel) {
printf("Use execute_kernel() for global CUDA kernels\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
// here static_cast is used to achieve perfect forwarding
#ifdef __CUDACC__
return execute_helper(m_Function, m_CUDAkernel, dim3(0), dim3(0),
static_cast<Args>(args)...);
#else
return execute_helper(m_Function, static_cast<Args>(args)...);
#endif
}

#ifdef __CUDACC__
template <typename... Args, class FnType = CladFunctionType>
typename std::enable_if<!std::is_same<FnType, NoFunction*>::value,
return_type_t<F>>::type
execute_kernel(dim3 grid, dim3 block, Args&&... args) CUDA_HOST_DEVICE {
if (!m_Function) {
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
if (!m_CUDAkernel) {
printf("Use execute() for non-global CUDA kernels\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}

return execute_helper(m_Function, m_CUDAkernel, grid, block,
static_cast<Args>(args)...);
}
#endif

/// `Execute` overload to be used when derived function type cannot be
/// deduced. One reason for this can be when user tries to differentiate
@@ -258,12 +312,39 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// Helper function for executing non-member derived functions.
template <class Fn, class... Args>
CUDA_HOST_DEVICE return_type_t<CladFunctionType>
execute_helper(Fn f, Args&&... args) {
execute_helper(Fn f, CUDA_ARGS Args&&... args) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
#if defined(__CUDACC__)
if constexpr (sizeof...(Args) >= 2) {
auto secondArg =
std::get<1>(std::forward_as_tuple(std::forward<Args>(args)...));
if constexpr (std::is_same<std::decay_t<decltype(secondArg)>,
cudaStream_t>::value) {
return [&](auto shared_mem, cudaStream_t stream, auto&&... args_) {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args) - 2, F>{}, f,
TakeNFirstArgs_t<sizeof...(Args) - 2, decltype(f)>{},
CUDAkernel, grid, block, shared_mem, stream,
static_cast<decltype(args_)>(args_)...);
}(static_cast<Args>(args)...);
} else {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel,
grid, block, 0, nullptr, static_cast<Args>(args)...);
}
} else {
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel,
grid, block, 0, nullptr, static_cast<Args>(args)...);
}
#else
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
#endif
}

/// Helper functions for executing member derived functions.
@@ -393,10 +474,10 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
annotate("G"))) CUDA_HOST_DEVICE
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
const char* code = "", bool CUDAkernel = false) {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, nullptr, CUDAkernel);
}

/// Specialization for differentiating functors.
21 changes: 18 additions & 3 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -177,9 +177,6 @@ namespace clad {
void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD,
Sema& SemaRef) {
CallExpr* call = this->CallContext;
// Index of "code" parameter:
auto codeArgIdx = static_cast<int>(call->getNumArgs()) - 1;
auto derivedFnArgIdx = codeArgIdx - 1;

assert(call && "Must be set");
assert(FD && "Trying to update with null FunctionDecl");
@@ -191,6 +188,24 @@ namespace clad {
ASTContext& C = SemaRef.getASTContext();

FunctionDecl* replacementFD = OverloadedFD ? OverloadedFD : FD;

// Index of "CUDAkernel" parameter:
int numArgs = static_cast<int>(call->getNumArgs());
if (numArgs > 4) {
auto kernelArgIdx = numArgs - 1;
auto* cudaKernelFlag =
SemaRef
.ActOnCXXBoolLiteral(noLoc,
replacementFD->hasAttr<CUDAGlobalAttr>()
? tok::kw_true
: tok::kw_false)
.get();
call->setArg(kernelArgIdx, cudaKernelFlag);
numArgs--;
}
auto codeArgIdx = numArgs - 1;
auto derivedFnArgIdx = numArgs - 2;

// Create ref to generated FD.
DeclRefExpr* DRE =
DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD,
8 changes: 8 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -240,6 +240,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildDeclStmt(gradientVD));
}

// If the function is a global kernel, we need to transform it
// into a device function when calling it inside the overload function
// which is the final global kernel returned.
if (m_Derivative->hasAttr<clang::CUDAGlobalAttr>()) {
m_Derivative->dropAttr<clang::CUDAGlobalAttr>();
m_Derivative->addAttr(clang::CUDADeviceAttr::CreateImplicit(m_Context));
}

Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs,
/*UseRefQualifiedThisObj=*/true);
addToCurrentBlock(callExpr);
79 changes: 79 additions & 0 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: %cladclang_cuda -I%S/../../include %s -fsyntax-only \
// RUN: %cudasmlevel --cuda-path=%cudapath -Xclang -verify 2>&1 | %filecheck %s

// RUN: %cladclang_cuda -I%S/../../include %s -xc++ %cudasmlevel \
// RUN: --cuda-path=%cudapath -L/usr/local/cuda/lib64 -lcudart_static \
// RUN: -L%cudapath/lib64/stubs \
// RUN: -ldl -lrt -pthread -lm -lstdc++ -lcuda -lnvrtc

// REQUIRES: cuda-runtime

// expected-no-diagnostics

// XFAIL: clang-15

#include "clad/Differentiator/Differentiator.h"

__global__ void kernel(int *a) {
*a *= *a;
}

// CHECK: void kernel_grad(int *a, int *_d_a) {
//CHECK-NEXT: int _t0 = *a;
//CHECK-NEXT: *a *= *a;
//CHECK-NEXT: {
//CHECK-NEXT: *a = _t0;
//CHECK-NEXT: int _r_d0 = *_d_a;
//CHECK-NEXT: *_d_a = 0;
//CHECK-NEXT: *_d_a += _r_d0 * *a;
//CHECK-NEXT: *_d_a += *a * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

void fake_kernel(int *a) {
*a *= *a;
}

int main(void) {
int *a = (int*)malloc(sizeof(int));
*a = 2;
int *d_a;
cudaMalloc(&d_a, sizeof(int));
cudaMemcpy(d_a, a, sizeof(int), cudaMemcpyHostToDevice);

int *asquare = (int*)malloc(sizeof(int));
*asquare = 1;
int *d_square;
cudaMalloc(&d_square, sizeof(int));
cudaMemcpy(d_square, asquare, sizeof(int), cudaMemcpyHostToDevice);

auto test = clad::gradient(kernel);
dim3 grid(1);
dim3 block(1);
cudaStream_t cudaStream;
cudaStreamCreate(&cudaStream);
test.execute_kernel(grid, block, 0, cudaStream, d_a, d_square);

cudaDeviceSynchronize();

cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost);
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 2, a^2 = 4

auto error = clad::gradient(fake_kernel);
error.execute_kernel(grid, block, d_a, d_square); // CHECK-EXEC: Use execute() for non-global CUDA kernels

test.execute(d_a, d_square); // CHECK-EXEC: Use execute_kernel() for global CUDA kernels

cudaMemset(d_a, 5, 1); // first byte is set to 5
cudaMemset(d_square, 1, 1);

test.execute_kernel(grid, block, d_a, d_square);
cudaDeviceSynchronize();

cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost);
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 5, a^2 = 10

return 0;
}
6 changes: 3 additions & 3 deletions test/lit.cfg
Original file line number Diff line number Diff line change
@@ -257,13 +257,13 @@ lit.util.usePlatformSdkOnDarwin(config, lit_config)
#\ -plugin-arg-ad -Xclang -fdump-derived-fn -Xclang -load -Xclang../../Debug+Asserts/lib/libclad.so
#FIXME: we need to introduce a better way to check compatible version of clang, propagating
#-fvalidate-clang-version flag is not enough.
flags = ' -std=c++11 -Xclang -add-plugin -Xclang clad -Xclang \
flags = ' -Xclang -add-plugin -Xclang clad -Xclang \
-plugin-arg-clad -Xclang -fdump-derived-fn -Xclang \
-load -Xclang ' + config.cladlib

config.substitutions.append( ('%cladclang_cuda', config.clang + flags) )
config.substitutions.append( ('%cladclang_cuda', config.clang + ' -std=c++17' + flags) )

config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + flags) )
config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + ' -std=c++11' + flags) )

config.substitutions.append( ('%cladlib', config.cladlib) )

0 comments on commit 431791f

Please sign in to comment.