Skip to content

Commit

Permalink
support dynamic sequence length in transformer kernels (microsoft#424)
Browse files Browse the repository at this point in the history
Co-authored-by: Conglong Li <[email protected]>
  • Loading branch information
RezaYazdaniAminabadi and conglongli authored Sep 22, 2020
1 parent 71f7df3 commit f0f2a70
Show file tree
Hide file tree
Showing 17 changed files with 789 additions and 891 deletions.
2 changes: 1 addition & 1 deletion csrc/includes/context.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)

#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 4096
#define DS_MAXIMUM_NUM_BLOCKS 262144

inline int DS_GET_BLOCKS(const int N)
{
Expand Down
23 changes: 5 additions & 18 deletions csrc/includes/custom_cuda_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);

template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);

template <typename T>
Expand All @@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);

// Custom fused bias add with layer normalization
Expand All @@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* means = nullptr,
T* vals_hat = nullptr);
bool training,
T* vars,
T* means);

template <typename T>
void launch_bias_residual_layer_norm(T* vals,
Expand All @@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* vals_hat = nullptr,
bool save_vals = false);
bool training,
T* vars);

template <typename T>
void launch_layerNorm_backward_fused_add(const T* out_grad1,
Expand All @@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
template <typename T>
Expand All @@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
Expand All @@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);

Expand All @@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
Expand All @@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);

Expand Down
10 changes: 5 additions & 5 deletions csrc/includes/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ class Dropout {
public:
struct Config {
float ratio;
uint32_t batch, dim;
uint32_t dim;
bool training;

Config(float r, uint32_t batch, uint32_t dim)
: ratio(r), batch(batch), dim(dim), training(true)
{
}
Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}

float RATIO() const { return training ? ratio : 0.0; }
inline void SetDim(uint32_t d) { dim = d; }
};

Dropout(const Config& config) : _config(config), _mask(nullptr) {}
Expand Down Expand Up @@ -70,6 +68,8 @@ class Dropout {

Config GetConfig() const { return _config; }

inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }

private:
uint8_t* _mask;
Config _config;
Expand Down
12 changes: 9 additions & 3 deletions csrc/includes/ds_transformer_cuda.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,17 @@ class BertTransformerLayer {

void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr);
uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);

inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }

void SetSeqLength(int seq_len, int bsz);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);

Expand All @@ -150,8 +156,8 @@ class BertTransformerLayer {
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _norm_layer2;
Normalize_Layer<T> _norm_layer3;
Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax;
Expand Down
13 changes: 3 additions & 10 deletions csrc/includes/gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@ template <typename T>
class Gelu {
public:
struct Config {
uint32_t batch_size;
uint32_t seq_length;
uint32_t intermediate_size;
Config(uint32_t batch, uint32_t seq, uint32_t inter_size)
: batch_size(batch), seq_length(seq), intermediate_size(inter_size)
{
}
Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};

Gelu(const Config& config) : _config(config) {}
Expand All @@ -28,14 +23,12 @@ class Gelu {
T* output,
cudaStream_t stream)
{
launch_bias_gelu<T>(
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}

void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{
launch_d_gelu<T>(
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}

private:
Expand Down
Loading

0 comments on commit f0f2a70

Please sign in to comment.