Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supporting different intermediate sizes other than 4 * hidden_dim #389

Merged
merged 4 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions csrc/transformer/ds_transformer_cuda.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ template <typename T>
size_t get_workspace_size(int maxBatchSize,
int seq_len,
int hidden_size,
int intermediate_size,
int heads,
bool training,
bool gelu_checkpoint)
{
size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += (std::max((4 * size_t(maxBatchSize) * seq_len * hidden_size),
workSpacesize += (std::max((size_t(maxBatchSize) * seq_len * intermediate_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
Expand Down Expand Up @@ -92,15 +93,15 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
false,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
4 * hidden_size,
intermediate_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
4 * hidden_size,
intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, _intermediate_size)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio,
_batch_size * _heads * _seq_length,
_seq_length)),
Expand Down Expand Up @@ -143,8 +144,13 @@ BertTransformerLayer<T>::~BertTransformerLayer()
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
Context::Instance().GenWorkSpace(get_workspace_size<T>(
_batch_size, _seq_length, _hidden_size, _heads, _training, _gelu_checkpoint));
Context::Instance().GenWorkSpace(get_workspace_size<T>(_batch_size,
_seq_length,
_hidden_size,
_intermediate_size,
_heads,
_training,
_gelu_checkpoint));

if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
Expand Down Expand Up @@ -343,7 +349,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;

T* ff2_buf = buf_3 + (_gelu_checkpoint ? 3 : 1) * small_buf_size;
T* ff2_buf = (_gelu_checkpoint ? (buf_2 + _intermediate_size) : (buf_3 + small_buf_size));
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);

cudaStream_t streams[2] = {_stream, _stream};
Expand Down
12 changes: 8 additions & 4 deletions deepspeed/ops/transformer/transformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self,
batch_size,
max_seq_length,
hidden_size,
intermediate_size,
heads,
attn_dropout_ratio,
hidden_dropout_ratio,
Expand All @@ -26,6 +27,7 @@ def __init__(self,
self.layer_id = -1
self.batch_size = batch_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size if intermediate_size > 0 else 4 * hidden_size
self.max_seq_length = max_seq_length
self.heads = heads
self.attn_dropout_ratio = attn_dropout_ratio
Expand Down Expand Up @@ -88,6 +90,7 @@ def __init__(self,
batch_size=-1,
max_seq_length=-1,
hidden_size=-1,
intermediate_size=-1,
heads=-1,
attn_dropout_ratio=-1,
hidden_dropout_ratio=-1,
Expand All @@ -106,6 +109,7 @@ def __init__(self,
self).__init__(batch_size,
max_seq_length,
hidden_size,
intermediate_size,
heads,
attn_dropout_ratio,
hidden_dropout_ratio,
Expand Down Expand Up @@ -432,12 +436,12 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.inter_w = nn.Parameter(
torch.Tensor(4 * self.config.hidden_size,
torch.Tensor(self.config.intermediate_size,
self.config.hidden_size))
self.inter_b = nn.Parameter(torch.Tensor(4 * self.config.hidden_size))
self.inter_b = nn.Parameter(torch.Tensor(self.config.intermediate_size))
self.output_w = nn.Parameter(
torch.Tensor(self.config.hidden_size,
4 * self.config.hidden_size))
self.config.intermediate_size))
self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
Expand Down Expand Up @@ -485,7 +489,7 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
self.config.batch_size,
self.config.hidden_size,
self.config.heads,
4 * self.config.hidden_size,
self.config.intermediate_size,
self.config.max_seq_length,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/test_cuda_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def create_models(ds_config):
hidden_size=ds_config.hidden_size,
num_hidden_layers=ds_config.num_hidden_layers,
num_attention_heads=ds_config.heads,
intermediate_size=4 * ds_config.hidden_size,
intermediate_size=ds_config.intermediate_size,
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
Expand All @@ -162,12 +162,12 @@ def create_models(ds_config):
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[4].data.fill_(1.0)
weights.append(
nn.Parameter(torch.Tensor(4 * ds_config.hidden_size,
nn.Parameter(torch.Tensor(ds_config.intermediate_size,
ds_config.hidden_size)))
weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
4 * ds_config.hidden_size)))
ds_config.intermediate_size)))
weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[7].data.fill_(1.0)
Expand All @@ -177,7 +177,7 @@ def create_models(ds_config):
for i in range(4):
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[i + 1].data.zero_()
biases.append(nn.Parameter(torch.Tensor(4 * ds_config.hidden_size)))
biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
biases[5].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[6].data.zero_()
Expand Down Expand Up @@ -274,6 +274,7 @@ def test_backward(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
Expand Down
57 changes: 30 additions & 27 deletions tests/unit/test_cuda_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_models(ds_config):
num_hidden_layers=ds_config.num_hidden_layers,
num_attention_heads=ds_config.heads,
batch_size=ds_config.batch_size,
intermediate_size=4 * ds_config.hidden_size,
intermediate_size=ds_config.intermediate_size,
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
Expand All @@ -130,12 +130,12 @@ def create_models(ds_config):
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[4].data.fill_(1.0)
weights.append(
nn.Parameter(torch.Tensor(4 * ds_config.hidden_size,
nn.Parameter(torch.Tensor(ds_config.intermediate_size,
ds_config.hidden_size)))
weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
4 * ds_config.hidden_size)))
ds_config.intermediate_size)))
weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[7].data.fill_(1.0)
Expand All @@ -145,7 +145,7 @@ def create_models(ds_config):
for i in range(4):
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[i + 1].data.zero_()
biases.append(nn.Parameter(torch.Tensor(4 * ds_config.hidden_size)))
biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
biases[5].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[6].data.zero_()
Expand Down Expand Up @@ -207,24 +207,24 @@ def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
(8,1024,512,16,3,False,True),
(8,1536,128,24,3,False,False),
(8,1536,128,24,3,False,True),
(8,2048,128,32,3,False,False),
(8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True),
# (64,1024,128,16,3,True,False),
# (64,1024,128,16,3,True,True),
# (8,1024,384,16,3,True,False),
# (8,1024,384,16,3,True,True),
# (8,1024,512,16,3,True,False),
# (8,1024,512,16,3,True,True),
# (64,1024,128,16,3,False,False),
# (64,1024,128,16,3,False,True),
# (8,1024,384,16,3,False,False),
# (8,1024,384,16,3,False,True),
# (8,1024,512,16,3,False,False),
# (8,1024,512,16,3,False,True),
# (8,1536,128,24,3,False,False),
# (8,1536,128,24,3,False,True),
# (8,2048,128,32,3,False,False),
# (8,2048,128,32,3,False,True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disable these tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put them back!

# (8,2560,128,40,3,False,False),
# (8,2560,128,40,3,False,True),
]) # yapf: disable
def test_forward(batch_size,
hidden_size,
Expand All @@ -242,6 +242,7 @@ def test_forward(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
Expand All @@ -256,11 +257,11 @@ def test_forward(batch_size,

@pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,3,1024,512,16,3,True,False),
(8,7,1024,512,16,3,True,True),
(8,3,1024,512,16,3,False,False),
(8,7,1024,512,16,3,False,True),
]) # yapf: disable
# (8,3,1024,512,16,3,True,False),
# (8,7,1024,512,16,3,True,True),
# (8,3,1024,512,16,3,False,False),
# (8,7,1024,512,16,3,False,True),
]) ## yapf: disable
def test_forward_with_small_bsz(batch_size,
small_bsz,
hidden_size,
Expand All @@ -278,6 +279,7 @@ def test_forward_with_small_bsz(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
Expand Down Expand Up @@ -312,6 +314,7 @@ def test_forward_stochastic(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
Expand Down