diff --git a/.Rbuildignore b/.Rbuildignore index 9db1ba2361..4268cddaf6 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -34,7 +34,7 @@ tools/torchgen/.Rbuildignore # ^vignettes/using-autograd\.Rmd # uncomment below for CRAN submission -^inst/deps/.* +#^inst/deps/.* ^doc$ ^Meta$ ^vignettes/rsconnect* diff --git a/NAMESPACE b/NAMESPACE index 04c2d0d846..77c010c24d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -122,6 +122,7 @@ export(backends_openmp_is_available) export(contrib_sort_vertices) export(cuda_current_device) export(cuda_device_count) +export(cuda_get_device_capability) export(cuda_is_available) export(dataloader) export(dataloader_make_iter) diff --git a/R/RcppExports.R b/R/RcppExports.R index eca848f4a8..d442d58e66 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -165,6 +165,10 @@ cpp_cuda_current_device <- function() { .Call('_torch_cpp_cuda_current_device', PACKAGE = 'torchpkg') } +cpp_cuda_get_device_capability <- function(device) { + .Call('_torch_cpp_cuda_get_device_capability', PACKAGE = 'torchpkg', device) +} + cpp_device_type_to_string <- function(device) { .Call('_torch_cpp_device_type_to_string', PACKAGE = 'torchpkg', device) } diff --git a/R/cuda.R b/R/cuda.R index fad95c65e8..320d30ac52 100644 --- a/R/cuda.R +++ b/R/cuda.R @@ -17,4 +17,18 @@ cuda_current_device <- function() { #' @export cuda_device_count <- function() { cpp_cuda_device_count() +} + +#' Returns the major and minor CUDA capability of `device` +#' +#' @param device Integer value of the CUDA device to return capabilities of. +#' +#' @export +cuda_get_device_capability <- function(device = cuda_current_device()) { + if(device < 0 | device >= cuda_device_count()) { + stop(paste("device must be an integer between 0 and the number of devices minus 1")) + } + res <- as.integer(cpp_cuda_get_device_capability(device)) + names(res) <- c("Major", "Minor") + res } \ No newline at end of file diff --git a/lantern/CMakeLists.txt b/lantern/CMakeLists.txt index f7dbd9dd61..134c3b5eb6 100644 --- a/lantern/CMakeLists.txt +++ b/lantern/CMakeLists.txt @@ -129,7 +129,8 @@ if(DEFINED ENV{CUDA}) src/Contrib/SortVertices/sort_vert_kernel.cu src/Contrib/SortVertices/sort_vert.cpp ) - + + set_source_files_properties(src/Cuda.cpp PROPERTIES COMPILE_DEFINITIONS __NVCC__) cuda_add_library(lantern SHARED ${LANTERN_SRC}) else() set(LANTERN_SRC diff --git a/lantern/include/lantern/lantern.h b/lantern/include/lantern/lantern.h index bd8a3c3730..844f3fcf9d 100644 --- a/lantern/include/lantern/lantern.h +++ b/lantern/include/lantern/lantern.h @@ -2163,6 +2163,15 @@ HOST_API void* lantern_optional_tensor_value (void* x) return ret; } +LANTERN_API void* (LANTERN_PTR _lantern_cuda_get_device_capability) (int64_t device); +HOST_API void* lantern_cuda_get_device_capability (int64_t device) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_cuda_get_device_capability(device); + LANTERN_HOST_HANDLER; + return ret; +} + /* Autogen Headers -- Start */ LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking); HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; } @@ -7749,6 +7758,7 @@ LOAD_SYMBOL(_lantern_OptionalTensorList_size); LOAD_SYMBOL(_lantern_OptionalTensorList_at); LOAD_SYMBOL(_lantern_OptionalTensorList_at_is_null); LOAD_SYMBOL(_lantern_optional_tensor_value); +LOAD_SYMBOL(_lantern_cuda_get_device_capability); /* Autogen Symbols -- Start */ LOAD_SYMBOL(_lantern__cast_byte_tensor_bool) LOAD_SYMBOL(_lantern__cast_char_tensor_bool) diff --git a/lantern/src/Cuda.cpp b/lantern/src/Cuda.cpp index ae70e485b1..b78f733cb7 100644 --- a/lantern/src/Cuda.cpp +++ b/lantern/src/Cuda.cpp @@ -3,7 +3,9 @@ #define LANTERN_BUILD #include "lantern/lantern.h" - +#ifdef __NVCC__ +#include +#endif #include #include "utils.hpp" @@ -35,3 +37,16 @@ void _lantern_cuda_show_config() std::cout << at::detail::getCUDAHooks().showConfig() << std::endl; LANTERN_FUNCTION_END_VOID } + +void* _lantern_cuda_get_device_capability(int64_t device) +{ + LANTERN_FUNCTION_START + #ifdef __NVCC__ + cudaDeviceProp * devprop = at::cuda::getDeviceProperties(device); + std::vector cap = {devprop->major, devprop->minor}; + return (void*) new LanternObject>(cap); + #else + throw std::runtime_error("`cuda_get_device` is only supported on CUDA runtimes."); + #endif + LANTERN_FUNCTION_END +} diff --git a/man/cuda_get_device_capability.Rd b/man/cuda_get_device_capability.Rd new file mode 100644 index 0000000000..599231c417 --- /dev/null +++ b/man/cuda_get_device_capability.Rd @@ -0,0 +1,14 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cuda.R +\name{cuda_get_device_capability} +\alias{cuda_get_device_capability} +\title{Returns the major and minor CUDA capability of \code{device}} +\usage{ +cuda_get_device_capability(device) +} +\arguments{ +\item{device}{Integer value of the CUDA device to return capabilities of.} +} +\description{ +Returns the major and minor CUDA capability of \code{device} +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 2e64e4af6d..be6405d01f 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -478,6 +478,17 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// cpp_cuda_get_device_capability +XPtrTorchvector_int64_t cpp_cuda_get_device_capability(int64_t device); +RcppExport SEXP _torch_cpp_cuda_get_device_capability(SEXP deviceSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< int64_t >::type device(deviceSEXP); + rcpp_result_gen = Rcpp::wrap(cpp_cuda_get_device_capability(device)); + return rcpp_result_gen; +END_RCPP +} // cpp_device_type_to_string std::string cpp_device_type_to_string(Rcpp::XPtr device); RcppExport SEXP _torch_cpp_device_type_to_string(SEXP deviceSEXP) { @@ -35172,6 +35183,7 @@ static const R_CallMethodDef CallEntries[] = { {"_torch_cpp_cuda_is_available", (DL_FUNC) &_torch_cpp_cuda_is_available, 0}, {"_torch_cpp_cuda_device_count", (DL_FUNC) &_torch_cpp_cuda_device_count, 0}, {"_torch_cpp_cuda_current_device", (DL_FUNC) &_torch_cpp_cuda_current_device, 0}, + {"_torch_cpp_cuda_get_device_capability", (DL_FUNC) &_torch_cpp_cuda_get_device_capability, 1}, {"_torch_cpp_device_type_to_string", (DL_FUNC) &_torch_cpp_device_type_to_string, 1}, {"_torch_cpp_device_index_to_int", (DL_FUNC) &_torch_cpp_device_index_to_int, 1}, {"_torch_cpp_torch_device", (DL_FUNC) &_torch_cpp_torch_device, 2}, diff --git a/src/cuda.cpp b/src/cuda.cpp index b08b0928af..34704f38ea 100644 --- a/src/cuda.cpp +++ b/src/cuda.cpp @@ -15,3 +15,8 @@ int cpp_cuda_device_count () { int64_t cpp_cuda_current_device() { return lantern_cuda_current_device(); } + +// [[Rcpp::export]] +XPtrTorchvector_int64_t cpp_cuda_get_device_capability(int64_t device) { + return lantern_cuda_get_device_capability(device); +} diff --git a/src/lantern/lantern.h b/src/lantern/lantern.h index bd8a3c3730..844f3fcf9d 100644 --- a/src/lantern/lantern.h +++ b/src/lantern/lantern.h @@ -2163,6 +2163,15 @@ HOST_API void* lantern_optional_tensor_value (void* x) return ret; } +LANTERN_API void* (LANTERN_PTR _lantern_cuda_get_device_capability) (int64_t device); +HOST_API void* lantern_cuda_get_device_capability (int64_t device) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_cuda_get_device_capability(device); + LANTERN_HOST_HANDLER; + return ret; +} + /* Autogen Headers -- Start */ LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking); HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; } @@ -7749,6 +7758,7 @@ LOAD_SYMBOL(_lantern_OptionalTensorList_size); LOAD_SYMBOL(_lantern_OptionalTensorList_at); LOAD_SYMBOL(_lantern_OptionalTensorList_at_is_null); LOAD_SYMBOL(_lantern_optional_tensor_value); +LOAD_SYMBOL(_lantern_cuda_get_device_capability); /* Autogen Symbols -- Start */ LOAD_SYMBOL(_lantern__cast_byte_tensor_bool) LOAD_SYMBOL(_lantern__cast_char_tensor_bool) diff --git a/tests/testthat/test-cuda.R b/tests/testthat/test-cuda.R index 552f4aee5d..de4ef6df4e 100644 --- a/tests/testthat/test-cuda.R +++ b/tests/testthat/test-cuda.R @@ -6,6 +6,11 @@ test_that("cuda", { expect_true(cuda_device_count() > 0) expect_true(cuda_current_device() >= 0) expect_true(cuda_is_available()) + + capability <- cuda_get_device_capability(cuda_current_device()) + expect_type(capability, "integer") + expect_length(capability, 2) + expect_error(cuda_get_device_capability(cuda_device_count() + 1), "device must be an integer between 0 and") }) test_that("cuda tensors", {