From d6dc2be1b782f01d3a449734647ff225d922a27a Mon Sep 17 00:00:00 2001 From: Eyal Rozenberg Date: Mon, 26 Feb 2024 01:53:35 +0200 Subject: [PATCH] Fixes #591: Some work on the memory copy functions --- src/cuda/api/memory.hpp | 57 +++++++++++++++++++-- src/cuda/api/multi_wrapper_impls/memory.hpp | 26 ++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/cuda/api/memory.hpp b/src/cuda/api/memory.hpp index d93f1633..11bb33b1 100644 --- a/src/cuda/api/memory.hpp +++ b/src/cuda/api/memory.hpp @@ -445,13 +445,24 @@ template inline void copy(region_t destination, const T(&source)[N]) { #ifndef NDEBUG - if (destination.size() < N) { + if (destination.size() < N * sizeof(T)) { throw ::std::logic_error("Source size exceeds destination size"); } #endif return copy(destination.start(), source, sizeof(T) * N); } +template +inline void copy(span destination, const T(&source)[N]) +{ +#ifndef NDEBUG + if (destination.size() < N) { + throw ::std::logic_error("Source size exceeds destination size"); + } +#endif + return copy(destination.data(), source, sizeof(T) * N); +} + /** * @param destination A region of memory to which to copy the data in @p source, * of size at least that of @p source. @@ -471,6 +482,19 @@ inline void copy(T(&destination)[N], const_region_t source) return copy(destination, source.start(), sizeof(T) * N); } +template +inline void copy(T(&destination)[N], span source) +{ +#ifndef NDEBUG + if (source.size() > N) { + throw ::std::invalid_argument( + "Attempt to copy a span of " + ::std::to_string(source.size()) + + " elements into an array of " + ::std::to_string(N) + " elements"); + } +#endif + return copy(destination, source.start(), sizeof(T) * N); +} + template inline void copy(void* destination, T (&source)[N]) { @@ -669,6 +693,19 @@ void copy(const array_t& destination, const T *source) copy(destination, context_of(source), source); } +template +void copy(const array_t& destination, span source) +{ +#ifndef NDEBUG + if (destination.size() < source.size()) { + throw ::std::invalid_argument( + "Attempt to copy a span of " + ::std::to_string(source.size()) + + " elements into a CUDA array of " + ::std::to_string(destination.size()) + " elements"); + } +#endif + copy(destination, source.data()); +} + /** * Synchronously copies data into a CUDA array from non-array memory. * @@ -710,6 +747,19 @@ void copy(T *destination, const array_t& source) copy(context_of(destination), destination, source); } +template +void copy(span destination, const array_t& source) +{ +#ifndef NDEBUG + if (destination.size() < source.size()) { + throw ::std::invalid_argument( + "Attempt to copy a CUDA array of " + ::std::to_string(source.size()) + + " elements into a span of " + ::std::to_string(destination.size()) + " elements"); + } +#endif + copy(destination.data(), source); +} + template void copy(const array_t& destination, const array_t& source) { @@ -742,7 +792,7 @@ void copy(const array_t& destination, const_region_t source) if (destination.size_bytes() < source.size()) { throw ::std::logic_error("Attempt to copy into an array from a source region larger than the array's size"); } - copy(destination, source.start()); + copy(destination, static_cast(source.start())); } /** @@ -854,7 +904,6 @@ status_t multidim_copy( return multidim_copy_in_current_context(::std::integral_constant{}, params, stream_handle); } - // Assumes the array and the stream share the same context, and that the destination is // accessible from that context (e.g. allocated within it, or being managed memory, etc.) template @@ -1021,7 +1070,7 @@ void copy(array_t& destination, const_region_t source, const s " bytes into an array of size " + ::std::to_string(required_size) + " bytes"); } #endif - copy(destination, source.start(), stream); + copy(destination, static_cast(source.start()), stream); } /** diff --git a/src/cuda/api/multi_wrapper_impls/memory.hpp b/src/cuda/api/multi_wrapper_impls/memory.hpp index 8038fda3..56a9d00f 100644 --- a/src/cuda/api/multi_wrapper_impls/memory.hpp +++ b/src/cuda/api/multi_wrapper_impls/memory.hpp @@ -43,6 +43,19 @@ inline void copy(array_t& destination, const T* source, const detail_::copy(destination, source, stream.handle()); } +template +inline void copy(array_t& destination, span source, const stream_t& stream) +{ +#ifndef NDEBUG + if (source.size() != destination.size()) { + throw ::std::invalid_argument( + "Attempt to copy " + ::std::to_string(source.size()) + + " elements into an array of " + ::std::to_string(destination.size()) + " elements"); + } +#endif + detail_::copy(destination, source.data(), stream.handle()); +} + // Note: Assumes the destination, source and stream are all usable on the same content template inline void copy(T* destination, const array_t& source, const stream_t& stream) @@ -55,6 +68,19 @@ inline void copy(T* destination, const array_t& source, const detail_::copy(destination, source, stream.handle()); } +template +inline void copy(span destination, const array_t& source, const stream_t& stream) +{ +#ifndef NDEBUG + if (destination.size() != source.size()) { + throw ::std::invalid_argument( + "Attempt to copy " + ::std::to_string(source.size()) + + " elements into an array of " + ::std::to_string(destination.size()) + " elements"); + } +#endif + copy(destination.data(), source, stream); +} + template inline void copy_single(T& destination, const T& source, const stream_t& stream) {