-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable computation of CUDA global kernels derivative in reverse mode #1059
Changes from all commits
e8b02d5
5417876
d5c1427
b44f7f4
a27314c
610850f
b91203a
df7c5c7
784ee25
4cc966e
d5e9a98
c42fd53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,16 +38,23 @@ | |
return count; | ||
} | ||
|
||
/// Tape type used for storing values in reverse-mode AD inside loops. | ||
template <typename T> | ||
using tape = tape_impl<T>; | ||
#ifdef __CUDACC__ | ||
#define CUDA_ARGS bool CUDAkernel, dim3 grid, dim3 block, | ||
#define CUDA_REST_ARGS size_t shared_mem, cudaStream_t stream, | ||
#else | ||
#define CUDA_ARGS | ||
#define CUDA_REST_ARGS | ||
#endif | ||
|
||
/// Add value to the end of the tape, return the same value. | ||
template <typename T, typename... ArgsT> | ||
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) { | ||
to.emplace_back(std::forward<ArgsT>(val)...); | ||
return to.back(); | ||
} | ||
/// Tape type used for storing values in reverse-mode AD inside loops. | ||
template <typename T> using tape = tape_impl<T>; | ||
|
||
/// Add value to the end of the tape, return the same value. | ||
template <typename T, typename... ArgsT> | ||
CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) { | ||
to.emplace_back(std::forward<ArgsT>(val)...); | ||
return to.back(); | ||
} | ||
|
||
/// Add value to the end of the tape, return the same value. | ||
/// A specialization for clad::array_ref types to use in reverse mode. | ||
|
@@ -115,17 +122,35 @@ | |
typename std::enable_if<EnablePadding, bool>::type = true> | ||
CUDA_HOST_DEVICE return_type_t<F> | ||
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>, | ||
Args&&... args) { | ||
CUDA_ARGS CUDA_REST_ARGS Args&&... args) { | ||
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__) | ||
if (CUDAkernel) { | ||
void* argPtrs[] = {(void*)&args..., (void*)static_cast<Rest>(nullptr)...}; | ||
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); | ||
} else { | ||
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...); | ||
} | ||
#else | ||
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...); | ||
#endif | ||
} | ||
|
||
template <bool EnablePadding, class... Rest, class F, class... Args, | ||
class... fArgTypes, | ||
typename std::enable_if<!EnablePadding, bool>::type = true> | ||
return_type_t<F> execute_with_default_args(list<Rest...>, F f, | ||
list<fArgTypes...>, | ||
Args&&... args) { | ||
return_type_t<F> | ||
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>, | ||
CUDA_ARGS CUDA_REST_ARGS Args&&... args) { | ||
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__) | ||
if (CUDAkernel) { | ||
void* argPtrs[] = {(void*)&args...}; | ||
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); | ||
} else { | ||
return f(static_cast<Args>(args)...); | ||
} | ||
#else | ||
return f(static_cast<Args>(args)...); | ||
#endif | ||
} | ||
|
||
// for executing member-functions | ||
|
@@ -167,12 +192,13 @@ | |
CladFunctionType m_Function; | ||
char* m_Code; | ||
FunctorType *m_Functor = nullptr; | ||
bool m_CUDAkernel = false; | ||
|
||
public: | ||
CUDA_HOST_DEVICE CladFunction(CladFunctionType f, | ||
const char* code, | ||
FunctorType* functor = nullptr) | ||
: m_Functor(functor) { | ||
CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code, | ||
FunctorType* functor = nullptr, | ||
bool CUDAkernel = false) | ||
: m_Functor(functor), m_CUDAkernel(CUDAkernel) { | ||
assert(f && "Must pass a non-0 argument."); | ||
if (size_t length = GetLength(code)) { | ||
m_Function = f; | ||
|
@@ -210,9 +236,37 @@ | |
printf("CladFunction is invalid\n"); | ||
return static_cast<return_type_t<F>>(return_type_t<F>()); | ||
} | ||
if (m_CUDAkernel) { | ||
printf("Use execute_kernel() for global CUDA kernels\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably assert-out if users use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that a programmatic error or user error. We can use assert/abort if the error is programmer error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a user error basically. They should use the appropriate |
||
return static_cast<return_type_t<F>>(return_type_t<F>()); | ||
} | ||
// here static_cast is used to achieve perfect forwarding | ||
#ifdef __CUDACC__ | ||
return execute_helper(m_Function, m_CUDAkernel, dim3(0), dim3(0), | ||
static_cast<Args>(args)...); | ||
#else | ||
return execute_helper(m_Function, static_cast<Args>(args)...); | ||
#endif | ||
} | ||
|
||
#ifdef __CUDACC__ | ||
template <typename... Args, class FnType = CladFunctionType> | ||
typename std::enable_if<!std::is_same<FnType, NoFunction*>::value, | ||
return_type_t<F>>::type | ||
execute_kernel(dim3 grid, dim3 block, Args&&... args) CUDA_HOST_DEVICE { | ||
if (!m_Function) { | ||
printf("CladFunction is invalid\n"); | ||
return static_cast<return_type_t<F>>(return_type_t<F>()); | ||
} | ||
if (!m_CUDAkernel) { | ||
printf("Use execute() for non-global CUDA kernels\n"); | ||
return static_cast<return_type_t<F>>(return_type_t<F>()); | ||
} | ||
|
||
return execute_helper(m_Function, m_CUDAkernel, grid, block, | ||
static_cast<Args>(args)...); | ||
} | ||
#endif | ||
|
||
/// `Execute` overload to be used when derived function type cannot be | ||
/// deduced. One reason for this can be when user tries to differentiate | ||
|
@@ -258,12 +312,39 @@ | |
/// Helper function for executing non-member derived functions. | ||
template <class Fn, class... Args> | ||
CUDA_HOST_DEVICE return_type_t<CladFunctionType> | ||
execute_helper(Fn f, Args&&... args) { | ||
execute_helper(Fn f, CUDA_ARGS Args&&... args) { | ||
// `static_cast` is required here for perfect forwarding. | ||
return execute_with_default_args<EnablePadding>( | ||
DropArgs_t<sizeof...(Args), F>{}, f, | ||
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, | ||
static_cast<Args>(args)...); | ||
#if defined(__CUDACC__) | ||
if constexpr (sizeof...(Args) >= 2) { | ||
auto secondArg = | ||
std::get<1>(std::forward_as_tuple(std::forward<Args>(args)...)); | ||
if constexpr (std::is_same<std::decay_t<decltype(secondArg)>, | ||
cudaStream_t>::value) { | ||
return [&](auto shared_mem, cudaStream_t stream, auto&&... args_) { | ||
return execute_with_default_args<EnablePadding>( | ||
DropArgs_t<sizeof...(Args) - 2, F>{}, f, | ||
TakeNFirstArgs_t<sizeof...(Args) - 2, decltype(f)>{}, | ||
CUDAkernel, grid, block, shared_mem, stream, | ||
static_cast<decltype(args_)>(args_)...); | ||
}(static_cast<Args>(args)...); | ||
} else { | ||
return execute_with_default_args<EnablePadding>( | ||
DropArgs_t<sizeof...(Args), F>{}, f, | ||
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel, | ||
grid, block, 0, nullptr, static_cast<Args>(args)...); | ||
} | ||
} else { | ||
return execute_with_default_args<EnablePadding>( | ||
DropArgs_t<sizeof...(Args), F>{}, f, | ||
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, CUDAkernel, | ||
grid, block, 0, nullptr, static_cast<Args>(args)...); | ||
} | ||
#else | ||
return execute_with_default_args<EnablePadding>( | ||
DropArgs_t<sizeof...(Args), F>{}, f, | ||
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{}, | ||
static_cast<Args>(args)...); | ||
#endif | ||
} | ||
|
||
/// Helper functions for executing member derived functions. | ||
|
@@ -393,10 +474,10 @@ | |
annotate("G"))) CUDA_HOST_DEVICE | ||
gradient(F f, ArgSpec args = "", | ||
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr), | ||
const char* code = "") { | ||
assert(f && "Must pass in a non-0 argument"); | ||
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>( | ||
derivedFn /* will be replaced by gradient*/, code); | ||
const char* code = "", bool CUDAkernel = false) { | ||
assert(f && "Must pass in a non-0 argument"); | ||
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>( | ||
derivedFn /* will be replaced by gradient*/, code, nullptr, CUDAkernel); | ||
} | ||
|
||
/// Specialization for differentiating functors. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// RUN: %cladclang_cuda -I%S/../../include %s -fsyntax-only \ | ||
// RUN: %cudasmlevel --cuda-path=%cudapath -Xclang -verify 2>&1 | %filecheck %s | ||
|
||
// RUN: %cladclang_cuda -I%S/../../include %s -xc++ %cudasmlevel \ | ||
// RUN: --cuda-path=%cudapath -L/usr/local/cuda/lib64 -lcudart_static \ | ||
// RUN: -L%cudapath/lib64/stubs \ | ||
// RUN: -ldl -lrt -pthread -lm -lstdc++ -lcuda -lnvrtc | ||
|
||
// REQUIRES: cuda-runtime | ||
|
||
// expected-no-diagnostics | ||
|
||
// XFAIL: clang-15 | ||
parth-07 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
__global__ void kernel(int *a) { | ||
*a *= *a; | ||
} | ||
|
||
// CHECK: void kernel_grad(int *a, int *_d_a) { | ||
//CHECK-NEXT: int _t0 = *a; | ||
//CHECK-NEXT: *a *= *a; | ||
//CHECK-NEXT: { | ||
//CHECK-NEXT: *a = _t0; | ||
//CHECK-NEXT: int _r_d0 = *_d_a; | ||
//CHECK-NEXT: *_d_a = 0; | ||
//CHECK-NEXT: *_d_a += _r_d0 * *a; | ||
//CHECK-NEXT: *_d_a += *a * _r_d0; | ||
//CHECK-NEXT: } | ||
//CHECK-NEXT: } | ||
|
||
void fake_kernel(int *a) { | ||
*a *= *a; | ||
} | ||
|
||
int main(void) { | ||
int *a = (int*)malloc(sizeof(int)); | ||
*a = 2; | ||
int *d_a; | ||
cudaMalloc(&d_a, sizeof(int)); | ||
cudaMemcpy(d_a, a, sizeof(int), cudaMemcpyHostToDevice); | ||
|
||
int *asquare = (int*)malloc(sizeof(int)); | ||
*asquare = 1; | ||
int *d_square; | ||
cudaMalloc(&d_square, sizeof(int)); | ||
cudaMemcpy(d_square, asquare, sizeof(int), cudaMemcpyHostToDevice); | ||
|
||
auto test = clad::gradient(kernel); | ||
dim3 grid(1); | ||
dim3 block(1); | ||
cudaStream_t cudaStream; | ||
cudaStreamCreate(&cudaStream); | ||
test.execute_kernel(grid, block, 0, cudaStream, d_a, d_square); | ||
|
||
cudaDeviceSynchronize(); | ||
|
||
cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost); | ||
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost); | ||
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 2, a^2 = 4 | ||
|
||
auto error = clad::gradient(fake_kernel); | ||
error.execute_kernel(grid, block, d_a, d_square); // CHECK-EXEC: Use execute() for non-global CUDA kernels | ||
|
||
test.execute(d_a, d_square); // CHECK-EXEC: Use execute_kernel() for global CUDA kernels | ||
|
||
cudaMemset(d_a, 5, 1); // first byte is set to 5 | ||
cudaMemset(d_square, 1, 1); | ||
|
||
test.execute_kernel(grid, block, d_a, d_square); | ||
cudaDeviceSynchronize(); | ||
|
||
cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost); | ||
cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost); | ||
printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 5, a^2 = 10 | ||
|
||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: do not call c-style vararg functions [cppcoreguidelines-pro-type-vararg]