From fd736c0589a7a2328118f4548f14f9bd71aba1c8 Mon Sep 17 00:00:00 2001 From: tianyan01 <124153989+tianyan01@users.noreply.github.com> Date: Tue, 23 Jan 2024 19:03:43 +0800 Subject: [PATCH] int8 ptq moe add grouped gemm (#112) * add & fix zeus int8 * ptq int8 moe add grouped gemm --- paddle/fluid/operators/fused/attn_gemm_int8.h | 39 +- .../fused_multi_transformer_moe_int8_op.cu | 286 ++++++---- .../fused/fused_multi_transformer_moe_op.h | 33 +- ...ed_multi_transformer_moe_weight_only_op.cu | 5 +- .../fluid/operators/fused/moe_expert_gemm.h | 74 +++ .../operators/fused/quant_dequant_kernel.h | 27 +- paddle/fluid/platform/flags.cc | 6 + paddle/phi/kernels/CMakeLists.txt | 3 +- .../epilogue/thread/linear_combination_ptq.h | 169 ++++++ .../b2b_default_epilogue_tensor_op.h | 239 ++++++++ .../epilogue/threadblock/b2b_epilogue.h | 525 ++++++++++++++++++ .../gemm/kernel/default_moe_gemm_grouped.h | 230 ++++++++ .../gemm/kernel/ptq_moe_kernel.h | 447 +++++++++++++++ .../cutlass_kernels/cutlass_heuristic.h | 29 +- .../fpA_intB_gemm/fpA_intB_gemm_template.cu | 6 +- .../moe_gemm/moe_gemm_kernels_template.cu | 8 +- .../ptq_moe_gemm/ptq_moe_gemm.h | 62 +++ .../ptq_moe_gemm/ptq_moe_gemm_fp16.cu | 21 + .../ptq_moe_gemm/ptq_moe_gemm_template.h | 505 +++++++++++++++++ .../incubate/nn/layer/fused_transformer.py | 77 ++- 20 files changed, 2600 insertions(+), 191 deletions(-) create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/linear_combination_ptq.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_default_epilogue_tensor_op.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_epilogue.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_moe_gemm_grouped.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/ptq_moe_kernel.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_fp16.cu create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_template.h diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index ce392e98ba606..c4e812e59ee33 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -133,27 +133,26 @@ class AttnMatmulINT8 { (void*)workspace->data(), workspace->numel()); - dequantize_kernel_launcher(output_tmp->data(), - output->data(), - m_, - n_, - dev_ctx_.stream(), - gpu_config_.get(), - quant_in_scale, - dequant_out_scale->data()); - if (compute_bias_) { - // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); - // PADDLE_ENFORCE_EQ(cudaGetLastError(), - // cudaSuccess, - // platform::errors::Fatal( - // "cuda error occured after computing bias. " - // "But it does not mean this error is caused by " - // "bias computing")); + dequantize_addbias_kernel_launcher(output_tmp->data(), + bias->data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config_.get(), + quant_in_scale, + dequant_out_scale->data()); + } else { + dequantize_addbias_kernel_launcher(output_tmp->data(), + nullptr, + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config_.get(), + quant_in_scale, + dequant_out_scale->data()); } } diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu index 364f7f96ad226..ae47f02633704 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu @@ -13,14 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ // #define DEBUG_MOE_TMPROFILE_INT8 #include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" #include "paddle/fluid/operators/fused/layernorm_quant_dequant.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/fluid/operators/fused/moe_expert_gemm.h" #ifdef DEBUG_MOE_TMPROFILE_INT8 #include "paddle/fluid/platform/timer.h" #endif + +DECLARE_bool(enable_moe_gemm_cutlass); + namespace paddle { namespace operators { - using Tensor = phi::DenseTensor; // #define _DEBUG_FUSED_MULTI_TRANSFORMER @@ -65,6 +69,11 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { platform::Timer gate_nccl_tm, gather_tm, scatter_tm; all_tm.Start(); other_tm.Start(); +#endif +#ifndef PADDLE_WITH_CUTLASS + PADDLE_ENFORCE_EQ(FLAGS_enable_moe_gemm_cutlass, false, + "not support cutlass fused moe gemm please disable " + "FLAGS_enable_moe_gemm_cutlass"); #endif auto *time_step = ctx.Input("TimeStep"); // 0. input @@ -105,12 +114,16 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { ctx.MultiInput("ExpertWeight2OutScale"); auto *sequence_lengths = ctx.Input("SeqLengths"); + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); int beam_size = 1; if (beam_cache_offset) { beam_size = beam_cache_offset->dims()[1]; } + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); if (!pre_layer_norm) { @@ -184,6 +197,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { if (!is_support_flash_attn) { qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *softmax_out_data = dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); @@ -216,7 +230,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { bias_dropout_residual_out_data = dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); - uint8_t *dropout_mask_out_data = nullptr; // 6. moe layer: gate / expert_w & b / some attrs auto gate_weights = ctx.MultiInput("GateWeight"); @@ -269,21 +282,25 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { local_expert_count.numel() * sizeof(int64_t)); dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); - // fwd_expert_count, fwd_batch_size - Tensor fwd_expert_count, fwd_batch_size; - Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + // fwd_expert_count + Tensor fwd_expert_count, fwd_expert_count_cumsum; + Tensor fwd_expert_count_cumsum_cpu; fwd_expert_count.Resize({{num_expert}}); - fwd_batch_size.Resize({{1}}); + fwd_expert_count_cumsum.Resize({{num_expert + 1}}); dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); - dev_ctx.Alloc(&fwd_batch_size, - fwd_batch_size.numel() * sizeof(int64_t)); + auto fwd_expert_count_cumsum_data = dev_ctx.Alloc(&fwd_expert_count_cumsum, + fwd_expert_count_cumsum.numel() * sizeof(int64_t)); + phi::funcs::set_constant( + dev_ctx, &fwd_expert_count_cumsum, static_cast(0)); // pos, temp pos Tensor pos, temp_pos; pos.Resize({{out_batch_size}}); temp_pos.Resize({{out_batch_size}}); dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); - dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + if (topk > 1) { + dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + } // cumsum Tensor lec_cum; lec_cum.Resize({{tot_expert}}); @@ -304,12 +321,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { Tensor all_gather_out; all_gather_out.Resize({{bsz_seq, dim_embed}}); dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); - // topk tensor - Tensor topk_tensor; - topk_tensor.Resize({{1}}); - dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); - phi::FullKernel( - dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); // moe nccl phi::NCCLMoECollective moe_pg(dev_ctx, moe_ring_id, num_expert); @@ -334,16 +345,13 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { cublaslt_workspace.numel() * sizeof(int8_t)); // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - Tensor buf0, moe_out; buf0.Resize({{bsz_seq, dim_embed}}); dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); moe_out.ShareDataWith(*out); moe_out.Resize({{bsz_seq, dim_embed}}); - const T *x_data = input_x->data(); + #ifdef DEBUG_MOE_TMPROFILE_INT8 dev_ctx.Wait(); other_tm.Pause(); @@ -632,52 +640,36 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { global_expert_count.dtype(), false, &fwd_expert_count); - // fwd batch size - phi::SumKernel( - dev_ctx, - fwd_expert_count, - phi::IntArray({}), // axis is None - fwd_expert_count.dtype(), - false, - &fwd_batch_size); + // fwd batch size, we dont compute this + phi::CumsumTensorValue( + dev_ctx, fwd_expert_count, &fwd_expert_count_cumsum, 1); // step4.3 cumsum & assign pos #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, cumsum"; #endif - phi::CumsumKernel( - dev_ctx, local_expert_count, 0, false, false, false, &lec_cum); + phi::CumsumTensorValue(dev_ctx, local_expert_count, &lec_cum); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, assign pos"; #endif - phi::AssignPosCompute( - dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "moe, floor divide"; -#endif - if (topk > 1) { - phi::FloorDivideKernel( - dev_ctx, pos, topk_tensor, &temp_pos); - } else { - temp_pos = pos; - } + phi::AssignInsAndPosCompute( + dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size, topk, &temp_pos); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, tensor copy"; #endif framework::TensorCopy( - fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); - framework::TensorCopy( - fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + fwd_expert_count_cumsum, platform::CPUPlace(), &fwd_expert_count_cumsum_cpu); dev_ctx.Wait(); - int fwd_bsz = fwd_batch_size_cpu.data()[0]; + int fwd_bsz = fwd_expert_count_cumsum_cpu.data()[num_expert]; Tensor global_scatter_out; global_scatter_out.Resize({{fwd_bsz, dim_embed}}); - dev_ctx.Alloc(&global_scatter_out, - global_scatter_out.numel() * sizeof(T)); + auto global_scatter_out_data = dev_ctx.Alloc(&global_scatter_out, + global_scatter_out.numel() * sizeof(T)); Tensor all_expert_out; all_expert_out.Resize({{fwd_bsz, dim_embed}}); - dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + auto all_expert_out_data = dev_ctx.Alloc(&all_expert_out, + all_expert_out.numel() * sizeof(T)); // step 5, MOEScatter // step 5.1, index select @@ -714,81 +706,130 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { VLOG(0) << "moe, Expert Computation"; #endif if (fwd_bsz != 0) { - int last_index = 0; - for (int idx = 0; idx < num_expert; idx++) { - int cur_expert_count = fwd_expert_count_cpu.data()[idx]; - if (cur_expert_count <= 0) { - continue; + if (FLAGS_enable_moe_gemm_cutlass) { + // grouped gemm + int expert_idx = i * num_expert; + // expert + Tensor expert_in_tmp, expert_out1; // int8_t, int32_t + expert_in_tmp.Resize({{fwd_bsz, dim_feedforward}}); + auto expert_in_tmp_data = dev_ctx.Alloc(&expert_in_tmp, + expert_in_tmp.numel() * sizeof(int8_t)); + + expert_out1.Resize({{fwd_bsz, dim_feedforward}}); + auto expert_out1_data = dev_ctx.Alloc(&expert_out1, + expert_out1.numel() * sizeof(T)); // dequant 输出, fp16 + // gemm1, do act + FusedGroupedMatMul(dev_ctx, + expert_weights1[expert_idx]->data(), + global_scatter_out_data, + &expert_in_tmp, + &expert_weight1_in_scale[expert_idx], + expert_biases1[expert_idx]->data(), + expert_out1_data, // dequant & bias & gelu output + expert_weight1_out_scales[expert_idx]->data(), + fwd_expert_count_cumsum_data, + fwd_expert_count_cumsum_cpu.data(), + fwd_bsz, + num_expert, + fwd_bsz, + dim_feedforward, + dim_embed, + true); + // gemm2, no act + FusedGroupedMatMul(dev_ctx, + expert_weights2[expert_idx]->data(), + expert_out1_data, + &expert_in_tmp, + &expert_weight2_in_scale[expert_idx], + expert_biases2[expert_idx]->data(), + all_expert_out_data, // dequant output + expert_weight2_out_scales[expert_idx]->data(), + fwd_expert_count_cumsum_data, + fwd_expert_count_cumsum_cpu.data(), + fwd_bsz, + num_expert, + fwd_bsz, + dim_embed, + dim_feedforward, + false); + } else { + int last_index = 0; + int64_t *csum_len = fwd_expert_count_cumsum_cpu.data(); + for (int idx = 0; idx < num_expert; idx++) { + int end = csum_len[idx + 1]; + int cur_expert_count = end - last_index; + if (cur_expert_count <= 0) { + continue; + } + + Tensor expert_in_tmp; // int8_t + expert_in_tmp.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.Alloc(&expert_in_tmp, + expert_in_tmp.numel() * sizeof(int8_t)); + + Tensor expert_out1; // int32_t + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.Alloc(&expert_out1, + expert_out1.numel() * sizeof(int32_t)); + + // input is int32_t, output is int8_t + FusedDropoutHelper + fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = + global_scatter_out.Slice(last_index, end); // fp16, T + int expert_idx = i * num_expert + idx; + // T to int8_t, matmul, dont compute bias + MatMulTToINT8(dev_ctx, + expert_weights1[expert_idx], + expert_weight1_in_scale[expert_idx], + &tmp_inp, + &expert_in_tmp, + &expert_out1, + cur_expert_count, + dim_feedforward, + dim_embed, + &cublaslt_workspace, // maybe space not enough + quant_round_type, + quant_max_bound, + quant_min_bound); + // act bias, input is int32_t, output is int8_t + fused_act_dropout_helper.DropoutActBias( + dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + expert_in_tmp.data(), // output + nullptr, + expert_weight1_in_scale[expert_idx], + expert_weight1_out_scales[expert_idx]->data(), + 0, // data offset + expert_weight2_in_scale[expert_idx], + quant_round_type, + quant_max_bound, + quant_min_bound, + approximate); + + // T(fp16) + Tensor expert_out2 = all_expert_out.Slice(last_index, end); + // linear2, int8_t to T + MatMulINT8ToT(dev_ctx, + expert_weights2[expert_idx], + expert_weight2_in_scale[expert_idx], + &expert_in_tmp, // input + expert_biases2[expert_idx], + &expert_out2, + &expert_out1, // output_tmp + &expert_out2, + expert_weight2_out_scales[expert_idx], + cur_expert_count, + dim_embed, + dim_feedforward, + true, + &cublaslt_workspace); + last_index = end; } - int end = cur_expert_count + last_index; - - Tensor expert_in_tmp; // int8_t - expert_in_tmp.Resize({{cur_expert_count, dim_feedforward}}); - dev_ctx.Alloc(&expert_in_tmp, - expert_in_tmp.numel() * sizeof(int8_t)); - - Tensor expert_out1; // int32_t - expert_out1.Resize({{cur_expert_count, dim_feedforward}}); - dev_ctx.Alloc(&expert_out1, - expert_out1.numel() * sizeof(int32_t)); - - // input is int32_t, output is int8_t - FusedDropoutHelper - fused_act_dropout_helper( - dev_ctx, cur_expert_count, dim_feedforward, dropout_param); - - Tensor tmp_inp = - global_scatter_out.Slice(last_index, end); // fp16, T - int expert_idx = i * num_expert + idx; - // T to int8_t, matmul, dont compute bias - MatMulTToINT8(dev_ctx, - expert_weights1[expert_idx], - expert_weight1_in_scale[expert_idx], - &tmp_inp, - &expert_in_tmp, - &expert_out1, - cur_expert_count, - dim_feedforward, - dim_embed, - &cublaslt_workspace, // maybe space not enough - quant_round_type, - quant_max_bound, - quant_min_bound); - // act bias, input is int32_t, output is int8_t - fused_act_dropout_helper.DropoutActBias( - dev_ctx, - expert_out1.data(), - expert_biases1[expert_idx]->data(), - "gelu", - expert_in_tmp.data(), // output - nullptr, - expert_weight1_in_scale[expert_idx], - expert_weight1_out_scales[expert_idx]->data(), - 0, // data offset - expert_weight2_in_scale[expert_idx], - quant_round_type, - quant_max_bound, - quant_min_bound, - approximate); - - // T(fp16) - Tensor expert_out2 = all_expert_out.Slice(last_index, end); - // linear2, int8_t to T - MatMulINT8ToT(dev_ctx, - expert_weights2[expert_idx], - expert_weight2_in_scale[expert_idx], - &expert_in_tmp, // input - expert_biases2[expert_idx], - &expert_out2, - &expert_out1, // output_tmp - &expert_out2, - expert_weight2_out_scales[expert_idx], - cur_expert_count, - dim_embed, - dim_feedforward, - true, - &cublaslt_workspace); - last_index = end; } } else { all_expert_out = global_scatter_out; @@ -924,5 +965,4 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( fused_multi_transformer_moe_int8, - ops::FusedMultiTransformerMoeINT8OpKernel, - ops::FusedMultiTransformerMoeINT8OpKernel); \ No newline at end of file + ops::FusedMultiTransformerMoeINT8OpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h index df9f6c1bd255e..ee5249c88e70e 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h @@ -88,21 +88,26 @@ void MatMulINT8ToT(const phi::GPUContext& dev_ctx, (void*)workspace->data(), workspace->numel()); - dequantize_kernel_launcher(output_tmp->data(), - output->data(), - m, - n, - dev_ctx.stream(), - gpu_config.get(), - quant_in_scale, - dequant_out_scale->data()); - if (compute_bias) { - // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + dequantize_addbias_kernel_launcher(output_tmp->data(), + bias->data(), + output->data(), + m, + n, + dev_ctx.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scale->data()); + } else { + dequantize_addbias_kernel_launcher(output_tmp->data(), + nullptr, + output->data(), + m, + n, + dev_ctx.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scale->data()); } } diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu index ef21ff0d0cef9..87a7e24566ffb 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_weight_only_op.cu @@ -21,9 +21,8 @@ limitations under the License. */ #include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h" #include "paddle/phi/kernels/weight_only_linear_kernel.h" -PADDLE_DEFINE_EXPORTED_bool(enable_moe_gemm_cutlass, - false, - "enable moe gemm cutlass ,default false"); +DECLARE_bool(enable_moe_gemm_cutlass); + namespace paddle { namespace operators { using Tensor = phi::DenseTensor; diff --git a/paddle/fluid/operators/fused/moe_expert_gemm.h b/paddle/fluid/operators/fused/moe_expert_gemm.h index 3fcf40c4a59e6..24ba632188c25 100644 --- a/paddle/fluid/operators/fused/moe_expert_gemm.h +++ b/paddle/fluid/operators/fused/moe_expert_gemm.h @@ -17,6 +17,7 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUTLASS) #include "paddle/phi/common/datatype_traits.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h" #endif namespace paddle { namespace operators { @@ -134,5 +135,78 @@ class MoeExpertGemmWeightOnly { bool is_uint4_ = false; }; +// for ptq +template +void FusedGroupedMatMul(const phi::GPUContext& dev_ctx, + const int8_t* weight, // int8 + const T* input, // fp16, shape is [fwd_bsz, k] + Tensor* input_tmp, // int8 + const float* quant_in_scale, + const T* bias, // fp16 + T* output_deq, + const float* dequant_out_scale, // fp32 + const int64_t* fwd_expert_count_cumsum, // int64 + const int64_t* fwd_expert_count_cumsum_cpu, // cpu + int fwd_bsz, + int num_expert, + int m, + int n, + int k, + const bool do_activation, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { +#if defined(PADDLE_WITH_CUTLASS) + int64_t offset = 0; + for (int i = 0; i < num_expert; ++i) { + int64_t cur_m = *(fwd_expert_count_cumsum_cpu + i + 1) - *(fwd_expert_count_cumsum_cpu + i); + if (cur_m == 0) { + continue; + } + quantize_kernel_launcher(input + offset, + input_tmp->data() + offset, + *(quant_in_scale + i), + cur_m, // cur m + k, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx.stream()); + offset += cur_m * k; + }; + + // group_quantize_kernel_launcher(input, + // input_tmp->data(), + // quant_in_scale, + // fwd_expert_count, + // num_expert, + // m, + // k, + // quant_round_type, + // quant_max_bound, + // quant_min_bound, + // dev_ctx.stream()); + + using half_dtype = typename phi::PDDataTypeTraits::DataType; + auto moe_gemm_runner = phi::PTQMoeGemmRunner(); + // int8 gemm & dequant & add biad & act(optional) + moe_gemm_runner.moe_gemm_bias_act(input_tmp->data(), + weight, + dequant_out_scale, + reinterpret_cast(bias), // bias + reinterpret_cast(output_deq), + fwd_expert_count_cumsum, + fwd_bsz, + n, + k, + num_expert, + do_activation, + dev_ctx.stream()); +#else + PADDLE_THROW(platform::errors::InvalidArgument( + "this machine not support FusedGroupedMatMul use cutlass")); +#endif +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h index bd490555a8b86..18ba470667a0e 100644 --- a/paddle/fluid/operators/fused/quant_dequant_kernel.h +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -158,9 +158,10 @@ void quantize_kernel_launcher(const T* input, min_bound); } -template +template __global__ void dequantize_kernel(T* output, const int32_t* input, + const T* bias, const int m, // batch size const int n, // hidden const float quant_in_scale, @@ -172,16 +173,23 @@ __global__ void dequantize_kernel(T* output, phi::AlignedVector in_vec; phi::AlignedVector out_scale_vec; + phi::AlignedVector bias_vec; phi::AlignedVector out_vec; for (; idx < numel; idx += stride) { phi::Load(input + idx, &in_vec); phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + if (ComputeBias) { + phi::Load(bias + col_id, &bias_vec); + } #pragma unroll for (int i = 0; i < VecSize; ++i) { out_vec[i] = static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + if (ComputeBias) { + out_vec[i] += bias_vec[i]; + } } phi::Store(out_vec, output + idx); @@ -199,7 +207,22 @@ void dequantize_kernel_launcher(const int32_t* input, const float* dequant_out_scale_data) { dequantize_kernel <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( - output, input, m, n, quant_in_scale, dequant_out_scale_data); + output, input, nullptr, m, n, quant_in_scale, dequant_out_scale_data); +} + +template +void dequantize_addbias_kernel_launcher(const int32_t* input, + const T* bias, + T* output, + const int m, // m + const int n, // n + gpuStream_t stream, + GpuLaunchConfig* gpu_config, + const float quant_in_scale, + const float* dequant_out_scale_data) { + dequantize_kernel + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, input, bias, m, n, quant_in_scale, dequant_out_scale_data); } } // namespace operators diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 518aabbb09ead..2747bb0b0f82a 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -1020,3 +1020,9 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_string(jit_engine_type, "Predictor", "Choose default funciton type in JitLayer."); +/** + * CUTLASS related FLAG + */ +PADDLE_DEFINE_EXPORTED_bool(enable_moe_gemm_cutlass, + false, + "enable moe gemm cutlass ,default false"); diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index ba635694d7771..7005213b3817a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -189,7 +189,8 @@ if(WITH_CUTLASS) "fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/*.cu" "fusion/cutlass/cutlass_kernels/fpA_intB_gemm/*.cu" "fusion/cutlass/cutlass_kernels/moe_gemm/autogen/*.cu" - "fusion/cutlass/cutlass_kernels/moe_gemm/*.cu") + "fusion/cutlass/cutlass_kernels/moe_gemm/*.cu" + "fusion/cutlass/cutlass_kernels/ptq_moe_gemm/*.cu") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/linear_combination_ptq.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/linear_combination_ptq.h new file mode 100644 index 0000000000000..0814aff49dee1 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/linear_combination_ptq.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +template< + typename ElementCompute_, + typename ElementOutput_, + bool IsHeavy = false, + template class ActivationFunctor = cutlass::epilogue::thread::GELU_taylor +> +class LinearCombinationPTQ { +public: + using ElementCompute = ElementCompute_; + using ElementOutput = ElementOutput_; + using ElementSource = ElementOutput; // fp16, bias + using ElementAccumulator = int32_t; // int32 + + static int const kCount = 128 / cutlass::sizeof_bits::value; + static const ScaleType::Kind kScale = ScaleType::OnlyAlphaPerChannelScaling; + // static constexpr bool IsPerChannelScalingSupported = true; + static bool const kIsHeavy = IsHeavy; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; // int32 + using FragmentCompute = Array; + + static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; + + /// Host-constructable parameters structure + struct Params + { + float beta; ///< scales source tensor + bool do_act; ///< if true, apply activation function + + CUTLASS_HOST_DEVICE + Params(): + beta(float(0)), + do_act(false) { } + + CUTLASS_HOST_DEVICE + Params(float _beta, bool _do_act): + beta(_beta), + do_act(_do_act) { } + }; + + // add new fun + // struct fp32multiply_fp16add { + // using A = FragmentMul; + // using B = A; + // using C = FragmentOutput; + // NumericArrayConverter converter; + // CUTLASS_HOST_DEVICE + // C operator()(A const &a, B const &b, C const &c) const { + // C res = converter(a * b); + // return res + c; + // } + // }; + +private: + + // + // Data members + // + float beta_ = float(0); + bool do_act_ = false; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + LinearCombinationPTQ(Params const& params) { + beta_ = params.beta; + do_act_ = params.do_act; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; // always need source + } + + CUTLASS_HOST_DEVICE + bool is_beta_vector() const { + return false; // beta always not a vector + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + } + + /// Computes linear scaling with source: D = act(scale * accumulator + bias) + /// scalar_beta is next in scale + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& accumulator, // A * B, int32 + FragmentCompute const& scale, // scale, float + FragmentSource const& bias, // fp16 + bool is_print_debug) const { + // conver accum from int32 to fp32 + NumericArrayConverter accumulator_converter; + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + FragmentOutput intermediate; // fp16 + multiplies mul; // dequant + plus add; // add bias + NumericArrayConverter out_converter; + intermediate = add(out_converter(mul(scale, converted_accumulator)), bias); + + ActivationFunctor activation; + intermediate = do_act_ ? activation(intermediate) : intermediate; + return intermediate; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_default_epilogue_tensor_op.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_default_epilogue_tensor_op.h new file mode 100644 index 0000000000000..99e79d7b1c76f --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_default_epilogue_tensor_op.h @@ -0,0 +1,239 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/linear_combination_ptq.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_epilogue.h" + +#include "cutlass/layout/permute.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +// namespace detail { + +// /// Partial specialization for half <= int32_t x 8 epilogues avoids shared +// /// memory bank conflicts. +// template +// struct DefaultIteratorsTensorOp { +// using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< +// WarpShape, +// InstructionShape, +// int32_t, +// 32, +// 16, +// 8, +// 8 +// >; + +// using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< +// ThreadMap, +// int32_t, +// 32, +// 16, +// 8, +// 8 +// >; + +// static int const kFragmentsPerIteration = 2; +// }; + +// } // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess, + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute +> +struct B2BDefaultEpilogueTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; // 64/16 = 4 + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + // static bool const UseCUDAStore = platform::is_same::value; + static bool const UseCUDAStore = false; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput, // here we use fp16 + ScatterD, + PermuteDLayout, + UseCUDAStore + >; + + // for float scale + using ScaleTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + float, + kElementsPerAccess // 128/16 or 128/32 ? + >::Type; + + using ScaleTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + ScaleTileThreadMap, + float, + ScatterD, + PermuteDLayout, + UseCUDAStore + >; + + using AccumulatorFragmentIterator = typename platform::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + /// Support several implementations depending on structure of epilogue + using DefaultIterators = detail::DefaultIteratorsTensorOp< + ElementOutput,//half + ElementAccumulator,//int32 + kElementsPerAccess, + Shape, + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename OutputTileThreadMap::CompactedThreadMap + >; + + using WarpTileIterator = typename DefaultIterators::WarpTileIterator; + using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; + + /// Hard-coded padding elements added + using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; + + static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::B2BEpilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + ScaleTileIterator, // we add a scale tile iterator + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding, + kFragmentsPerIteration + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_epilogue.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_epilogue.h new file mode 100644 index 0000000000000..bcf1bfc7186df --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_epilogue.h @@ -0,0 +1,525 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + The shared memory resource is time-sliced across warps. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename ScaleTileIterator_, ///< Tile iterator reading and writing scale tensor + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class B2BEpilogue : + public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>, + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using ScaleTileIterator = ScaleTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// Number of warps per block + using WarpCount = typename Base::WarpCount; + + /// Number of threads per block + static int const kBlockThreads = 32 * WarpCount::kCount; + + /// Per-thread accumulator tile type + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Numerical accumulation element type + using ElementAccumulator = typename WarpMmaOperator::ElementC; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; + + /// Output element + // using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + // using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + // using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + // using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Vector type used by the global output iterator + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Scale type used by the global scale iterator + using ScaleAccessType = Array< + typename ScaleTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Vector type used by the shared output iterator + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + +public: + + /// Aspect for when epilogue source is needed + struct SourceAspectNeeded + { + OutputTileIterator source_iterator; + ScaleTileIterator scale_iterator; + + typename OutputTileIterator::Fragment source_fragment; + typename ScaleTileIterator::Fragment scale_fragment; + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + static void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + typename ScaleTileIterator::Fragment const &scale_fragment, + typename OutputTileIterator::Fragment const &source_fragment, + bool is_print_debug) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); // output + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); // accum + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); // bias + + ScaleAccessType const *scale_frag_ptr = + reinterpret_cast(&scale_fragment); // scale + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i], + scale_frag_ptr[i], + source_frag_ptr[i], + is_print_debug); + } + } + + /// Constructor + CUTLASS_DEVICE + SourceAspectNeeded(ScaleTileIterator scale_iterator, + OutputTileIterator source_iterator) : + source_iterator(source_iterator), + scale_iterator(scale_iterator) + { + source_fragment.clear(); + scale_fragment.clear(); + } + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + bool is_print_debug) + { + // Load addend source fragment from global memory + source_iterator.load(source_fragment); + ++source_iterator; + + scale_iterator.load(scale_fragment); + ++scale_iterator; + + apply_output_operator(output_fragment, + output_op, + aligned_accum_fragment, + scale_fragment, + source_fragment, + is_print_debug); + } + }; + + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index in the threadblock + int thread_idx; + + /// Warp index in the threadblock + int warp_idx; + bool is_print_debug = false; // DEBUG + +public: + + /// Constructor + CUTLASS_DEVICE + B2BEpilogue( + typename Base::SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx) ///< Id of thread within warp + : + Base(shared_storage, thread_idx, warp_idx, lane_idx), + BaseStreamK(thread_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx), + thread_idx(thread_idx), + warp_idx(warp_idx) + { + if (thread_idx == 0 && warp_idx == 0 && lane_idx == 0) { + is_print_debug = true; + } + } + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + void *element_workspace, + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + ScaleTileIterator scale_iterator, ///< Tile iterator for scale + OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + { + // Redcuce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Store fragment to shared memory + this->warp_tile_iterator_.store(accum_fragment); + + __syncthreads(); + + // Initialize/load source-fragment data + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + source_iterator += reduce_fragment_idx; + source_iterator.load(source_fragment); + + // typename OutputTileIterator::Fragment scale_fragment; + typename ScaleTileIterator::Fragment scale_fragment; + scale_fragment.clear(); + scale_iterator += reduce_fragment_idx; + scale_iterator.load(scale_fragment); + + // Load fragment from shared memory + typename SharedLoadIterator::Fragment aligned_accum_fragment; + shared_load_iterator_.load(aligned_accum_fragment); + + // Add fragments shared by other k partitions + if (kPartitionsK > 1) + { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + typename SharedLoadIterator::Fragment aligned_addend_fragment; + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_addend_fragment); + aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment); + } + } + + // Compute the output result + typename OutputTileIterator::Fragment output_fragment; + + // Apply the output operator + SourceAspectNeeded::apply_output_operator( + output_fragment, + output_op, + aligned_accum_fragment, + scale_fragment, + source_fragment, + is_print_debug); + + // Store the final result + destination_iterator += reduce_fragment_idx; + destination_iterator.store(output_fragment); + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements + /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. + /// 默认需要source,即bias + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + ScaleTileIterator scale_iterator, ///< Tile iterator for scale + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(scale_iterator, source_iterator)); + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements a + /// single codepath, regardless of whether the output op requires addend data to be loaded + CUTLASS_DEVICE + void unified( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + ScaleTileIterator scale_iterator, ///< Tile iterator for scale + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + // if (!output_op.is_source_needed()) + // { + // source_iterator.clear_mask(); + // __syncthreads(); // Dummy (CUDA 11.0) + // } + + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(scale_iterator, source_iterator)); + } + + + /// Streams the result to global memory + template + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + SourceAspect scale_source) + { + // Iterator over warp-level accumulator fragment + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) + { + + // + // Convert and store fragment + // + + __syncthreads(); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) + { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + this->warp_tile_iterator_.store(accum_fragment); + + if (p < Base::kFragmentsPerIteration - 1) { + this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + + + // + // Load fragments from shared memory + // + + __syncthreads(); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) + { + typename SharedLoadIterator::Fragment aligned_accum_fragment; + shared_load_iterator_.load(aligned_accum_fragment); + + if (p < Base::kFragmentsPerIteration - 1) + { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } + else if (kPartitionsK > 1) + { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + typename SharedLoadIterator::Fragment aligned_accum_fragment_addend; + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment_addend); + aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + scale_source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment, is_print_debug); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_moe_gemm_grouped.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_moe_gemm_grouped.h new file mode 100644 index 0000000000000..fcb854bde92c9 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_moe_gemm_grouped.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/b2b_default_epilogue_tensor_op.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, + ElementAccumulator>::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void + > +struct DefaultMoeGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout +> +struct DefaultMoeGemmGrouped< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + GroupScheduleMode_, + Operator, + SharedMemoryClear, + PermuteDLayout, + typename platform::enable_if< ! cutlass::is_complex::value>::type +> { + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, OperatorClass, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, SharedMemoryClear, false, false>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + using RegularEpilogue = + typename cutlass::epilogue::threadblock::B2BDefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, + EpilogueOutputOp, EpilogueOutputOp::kCount, false, PermuteDLayout>::Epilogue; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmGrouped< + Mma, + RegularEpilogue, // here we use user-defined Epilogue + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/ptq_moe_kernel.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/ptq_moe_kernel.h new file mode 100644 index 0000000000000..ee33b58c53133 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/ptq_moe_kernel.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct PTQMoeGemm { +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; // fp16 + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = float; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; // scale + ElementC* ptr_C; // bias + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): + problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, + const ElementB* ptr_B, + const ElementScale* weight_scales, + const ElementC* ptr_C, + ElementC* ptr_D, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr): + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + // ptr_C(const_cast(ptr_C)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(const_cast(total_rows_before_expert)), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0): + problem_visitor( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + PTQMoeGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using LayoutScale = typename Epilogue::ScaleTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; // must be 1 + static_assert(platform::is_same::value && kInterleave == 1, + "B must be rcol major & kInterleave == 1."); + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + typename LayoutA::LongIndex ldm_A = gemm_k; + typename LayoutB::LongIndex ldm_B = gemm_k; + + LayoutScale layout_scale(0); + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::ScaleTileIterator::Params params_scale(layout_scale); + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + typename Mma::FragmentC accumulators; + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = params.problem_visitor.last_row_for_problem[problem_idx]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_offset.n(), + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k(), problem_size.n()}, + thread_idx, + tb_offset_B); + + // typename Mma::FragmentC accumulators; + + accumulators.clear(); // full of zeros + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementScale* scale_ptr = params.weight_scales + problem_idx * gemm_n; + ElementC* ptr_bias = params.ptr_C + problem_idx * gemm_n; + ElementC* ptr_D = params.ptr_D + rows_to_jump * gemm_n; + + // scale + typename Epilogue::ScaleTileIterator iterator_scale( + params_scale, scale_ptr, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // bias + typename Epilogue::OutputTileIterator iterator_bias( + params_C, ptr_bias, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_scale, iterator_bias); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 66698b6c93270..07633109cb07c 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -39,6 +39,8 @@ struct TileShape { int n; }; +enum class CutlassGemmType : char { Default, WeightOnly, PTQ, Simt, Int8 }; + inline TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { switch (tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: @@ -107,19 +109,7 @@ inline bool is_valid_split_k_factor(const int64_t m, inline std::vector get_candidate_tiles( const int sm, - const bool is_weight_only, - const bool simt_configs_only, - const bool int8_configs_only) { - enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8 }; - - CutlassGemmType gemm_type = CutlassGemmType::Default; - if (simt_configs_only) { - gemm_type = CutlassGemmType::Simt; - } else if (is_weight_only) { - gemm_type = CutlassGemmType::WeightOnly; - } else if (int8_configs_only) { - gemm_type = CutlassGemmType::Int8; - } + CutlassGemmType gemm_type=CutlassGemmType::Default) { std::vector base_configs{ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, @@ -148,6 +138,8 @@ inline std::vector get_candidate_tiles( CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::PTQ: + return {CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}; default: return base_configs; } @@ -155,16 +147,13 @@ inline std::vector get_candidate_tiles( inline std::vector get_candidate_configs( int sm, - const bool is_weight_only, - const bool simt_configs_only, - const bool int8_configs_only, + CutlassGemmType gemm_type, const int max_split_k) { - std::vector tiles = get_candidate_tiles( - sm, is_weight_only, simt_configs_only, int8_configs_only); + std::vector tiles = get_candidate_tiles(sm, gemm_type); std::vector candidate_configs; - const int min_stages = int8_configs_only ? 3 : 2; - const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + const int min_stages = gemm_type == CutlassGemmType::Int8 ? 3 : 2; + const int max_stages = gemm_type == CutlassGemmType::Int8 ? 6 : (sm >= 80 ? 4 : 2); for (const auto& tile_config : tiles) { for (int stages = min_stages; stages <= max_stages; ++stages) { CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu index 8f6b7b0846fd5..17fcfe8eb4e56 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu @@ -431,8 +431,12 @@ void CutlassFpAIntBGemmRunner::run_gemm( if (it == config_cache_.end()) { static constexpr bool is_weight_only = !std::is_same::value; const bool is_weight_only_encoder = m >= 512 ? true : false; + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (is_weight_only) { + gemm_type = CutlassGemmType::WeightOnly; + } std::vector candidate_configs = get_candidate_configs( - sm_, is_weight_only, false, false, split_k_limit); + sm_, gemm_type, split_k_limit); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu index 787442795994d..1daf23982a554 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.cu @@ -527,8 +527,14 @@ void MoeGemmRunner::run_gemm( static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (is_weight_only) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (only_simt_configs) { + gemm_type = CutlassGemmType::Simt; + } std::vector candidate_configs = get_candidate_configs( - sm_, is_weight_only, only_simt_configs, false, split_k_limit); + sm_, gemm_type, split_k_limit); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h new file mode 100644 index 0000000000000..d97a0d8be195b --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm_configs.h" +#include "cuda_runtime_api.h" + +namespace phi { + +template /*The type used for scales/bias/compute*/ +class PTQMoeGemmRunner { +public: + PTQMoeGemmRunner(); + + void moe_gemm_bias_act(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + cudaStream_t stream); + +private: + void dispatch_to_arch(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy = nullptr); + +private: + int sm_; + int multi_processor_count_; +}; + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_fp16.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_fp16.cu new file mode 100644 index 0000000000000..18a3f06c960f3 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_fp16.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_template.h" + +namespace phi { +template class PTQMoeGemmRunner; +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_template.h new file mode 100644 index 0000000000000..d0837383de384 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm_template.h @@ -0,0 +1,505 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_moe_gemm_grouped.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/ptq_moe_kernel.h" + +#pragma GCC diagnostic pop + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/ptq_moe_gemm/ptq_moe_gemm.h" +#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h" +#include +#include +#include +#include + +namespace phi { + +// ============================= Variable batched Gemm things =========================== +template +void generic_moe_gemm_kernelLauncher(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + const int multi_processor_count, + cudaStream_t stream, + int* kernel_occupancy = nullptr) { + if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + throw std::runtime_error("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + } + using InputType = int8_t; + using ElementAccumulator = int32_t; + using ComputeType = float; + using OutputType = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + + using EpilogueOutputOp = + cutlass::epilogue::thread::LinearCombinationPTQ; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultMoeGemmGrouped< + InputType, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 16, // 128 / 8 + InputType, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 16, // 128 / 8 + OutputType, // output type + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::arch::OpMultiplyAddSaturate>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::PTQMoeGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + if (occupancy == 0) { + throw std::runtime_error( + "[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + } + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOutputOp::Params epilogue_op(ComputeType(0.f), do_activation); + + typename GemmGrouped::Arguments args(num_experts, + threadblock_count, + epilogue_op, + A, + B, + weight_scales, + reinterpret_cast(biases), + reinterpret_cast(C), + total_rows_before_expert, + gemm_n, + gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass moe gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + + std::to_string(Stages); + throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + } +}; + +template +void dispatch_gemm_config(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + throw std::runtime_error("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } +} + +template +void dispatch_moe_gemm_to_cutlass(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + break; + } +} + +template +PTQMoeGemmRunner::PTQMoeGemmRunner() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + sm_ = getSMVersion(); + check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +void PTQMoeGemmRunner::dispatch_to_arch(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy) { + if (sm_ >= 80 && sm_ <= 90) { + dispatch_moe_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + do_activation, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); + } else { + throw std::runtime_error("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } +} + +template +void PTQMoeGemmRunner::moe_gemm_bias_act(const int8_t* A, + const int8_t* B, + const float* weight_scales, + const T* biases, + T* C, + const int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + bool do_activation, + cudaStream_t stream) { + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + std::vector candidate_configs = get_candidate_configs(sm_, CutlassGemmType::PTQ, split_k_limit); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + do_activation, + candidate_configs[ii], + stream, + &occupancies[ii]); + } + + CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, + occupancies, + total_rows, + gemm_n, + gemm_k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + false); + + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + do_activation, + chosen_config, + stream); +} + +} // namespace phi \ No newline at end of file diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 0618611521ac9..e7e6e1303ec6e 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1461,6 +1461,7 @@ def trans_to_fp16(l): trans_to_fp16(self.ffn2_biases) self._dtype = dtype + class FusedMultiTransformerWeightOnly(Layer): """ FusedMultiTransfor on weight quant @@ -1804,7 +1805,6 @@ def trans_to_int8(l): self._dtype = "int8" - class FusedMultiTransformerINT8(Layer): def __init__(self, embed_dim, @@ -2350,7 +2350,7 @@ def __init__( # origin fmt config self.normalize_before = normalize_before - self._dtype = self._helper.get_default_dtype() + self._dtype = "float16" self._epsilon = epsilon self._trans_qkvw = trans_qkvw self._ring_id = ring_id @@ -2895,7 +2895,6 @@ def get_attr(attrs, idx): expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) expert_weight1 = self.create_parameter( - # shape=[d_model, dim_feedforward], shape=[dim_feedforward, d_model], attr=expert_weight1_attr, dtype=self._dtype, @@ -2910,7 +2909,6 @@ def get_attr(attrs, idx): default_initializer=nn.initializer.Constant(value=0.0) ) expert_weight2 = self.create_parameter( - # shape=[dim_feedforward, d_model], shape=[d_model, dim_feedforward], attr=expert_weight2_attr, dtype=self._dtype, @@ -2940,6 +2938,8 @@ def get_attr(attrs, idx): expert_bias1.name = "expert_" + expert_bias1.name expert_weight2.name = "expert_" + expert_weight2.name expert_bias2.name = "expert_" + expert_bias2.name + expert_weight1_out_scale.name = "expert_" + expert_weight1_out_scale.name + expert_weight2_out_scale.name = "expert_" + expert_weight2_out_scale.name self.expert_weights1.append(expert_weight1) self.expert_biases1.append(expert_bias1) self.expert_weights2.append(expert_weight2) @@ -2951,6 +2951,8 @@ def get_attr(attrs, idx): self.name = name # int8 self._int8_decorate() + self._share_expert_param(num_layers, num_expert, dim_feedforward, d_model) + self._dtype = "int8" def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): """ @@ -3036,8 +3038,71 @@ def trans_to_int8(l): trans_to_int8(self.linear_weights) trans_to_int8(self.expert_weights1) trans_to_int8(self.expert_weights2) - self._dtype = "int8" + # self._dtype = "int8" + + def _share_expert_param(self, num_layers, num_expert, dim_feedforward, d_model): + """ + share_param + """ + def shard_tensor(dst_tensor, parent_tensor, pos): + tmp = parent_tensor.value().get_tensor()._slice(pos, pos + 1) + dst_tensor.value().get_tensor()._share_data_buffer(tmp, False) + + self.shared_weights1, self.shared_scales1, self.shared_biases1 = ParameterList(), ParameterList(), ParameterList() + self.shared_weights2, self.shared_scales2, self.shared_biases2 = ParameterList(), ParameterList(), ParameterList() + + for i in range(num_layers): + shared_weight1 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_weight1", + shape=[num_expert, dim_feedforward, d_model], + dtype="uint8", + default_initializer=nn.initializer.Constant(value=0)) + shared_scale1 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_scale1", + shape=[num_expert, dim_feedforward], + dtype="float32", + default_initializer=nn.initializer.Constant(value=0.0)) + shared_bias1 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_bias1", + shape=[num_expert, dim_feedforward], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + + shared_weight2 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_weight2", + shape=[num_expert, d_model, dim_feedforward], + dtype="uint8", + default_initializer=nn.initializer.Constant(value=0)) + shared_scale2 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_scale2", + shape=[num_expert, d_model], + dtype="float32", + default_initializer=nn.initializer.Constant(value=0.0)) + shared_bias2 = paddle.create_parameter( + # name=f"moe.expert.layer{i}.shared_bias2", + shape=[num_expert, d_model], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(value=0.0)) + _to_dtype(shared_weight1, "int8") + _to_dtype(shared_weight2, "int8") + + for j in range(self.num_expert): + expert_idx = j + i * self.num_expert + shard_tensor(self.expert_weights1[expert_idx], shared_weight1, j) + shard_tensor(self.expert_weight1_out_scales[expert_idx], shared_scale1, j) + shard_tensor(self.expert_biases1[expert_idx], shared_bias1, j) + shard_tensor(self.expert_weights2[expert_idx], shared_weight2, j) + shard_tensor(self.expert_weight2_out_scales[expert_idx], shared_scale2, j) + shard_tensor(self.expert_biases2[expert_idx], shared_bias2, j) + + self.shared_weights1.append(shared_weight1) + self.shared_scales1.append(shared_scale1) + self.shared_biases1.append(shared_bias1) + + self.shared_weights2.append(shared_weight2) + self.shared_scales2.append(shared_scale2) + self.shared_biases2.append(shared_bias2) class FusedMultiTransformerMoeWeightOnly(Layer): """ @@ -3476,4 +3541,4 @@ def shard_tensor(dst_tensor, parent_tensor, pos): self.shared_weights2.append(shared_weight2) self.shared_scales2.append(shared_scale2) - self.shared_biases2.append(shared_bias2) + self.shared_biases2.append(shared_bias2) \ No newline at end of file