Skip to content

Commit

Permalink
add & fix zeus int8 (PaddlePaddle#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyan01 authored Nov 30, 2023
1 parent eaf9b14 commit e1c65ba
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 25 deletions.
27 changes: 7 additions & 20 deletions paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */

namespace paddle {
namespace operators {
// #define _DEBUG_FUSED_MULTI_TRANSFORMER

template <typename T>
static void PrintMatrix(const T* mat_d, int num, std::string name) {
Expand Down Expand Up @@ -77,7 +78,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
auto ffn1_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN1OutScale");
auto ffn2_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN2OutScale");

bool remove_padding = false;
bool remove_padding = false;
auto *sequence_lengths = ctx.Input<phi::DenseTensor>("SeqLengths");
if (sequence_lengths) {
remove_padding = true;
Expand Down Expand Up @@ -190,25 +191,10 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
}

auto out_seq_len = seq_len;
int time_step_cpu = 0;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step->place(),
platform::CPUPlace(),
platform::errors::PreconditionNotMet(
"The place of input(TimeStep) must be CPUPlace."));
// cache_seq_len
int time_step_value = time_step->data<int>()[0];
PADDLE_ENFORCE_GT(time_step_value,
0,
platform::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d",
time_step_value));
PADDLE_ENFORCE_EQ(
seq_len,
1,
platform::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
out_seq_len += time_step_value;
time_step_cpu = src_mask->dims()[3] - 1;
out_seq_len += time_step_cpu;
} else {
out_seq_len += cache_offset;
}
Expand Down Expand Up @@ -513,7 +499,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
max_seq_len,
num_head,
dim_head,
src_mask->dims()[3] - 1,
time_step_cpu,
rotary_emb_dims,
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
Expand Down Expand Up @@ -966,6 +952,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
dropout_mask_out_data,
ffn2_in_scale[i],
ffn2_out_scales[i]->data<float>(),
0,
1.0);
}
} else {
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def pure_fp16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True
if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8):
if isinstance(layer, (paddle.incubate.nn.FusedMultiTransformerMoeINT8,
paddle.incubate.nn.FusedMultiTransformerINT8)):
continue
if (layer._dtype == 'float16') or isinstance(
layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D,
Expand All @@ -198,9 +199,6 @@ def pure_fp16_initialize(models):
paddle.incubate.nn.FusedMoELayer)):
layer._amp_decorate(dtype='float16')
continue
# if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8):
# layer._amp_decorate(dtype='int8')
# continue
layer._to_impl(dtype='float16',
include_sublayers=False,
floating_only=True)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/incubate/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformerINT8 # noqa: F401
from .layer.fused_linear import FusedLinear # noqa: F401
from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401
from .layer.fused_transformer import FusedMoELayer # noqa: F401
Expand All @@ -27,6 +28,7 @@
'FusedFeedForward',
'FusedTransformerEncoderLayer',
'FusedMultiTransformer',
'FusedMultiTransformerINT8',
'FusedMultiTransformerMoe',
'FusedMultiTransformerMoeINT8',
'FusedLinear',
Expand Down
Loading

0 comments on commit e1c65ba

Please sign in to comment.