Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move workspace memory-allocation to PyTorch #661

Merged
merged 5 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions csrc/includes/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,10 @@ class Context {
return _ctx;
}

void GenWorkSpace(size_t size)
void SetWorkSpace(void* workspace)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}

_workSpaceSize = size;
if (!workspace) { throw std::runtime_error("Workspace is null."); }
_workspace = workspace;
}

void* GetWorkSpace() { return _workspace; }
Expand Down Expand Up @@ -172,6 +165,5 @@ class Context {
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
std::vector<std::array<int, 3>> _gemm_algos;
};
5 changes: 4 additions & 1 deletion csrc/includes/ds_transformer_cuda.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ class BertTransformerLayer {
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
inline int GetIntermediateSize() const { return _intermediate_size; }

void SetSeqLength(int seq_len, int bsz);
void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }

private:
void Initialize();
Expand Down
40 changes: 24 additions & 16 deletions csrc/transformer/ds_transformer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ size_t get_workspace_size(int maxBatchSize,
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
return workSpacesize * sizeof(T);
return workSpacesize; // * sizeof(T);
}

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
Expand Down Expand Up @@ -123,7 +123,6 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
assert(_seq_length <= 1024);

Initialize();
}
Expand All @@ -136,14 +135,6 @@ BertTransformerLayer<T>::~BertTransformerLayer()
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
Context::Instance().GenWorkSpace(get_workspace_size<T>(_batch_size,
_seq_length,
_hidden_size,
_intermediate_size,
_heads,
_training,
_gelu_checkpoint));

if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}

Expand Down Expand Up @@ -574,17 +565,14 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
}

template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
{
_seq_length = seq_len;

_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);

Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
}

template <typename T>
Expand Down Expand Up @@ -707,9 +695,19 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}

auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
Expand Down Expand Up @@ -877,9 +875,19 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}

auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
grad_output.options());
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
Expand Down