Skip to content

Commit

Permalink
int8 ptq moe add grouped gemm (PaddlePaddle#112)
Browse files Browse the repository at this point in the history
* add & fix zeus int8

* ptq int8 moe add grouped gemm
  • Loading branch information
tianyan01 authored Jan 23, 2024
1 parent a20e3e2 commit fd736c0
Show file tree
Hide file tree
Showing 20 changed files with 2,600 additions and 191 deletions.
39 changes: 19 additions & 20 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,26 @@ class AttnMatmulINT8 {
(void*)workspace->data<int8_t>(),
workspace->numel());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
// PADDLE_ENFORCE_EQ(cudaGetLastError(),
// cudaSuccess,
// platform::errors::Fatal(
// "cuda error occured after computing bias. "
// "But it does not mean this error is caused by "
// "bias computing"));
dequantize_addbias_kernel_launcher<T, true>(output_tmp->data<int32_t>(),
bias->data<T>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
} else {
dequantize_addbias_kernel_launcher<T, false>(output_tmp->data<int32_t>(),
nullptr,
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
}
}

Expand Down
286 changes: 163 additions & 123 deletions paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu

Large diffs are not rendered by default.

33 changes: 19 additions & 14 deletions paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,26 @@ void MatMulINT8ToT(const phi::GPUContext& dev_ctx,
(void*)workspace->data<int8_t>(),
workspace->numel());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m,
n,
dev_ctx.stream(),
gpu_config.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias) {
// bias_out = output + bias
std::vector<const Tensor*> ins = {output, bias};
std::vector<Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor<T>());
dequantize_addbias_kernel_launcher<T, true>(output_tmp->data<int32_t>(),
bias->data<T>(),
output->data<T>(),
m,
n,
dev_ctx.stream(),
gpu_config.get(),
quant_in_scale,
dequant_out_scale->data<float>());
} else {
dequantize_addbias_kernel_launcher<T, false>(output_tmp->data<int32_t>(),
nullptr,
output->data<T>(),
m,
n,
dev_ctx.stream(),
gpu_config.get(),
quant_in_scale,
dequant_out_scale->data<float>());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h"
#include "paddle/phi/kernels/weight_only_linear_kernel.h"

PADDLE_DEFINE_EXPORTED_bool(enable_moe_gemm_cutlass,
false,
"enable moe gemm cutlass ,default false");
DECLARE_bool(enable_moe_gemm_cutlass);

namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
Expand Down
74 changes: 74 additions & 0 deletions paddle/fluid/operators/fused/moe_expert_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#if defined(PADDLE_WITH_CUTLASS)
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h"
#endif
namespace paddle {
namespace operators {
Expand Down Expand Up @@ -134,5 +135,78 @@ class MoeExpertGemmWeightOnly {
bool is_uint4_ = false;
};

// for ptq
template <typename T>
void FusedGroupedMatMul(const phi::GPUContext& dev_ctx,
const int8_t* weight, // int8
const T* input, // fp16, shape is [fwd_bsz, k]
Tensor* input_tmp, // int8
const float* quant_in_scale,
const T* bias, // fp16
T* output_deq,
const float* dequant_out_scale, // fp32
const int64_t* fwd_expert_count_cumsum, // int64
const int64_t* fwd_expert_count_cumsum_cpu, // cpu
int fwd_bsz,
int num_expert,
int m,
int n,
int k,
const bool do_activation,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
#if defined(PADDLE_WITH_CUTLASS)
int64_t offset = 0;
for (int i = 0; i < num_expert; ++i) {
int64_t cur_m = *(fwd_expert_count_cumsum_cpu + i + 1) - *(fwd_expert_count_cumsum_cpu + i);
if (cur_m == 0) {
continue;
}
quantize_kernel_launcher<T>(input + offset,
input_tmp->data<int8_t>() + offset,
*(quant_in_scale + i),
cur_m, // cur m
k,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx.stream());
offset += cur_m * k;
};

// group_quantize_kernel_launcher(input,
// input_tmp->data<int8_t>(),
// quant_in_scale,
// fwd_expert_count,
// num_expert,
// m,
// k,
// quant_round_type,
// quant_max_bound,
// quant_min_bound,
// dev_ctx.stream());

using half_dtype = typename phi::PDDataTypeTraits<paddle::platform::float16>::DataType;
auto moe_gemm_runner = phi::PTQMoeGemmRunner<half_dtype>();
// int8 gemm & dequant & add biad & act(optional)
moe_gemm_runner.moe_gemm_bias_act(input_tmp->data<int8_t>(),
weight,
dequant_out_scale,
reinterpret_cast<const half_dtype*>(bias), // bias
reinterpret_cast<half_dtype*>(output_deq),
fwd_expert_count_cumsum,
fwd_bsz,
n,
k,
num_expert,
do_activation,
dev_ctx.stream());
#else
PADDLE_THROW(platform::errors::InvalidArgument(
"this machine not support FusedGroupedMatMul use cutlass"));
#endif
}

} // namespace operators
} // namespace paddle
27 changes: 25 additions & 2 deletions paddle/fluid/operators/fused/quant_dequant_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ void quantize_kernel_launcher(const T* input,
min_bound);
}

template <typename T, int VecSize>
template <typename T, int VecSize, bool ComputeBias=false>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const T* bias,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
Expand All @@ -172,16 +173,23 @@ __global__ void dequantize_kernel(T* output,

phi::AlignedVector<int32_t, VecSize> in_vec;
phi::AlignedVector<float, VecSize> out_scale_vec;
phi::AlignedVector<T, VecSize> bias_vec;
phi::AlignedVector<T, VecSize> out_vec;

for (; idx < numel; idx += stride) {
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
if (ComputeBias) {
phi::Load<T, VecSize>(bias + col_id, &bias_vec);
}

#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] =
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
if (ComputeBias) {
out_vec[i] += bias_vec[i];
}
}

phi::Store<T, VecSize>(out_vec, output + idx);
Expand All @@ -199,7 +207,22 @@ void dequantize_kernel_launcher(const int32_t* input,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
output, input, nullptr, m, n, quant_in_scale, dequant_out_scale_data);
}

template <typename T, bool ComputeBias>
void dequantize_addbias_kernel_launcher(const int32_t* input,
const T* bias,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize, ComputeBias>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, bias, m, n, quant_in_scale, dequant_out_scale_data);
}

} // namespace operators
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1020,3 +1020,9 @@ PADDLE_DEFINE_EXPORTED_bool(
PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");
/**
* CUTLASS related FLAG
*/
PADDLE_DEFINE_EXPORTED_bool(enable_moe_gemm_cutlass,
false,
"enable moe gemm cutlass ,default false");
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ if(WITH_CUTLASS)
"fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/*.cu"
"fusion/cutlass/cutlass_kernels/fpA_intB_gemm/*.cu"
"fusion/cutlass/cutlass_kernels/moe_gemm/autogen/*.cu"
"fusion/cutlass/cutlass_kernels/moe_gemm/*.cu")
"fusion/cutlass/cutlass_kernels/moe_gemm/*.cu"
"fusion/cutlass/cutlass_kernels/ptq_moe_gemm/*.cu")
list(APPEND kernel_cu ${cutlass_cu})
endif()

Expand Down
Loading

0 comments on commit fd736c0

Please sign in to comment.