-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move workspace memory-allocation to PyTorch #661
Move workspace memory-allocation to PyTorch #661
Conversation
csrc/includes/context.h
Outdated
@@ -64,17 +64,18 @@ class Context { | |||
return _ctx; | |||
} | |||
|
|||
void GenWorkSpace(size_t size) | |||
void GenWorkSpace(void* workspace) // (size_t size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename it to SetWorkSpace()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Elton, will do! :)
csrc/includes/context.h
Outdated
assert(_workspace == nullptr); | ||
cudaMalloc(&_workspace, size); | ||
} else if (_workSpaceSize < size) { | ||
cudaFree(_workspace); | ||
cudaMalloc(&_workspace, size); | ||
} | ||
|
||
_workSpaceSize = size; | ||
_workSpaceSize = size;*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_workSpaceSize can be deleted from context class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes agreed!
csrc/includes/context.h
Outdated
if (!_workspace) { | ||
if (!workspace) { throw std::runtime_error("Workspace is null."); } | ||
_workspace = workspace; | ||
/*if (!_workspace) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove those commented code if they are useless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just commented them to verify that this new way is working properly, after some tests I will remove the comments
_training, | ||
_gelu_checkpoint)); | ||
|
||
// Context::Instance().GenWorkSpace(get_workspace_size<T>(_batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it?
layer->GetNumHeads(), | ||
layer->IsTrainingMode(), | ||
layer->GeluCheckpoint())}, | ||
options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use g_output.options() here instead of creating a new option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested Bert seq512 and worked well. I think it is good to go after applying Elton's comments.
No description provided.