diff --git a/src/broadcast_kernel.cu b/src/broadcast_kernel.cu index 2b02a14b..e7062f89 100644 --- a/src/broadcast_kernel.cu +++ b/src/broadcast_kernel.cu @@ -241,11 +241,11 @@ void BroadcastBackwardKernelGPU( // cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO); // Sort COO first - thrust::sort_by_key(thrust::device, // - d_out_map, // key begin - d_out_map + nnz, // key end - d_in_map // value begin - ); + THRUST_CHECK(thrust::sort_by_key(thrust::device, // + d_out_map, // key begin + d_out_map + nnz, // key end + d_in_map // value begin + )); cusparseSpMMAlg_t mm_alg; #if defined(CUDART_VERSION) && (CUDART_VERSION < 10010) diff --git a/src/coordinate_map_gpu.cu b/src/coordinate_map_gpu.cu index dda94a92..fb7325d1 100644 --- a/src/coordinate_map_gpu.cu +++ b/src/coordinate_map_gpu.cu @@ -973,7 +973,7 @@ CoordinateFieldMapGPU::kernel_map( CUDA_CHECK(cudaStreamSynchronize(0)); LOG_DEBUG("Preallocated kernel map done"); - kernel_map.decompose(); + THRUST_CHECK(kernel_map.decompose()); base_type::m_byte_allocator.deallocate( reinterpret_cast(d_p_count_per_thread), num_threads * sizeof(index_type)); @@ -1730,7 +1730,7 @@ CoordinateMapGPU::kernel_map( CUDA_CHECK(cudaMemcpy(kernel_map.out_maps.data(), d_p_valid_out_index, valid_size * sizeof(index_type), cudaMemcpyDeviceToDevice)); - kernel_map.decompose(); + THRUST_CHECK(kernel_map.decompose()); base_type::m_byte_allocator.deallocate( reinterpret_cast(d_p_valid_in_index), @@ -1961,7 +1961,7 @@ CoordinateMapGPU::origin_map( m_coordinate_size); CUDA_CHECK(cudaStreamSynchronize(0)); - kernel_map.decompose(); + THRUST_CHECK(kernel_map.decompose()); LOG_DEBUG("origin map decomposed"); return kernel_map; diff --git a/src/gpu.cuh b/src/gpu.cuh index e00a6e1a..051307e8 100644 --- a/src/gpu.cuh +++ b/src/gpu.cuh @@ -155,6 +155,13 @@ namespace minkowski { << __FILE__ << ":" << __LINE__); \ } +#define THRUST_CATCH \ + catch (thrust::system_error e) { \ + throw std::runtime_error(Formatter() \ + << "Thrust error: " << e.what() << " at " \ + << __FILE__ << ":" << __LINE__); \ + } + // CUDA: library error reporting. const char *cublasGetErrorString(cublasStatus_t error); diff --git a/src/kernel_map.cuh b/src/kernel_map.cuh index 82251df4..877bddc5 100644 --- a/src/kernel_map.cuh +++ b/src/kernel_map.cuh @@ -314,15 +314,18 @@ public: LOG_DEBUG("Decomposing", kernels.end() - kernels.begin(), "elements"); // the memory space must be initialized first! // sort - thrust::sort_by_key(thrust::device, // - kernels.begin(), // key begin - kernels.end(), // key end - thrust::make_zip_iterator( // value begin - thrust::make_tuple( // - in_maps.begin(), // - out_maps.begin() // - ) // - )); + try { + thrust::sort_by_key(thrust::device, // + kernels.begin(), // key begin + kernels.end(), // key end + thrust::make_zip_iterator( // value begin + thrust::make_tuple( // + in_maps.begin(), // + out_maps.begin() // + ) // + )); + } + THRUST_CATCH; #ifdef DEBUG size_type map_size = @@ -357,21 +360,25 @@ public: gpu_storage out_key_min(m_capacity); gpu_storage out_key_size(m_capacity); - auto end = thrust::reduce_by_key( - thrust::device, // policy - kernels.begin(), // key begin - kernels.end(), // key end - thrust::make_zip_iterator( - thrust::make_tuple(min_begin, size_begin)), // value begin - out_key.begin(), // key out begin - thrust::make_zip_iterator(thrust::make_tuple( - out_key_min.begin(), out_key_size.begin())), // value out begin - thrust::equal_to(), // key equal binary predicate - detail::min_size_functor() // value binary operator - ); - - size_type num_unique_keys = end.first - out_key.begin(); - LOG_DEBUG(num_unique_keys, "unique kernel map keys found"); + size_type num_unique_keys; + + try { + auto end = thrust::reduce_by_key( + thrust::device, // policy + kernels.begin(), // key begin + kernels.end(), // key end + thrust::make_zip_iterator( + thrust::make_tuple(min_begin, size_begin)), // value begin + out_key.begin(), // key out begin + thrust::make_zip_iterator(thrust::make_tuple( + out_key_min.begin(), out_key_size.begin())), // value out begin + thrust::equal_to(), // key equal binary predicate + detail::min_size_functor() // value binary operator + ); + num_unique_keys = end.first - out_key.begin(); + LOG_DEBUG(num_unique_keys, "unique kernel map keys found"); + } + THRUST_CATCH; auto const cpu_out_keys = out_key.to_vector(num_unique_keys); auto const cpu_out_offset = out_key_min.to_vector(num_unique_keys); diff --git a/src/pooling_avg_kernel.cu b/src/pooling_avg_kernel.cu index 9e62186a..449ac86e 100644 --- a/src/pooling_avg_kernel.cu +++ b/src/pooling_avg_kernel.cu @@ -214,10 +214,10 @@ void NonzeroAvgPoolingForwardKernelGPU( CUDA_CHECK(cudaMemcpy(sorted_col_ptr, kernel_map.in_maps.begin(), sparse_nnzs * sizeof(Itype), cudaMemcpyDeviceToDevice)); - thrust::sort_by_key(thrust::device, // - sorted_row_ptr, // key begin - sorted_row_ptr + sparse_nnzs, // key end - sorted_col_ptr); + THRUST_CHECK(thrust::sort_by_key(thrust::device, // + sorted_row_ptr, // key begin + sorted_row_ptr + sparse_nnzs, // key end + sorted_col_ptr)); // +---------+ +---+ // | spm | | i | @@ -280,16 +280,18 @@ void NonzeroAvgPoolingForwardKernelGPU( (Dtype *)allocator.allocate(sparse_nnzs * sizeof(Dtype)); // reduce by key - auto end = thrust::reduce_by_key(thrust::device, // policy - sorted_row_ptr, // key begin - sorted_row_ptr + sparse_nnzs, // key end - d_ones, // value begin - unique_row_ptr, // key out begin - reduced_val_ptr // value out begin - ); - - int num_unique_keys = end.first - unique_row_ptr; - LOG_DEBUG("Num unique keys:", num_unique_keys); + int num_unique_keys; + try { + auto end = thrust::reduce_by_key(thrust::device, // policy + sorted_row_ptr, // key begin + sorted_row_ptr + sparse_nnzs, // key end + d_ones, // value begin + unique_row_ptr, // key out begin + reduced_val_ptr // value out begin + ); + num_unique_keys = end.first - unique_row_ptr; + LOG_DEBUG("Num unique keys:", num_unique_keys); + } THRUST_CATCH; #ifdef DEBUG Itype *p_unique_row = (Itype *)std::malloc(num_unique_keys * sizeof(Itype)); diff --git a/src/pooling_max_kernel.cu b/src/pooling_max_kernel.cu index 28f4be5a..6c93a1c9 100644 --- a/src/pooling_max_kernel.cu +++ b/src/pooling_max_kernel.cu @@ -147,7 +147,7 @@ void max_pool_forward_pointer_kernel_gpu( MapItype *d_reduced_out_map = d_scr + 2 * nmap + 2; // reduced output maps // create number of in_feat per out, and starting index - thrust::sequence(thrust::device, d_index, d_index + nmap); + THRUST_CHECK(thrust::sequence(thrust::device, d_index, d_index + nmap)); //////////////////////////////// // Reduction @@ -155,23 +155,26 @@ void max_pool_forward_pointer_kernel_gpu( // sort d_out_map and d_in_map with the d_out_map so that in_feat are // placed adjacent according to out_map if (!is_sorted) - thrust::sort_by_key(thrust::device, d_out_map, d_out_map + nmap, d_in_map); + THRUST_CHECK(thrust::sort_by_key(thrust::device, d_out_map, + d_out_map + nmap, d_in_map)); thrust::equal_to equal_pred; thrust::minimum min_op; - - auto reduction_pair = - thrust::reduce_by_key(thrust::device, // execution policy - d_out_map, // key begin - d_out_map + nmap, // key end - d_index, // val begin - d_reduced_out_map, // key out begin - d_in_map_min, // val out begin - equal_pred, // binary pred - min_op); // binary op - CUDA_CHECK(cudaStreamSynchronize(0)); - - size_t num_unique_out_map = reduction_pair.first - d_reduced_out_map; + size_t num_unique_out_map; + + try { + auto reduction_pair = + thrust::reduce_by_key(thrust::device, // execution policy + d_out_map, // key begin + d_out_map + nmap, // key end + d_index, // val begin + d_reduced_out_map, // key out begin + d_in_map_min, // val out begin + equal_pred, // binary pred + min_op); // binary op + CUDA_CHECK(cudaStreamSynchronize(0)); + num_unique_out_map = reduction_pair.first - d_reduced_out_map; + } THRUST_CATCH; #ifdef DEBUG std::cout << "num_unique_out_map: " << num_unique_out_map << "\n"; diff --git a/src/spmm.cu b/src/spmm.cu index 00879aae..2a8404d1 100644 --- a/src/spmm.cu +++ b/src/spmm.cu @@ -235,15 +235,15 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, CUDA_CHECK(cudaMemcpy(sorted_val_ptr, values_ptr, nnz * sizeof(scalar_t), cudaMemcpyDeviceToDevice)); - thrust::sort_by_key(thrust::device, // - sorted_row_ptr, // key begin - sorted_row_ptr + nnz, // key end - thrust::make_zip_iterator( // value begin - thrust::make_tuple( // - sorted_col_ptr, // - sorted_val_ptr // - ) // - )); + THRUST_CHECK(thrust::sort_by_key(thrust::device, // + sorted_row_ptr, // key begin + sorted_row_ptr + nnz, // key end + thrust::make_zip_iterator( // value begin + thrust::make_tuple( // + sorted_col_ptr, // + sorted_val_ptr // + ) // + ))); LOG_DEBUG("sorted row", cudaDeviceSynchronize()); } else { sorted_row_ptr = row_indices_ptr; @@ -481,10 +481,10 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols, CUDA_CHECK(cudaMemcpy(sorted_col_ptr, col_indices_ptr, nnz * sizeof(th_int_type), cudaMemcpyDeviceToDevice)); - thrust::sort_by_key(thrust::device, // - sorted_row_ptr, // key begin - sorted_row_ptr + nnz, // key end - sorted_col_ptr); + THRUST_CHECK(thrust::sort_by_key(thrust::device, // + sorted_row_ptr, // key begin + sorted_row_ptr + nnz, // key end + sorted_col_ptr)); ///////////////////////////////////////////////////////////////////////// // Create vals @@ -496,21 +496,20 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols, (scalar_t *)c10::cuda::CUDACachingAllocator::raw_alloc( nnz * sizeof(scalar_t)); torch::Tensor ones = at::ones({nnz}, mat2.options()); - - // reduce by key - auto end = thrust::reduce_by_key( - thrust::device, // policy - sorted_row_ptr, // key begin - sorted_row_ptr + nnz, // key end - reinterpret_cast(ones.data_ptr()), // value begin - unique_row_ptr, // key out begin - reduced_val_ptr // value out begin - ); - - int num_unique_keys = end.first - unique_row_ptr; - LOG_DEBUG("Num unique keys:", num_unique_keys); - - // Create values + int num_unique_keys; + try { + // reduce by key + auto end = thrust::reduce_by_key( + thrust::device, // policy + sorted_row_ptr, // key begin + sorted_row_ptr + nnz, // key end + reinterpret_cast(ones.data_ptr()), // value begin + unique_row_ptr, // key out begin + reduced_val_ptr // value out begin + ); + num_unique_keys = end.first - unique_row_ptr; + LOG_DEBUG("Num unique keys:", num_unique_keys); + } THRUST_CATCH; // Copy the results to the correct output inverse_val diff --git a/tests/cpp/coordinate_map_gpu_test.cu b/tests/cpp/coordinate_map_gpu_test.cu index 3fb62222..44a39ad2 100644 --- a/tests/cpp/coordinate_map_gpu_test.cu +++ b/tests/cpp/coordinate_map_gpu_test.cu @@ -178,8 +178,8 @@ coordinate_map_batch_find_test(const torch::Tensor &coordinates, std::vector cpu_firsts(NR); std::vector cpu_seconds(NR); - thrust::copy(firsts.cbegin(), firsts.cend(), cpu_firsts.begin()); - thrust::copy(seconds.cbegin(), seconds.cend(), cpu_seconds.begin()); + THRUST_CHECK(thrust::copy(firsts.cbegin(), firsts.cend(), cpu_firsts.begin())); + THRUST_CHECK(thrust::copy(seconds.cbegin(), seconds.cend(), cpu_seconds.begin())); return std::make_pair(cpu_firsts, cpu_seconds); }