From 99efd46dc2abfff52bffe40ec7fc335cf64e0d0b Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 21 Sep 2024 15:03:50 +0300 Subject: [PATCH] Fix appendage of nullptrs to args of a CUDA kernel --- include/clad/Differentiator/Differentiator.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 3a8f35faf..ce99c1b5c 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -125,8 +125,17 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { CUDA_ARGS CUDA_REST_ARGS Args&&... args) { #if defined(__CUDACC__) && !defined(__CUDA_ARCH__) if (CUDAkernel) { - void* argPtrs[] = {(void*)&args..., (void*)static_cast(nullptr)...}; - cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); + constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest); + std::vector argPtrs; + argPtrs.reserve(totalArgs); + (argPtrs.push_back(static_cast(&args)), ...); + + void* null_param = nullptr; + for (size_t i = sizeof...(args); i < totalArgs; ++i) + argPtrs[i] = &null_param; + + cudaLaunchKernel((void*)f, grid, block, argPtrs.data(), shared_mem, stream); + return return_type_t(); } else { return f(static_cast(args)..., static_cast(nullptr)...); }