Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] remove if constexpr(), which is not supported on gcc54 #50421

Merged
merged 4 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 12 additions & 54 deletions paddle/phi/kernels/sparse/gpu/conv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];

if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
dispatchKernel(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
cutlass,
x.dtype());
}
} else {
#endif
Expand Down
61 changes: 61 additions & 0 deletions paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
Expand Down Expand Up @@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
static void dispatchKernel(const GPUContext& dev_ctx,
const void* const a,
const void* const b,
const void* const c,
void* const d,
const int m,
const int n,
const int k,
const void* a_indices,
const void* c_d_indices,
const bool cutlass,
const phi::DataType type) {
if (!cutlass) return;

if (type == phi::DataType::FLOAT16) {
fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const cutlass::half_t*>(a),
static_cast<const cutlass::half_t*>(b),
static_cast<const cutlass::half_t*>(c),
static_cast<cutlass::half_t*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
} else if (type == phi::DataType::FLOAT32) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
static_cast<const float*>(a),
static_cast<const float*>(b),
static_cast<const float*>(c),
static_cast<float*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<float>(1),
static_cast<float>(1));
} else if (type == phi::DataType::FLOAT64) {
fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const double*>(a),
static_cast<const double*>(b),
static_cast<const double*>(c),
static_cast<double*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<double>(1),
static_cast<double>(1));
}
}

struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
Expand Down