Skip to content

Commit

Permalink
[cherry-pick] remove if constexpr(), which is not supported on gcc54 (#…
Browse files Browse the repository at this point in the history
…50421)

att, cherry-pick #48563
  • Loading branch information
zhangkaihuo authored Feb 10, 2023
1 parent eb61074 commit 913f40e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 54 deletions.
66 changes: 12 additions & 54 deletions paddle/phi/kernels/sparse/gpu/conv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];

if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
dispatchKernel(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
cutlass,
x.dtype());
}
} else {
#endif
Expand Down
61 changes: 61 additions & 0 deletions paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
Expand Down Expand Up @@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
static void dispatchKernel(const GPUContext& dev_ctx,
const void* const a,
const void* const b,
const void* const c,
void* const d,
const int m,
const int n,
const int k,
const void* a_indices,
const void* c_d_indices,
const bool cutlass,
const phi::DataType type) {
if (!cutlass) return;

if (type == phi::DataType::FLOAT16) {
fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const cutlass::half_t*>(a),
static_cast<const cutlass::half_t*>(b),
static_cast<const cutlass::half_t*>(c),
static_cast<cutlass::half_t*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
} else if (type == phi::DataType::FLOAT32) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
static_cast<const float*>(a),
static_cast<const float*>(b),
static_cast<const float*>(c),
static_cast<float*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<float>(1),
static_cast<float>(1));
} else if (type == phi::DataType::FLOAT64) {
fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const double*>(a),
static_cast<const double*>(b),
static_cast<const double*>(c),
static_cast<double*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<double>(1),
static_cast<double>(1));
}
}

struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
Expand Down

0 comments on commit 913f40e

Please sign in to comment.