Skip to content

Commit

Permalink
Use cuda::mr::memory_resource instead of raw `device_memory_resourc…
Browse files Browse the repository at this point in the history
…e` (#1095)

This introduces `cuda::mr::{async_}resource_ref` as a type erased safe resource wrapper that is meant to replace uses of `{host, device}_memory_resource`

We provide both async and classic allocate functions that delegate back to the original resource used to construct the `cuda::mr::{async_}resource_ref`

In comparison to `{host, device}_memory_resource` the new feature provides additional compile time checks that will help users avoid common pitfalls with heterogeneous memory allocations.

As a first step we provide the properties `cuda::mr::host_accessible` and `cuda::mr::device_accessible`. These properties can be added to an internal or even external type through a free function `get_property`
```cpp
// For a user defined resource
struct my_resource {
  friend void get_property(my_resource const&, cuda::mr::device_accessible) noexcept {}
};

// For an external resource
void get_property(some_external_resource const&, cuda::mr::device_accessible) noexcept {}
```

The advantage is that we can constrain interfaces based on these properties
```cpp
void do_some_computation_on_device(cuda::mr::async_resource_ref<cuda::mr::device_accessible> mr, ...) { ... }
```
This function will fail to compile if it is passed any resource that does not support async allocations or is not tagged as providing device accessible memory. In the same way the following function will only compile if the provided resource provides the classic allocate / deallocate interface and is tagged to provide host accessible memory
```cpp
void do_some_computation_on_host(cuda::mr::resource_ref<cuda::mr::host_accessible> mr, ...) { ... }
```

The property system is highly flexible and can easily be user provided to add their own properties as needed. That gives it both the flexibility of an inheritance based implementation and the security of a strictly type checked interface

Authors:
  - Michael Schellenberger Costa (https://github.com/miscco)
  - Bradley Dice (https://github.com/bdice)
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Mark Harris (https://github.com/harrism)
  - Bradley Dice (https://github.com/bdice)

URL: #1095
  • Loading branch information
miscco authored Nov 17, 2023
1 parent ba99ff4 commit 6acae3c
Show file tree
Hide file tree
Showing 28 changed files with 1,664 additions and 89 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ rapids_cpm_init()

include(cmake/thirdparty/get_fmt.cmake)
include(cmake/thirdparty/get_spdlog.cmake)
include(cmake/thirdparty/get_libcudacxx.cmake)
include(cmake/thirdparty/get_thrust.cmake)

# ##################################################################################################
Expand All @@ -89,11 +90,13 @@ else()
target_link_libraries(rmm INTERFACE CUDA::cudart)
endif()

target_link_libraries(rmm INTERFACE libcudacxx::libcudacxx)
target_link_libraries(rmm INTERFACE rmm::Thrust)
target_link_libraries(rmm INTERFACE fmt::fmt-header-only)
target_link_libraries(rmm INTERFACE spdlog::spdlog_header_only)
target_link_libraries(rmm INTERFACE dl)
target_compile_features(rmm INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
target_compile_definitions(rmm INTERFACE LIBCUDACXX_ENABLE_EXPERIMENTAL_MEMORY_RESOURCE)

# ##################################################################################################
# * tests and benchmarks ---------------------------------------------------------------------------
Expand Down
23 changes: 23 additions & 0 deletions cmake/thirdparty/get_libcudacxx.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Use CPM to find or clone libcudacxx
function(find_and_configure_libcudacxx)

include(${rapids-cmake-dir}/cpm/libcudacxx.cmake)
rapids_cpm_libcudacxx(BUILD_EXPORT_SET rmm-exports INSTALL_EXPORT_SET rmm-exports)

endfunction()

find_and_configure_libcudacxx()
16 changes: 16 additions & 0 deletions include/rmm/cuda_stream_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <cuda_runtime_api.h>

#include <cuda/stream_ref>

#include <atomic>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -58,6 +60,13 @@ class cuda_stream_view {
*/
constexpr cuda_stream_view(cudaStream_t stream) noexcept : stream_{stream} {}

/**
* @brief Implicit conversion from stream_ref.
*
* @param stream The underlying stream for this view
*/
constexpr cuda_stream_view(cuda::stream_ref stream) noexcept : stream_{stream.get()} {}

/**
* @brief Get the wrapped stream.
*
Expand All @@ -72,6 +81,13 @@ class cuda_stream_view {
*/
constexpr operator cudaStream_t() const noexcept { return value(); }

/**
* @brief Implicit conversion to stream_ref.
*
* @return stream_ref The underlying stream referenced by this cuda_stream_view
*/
constexpr operator cuda::stream_ref() const noexcept { return value(); }

/**
* @briefreturn{true if the wrapped stream is the CUDA per-thread default stream}
*/
Expand Down
26 changes: 15 additions & 11 deletions include/rmm/device_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <stdexcept>
#include <utility>

#include <cuda/memory_resource>

namespace rmm {
/**
* @addtogroup data_containers
Expand Down Expand Up @@ -80,6 +82,8 @@ namespace rmm {
*```
*/
class device_buffer {
using async_resource_ref = cuda::mr::async_resource_ref<cuda::mr::device_accessible>;

public:
// The copy constructor and copy assignment operator without a stream are deleted because they
// provide no way to specify an explicit stream
Expand Down Expand Up @@ -107,7 +111,7 @@ class device_buffer {
*/
explicit device_buffer(std::size_t size,
cuda_stream_view stream,
mr::device_memory_resource* mr = mr::get_current_device_resource())
async_resource_ref mr = mr::get_current_device_resource())
: _stream{stream}, _mr{mr}
{
cuda_set_device_raii dev{_device};
Expand Down Expand Up @@ -136,7 +140,7 @@ class device_buffer {
device_buffer(void const* source_data,
std::size_t size,
cuda_stream_view stream,
mr::device_memory_resource* mr = mr::get_current_device_resource())
async_resource_ref mr = rmm::mr::get_current_device_resource())
: _stream{stream}, _mr{mr}
{
cuda_set_device_raii dev{_device};
Expand Down Expand Up @@ -167,7 +171,7 @@ class device_buffer {
*/
device_buffer(device_buffer const& other,
cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
async_resource_ref mr = rmm::mr::get_current_device_resource())
: device_buffer{other.data(), other.size(), stream, mr}
{
}
Expand Down Expand Up @@ -245,7 +249,6 @@ class device_buffer {
{
cuda_set_device_raii dev{_device};
deallocate_async();
_mr = nullptr;
_stream = cuda_stream_view{};
}

Expand Down Expand Up @@ -407,18 +410,19 @@ class device_buffer {
void set_stream(cuda_stream_view stream) noexcept { _stream = stream; }

/**
* @briefreturn{Pointer to the memory resource used to allocate and deallocate}
* @briefreturn{The async_resource_ref used to allocate and deallocate}
*/
[[nodiscard]] mr::device_memory_resource* memory_resource() const noexcept { return _mr; }
[[nodiscard]] async_resource_ref memory_resource() const noexcept { return _mr; }

private:
void* _data{nullptr}; ///< Pointer to device memory allocation
std::size_t _size{}; ///< Requested size of the device memory allocation
std::size_t _capacity{}; ///< The actual size of the device memory allocation
cuda_stream_view _stream{}; ///< Stream to use for device memory deallocation
mr::device_memory_resource* _mr{
mr::get_current_device_resource()}; ///< The memory resource used to
///< allocate/deallocate device memory

async_resource_ref _mr{
rmm::mr::get_current_device_resource()}; ///< The memory resource used to
///< allocate/deallocate device memory
cuda_device_id _device{get_current_cuda_device()};

/**
Expand All @@ -434,7 +438,7 @@ class device_buffer {
{
_size = bytes;
_capacity = bytes;
_data = (bytes > 0) ? memory_resource()->allocate(bytes, stream()) : nullptr;
_data = (bytes > 0) ? _mr.allocate_async(bytes, stream()) : nullptr;
}

/**
Expand All @@ -448,7 +452,7 @@ class device_buffer {
*/
void deallocate_async() noexcept
{
if (capacity() > 0) { memory_resource()->deallocate(data(), capacity(), stream()); }
if (capacity() > 0) { _mr.deallocate_async(data(), capacity(), stream()); }
_size = 0;
_capacity = 0;
_data = nullptr;
Expand Down
21 changes: 11 additions & 10 deletions include/rmm/device_uvector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <cstddef>
#include <vector>

#include <cuda/memory_resource>

namespace rmm {
/**
* @addtogroup data_containers
Expand Down Expand Up @@ -72,6 +74,7 @@ namespace rmm {
*/
template <typename T>
class device_uvector {
using async_resource_ref = cuda::mr::async_resource_ref<cuda::mr::device_accessible>;
static_assert(std::is_trivially_copyable<T>::value,
"device_uvector only supports types that are trivially copyable.");

Expand Down Expand Up @@ -121,10 +124,9 @@ class device_uvector {
* @param stream The stream on which to perform the allocation
* @param mr The resource used to allocate the device storage
*/
explicit device_uvector(
std::size_t size,
cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
explicit device_uvector(std::size_t size,
cuda_stream_view stream,
async_resource_ref mr = rmm::mr::get_current_device_resource())
: _storage{elements_to_bytes(size), stream, mr}
{
}
Expand All @@ -138,10 +140,9 @@ class device_uvector {
* @param stream The stream on which to perform the copy
* @param mr The resource used to allocate device memory for the new vector
*/
explicit device_uvector(
device_uvector const& other,
cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
explicit device_uvector(device_uvector const& other,
cuda_stream_view stream,
async_resource_ref mr = rmm::mr::get_current_device_resource())
: _storage{other._storage, stream, mr}
{
}
Expand Down Expand Up @@ -524,9 +525,9 @@ class device_uvector {
[[nodiscard]] bool is_empty() const noexcept { return size() == 0; }

/**
* @briefreturn{Pointer to underlying resource used to allocate and deallocate the device storage}
* @briefreturn{The async_resource_ref used to allocate and deallocate the device storage}
*/
[[nodiscard]] mr::device_memory_resource* memory_resource() const noexcept
[[nodiscard]] async_resource_ref memory_resource() const noexcept
{
return _storage.memory_resource();
}
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/callback_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ class callback_memory_resource final : public device_memory_resource {
throw std::runtime_error("cannot get free / total memory");
}

[[nodiscard]] virtual bool supports_streams() const noexcept { return false; }
[[nodiscard]] virtual bool supports_get_mem_info() const noexcept { return false; }
[[nodiscard]] bool supports_streams() const noexcept override { return false; }
[[nodiscard]] bool supports_get_mem_info() const noexcept override { return false; }

allocate_callback_t allocate_callback_;
deallocate_callback_t deallocate_callback_;
Expand Down
Loading

0 comments on commit 6acae3c

Please sign in to comment.