-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] GradientCell Relay Pass #5039
Conversation
had you rebased? |
Just to be clear, this pass save 50% memory for reverse mode ad. |
1783a45
to
807c8e8
Compare
* module must have TypeDefinition of GradCell (defined in gradient.rly) | ||
*/ | ||
Constructor getGradCellConstructor(IRModule module, std::string name_hint) { | ||
TypeData gradCell = module->LookupTypeDef("GradCell"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you refactor this function into module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, added it as GetConstructor
function although it may be more appropriate to be a separate PR
@@ -219,6 +219,18 @@ def DeadCodeElimination(inline_once=False): | |||
""" | |||
return _ffi_api.DeadCodeElimination(inline_once) | |||
|
|||
def GradientCell(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is gradient cell really the right name for this given it is just making 1/0s lazy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to LazyGradientInit
, highlight that this pass should be used after reverse mode ad to lazily initiate tensor gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some style comments on the PR, please fix then I'll merge.
It would be great if you can add some information on why this pass is useful / needed. I understand that its important to conserve memory when computing gradients, but converting all tensors to a new type just lazily evaluate constant fills seems somewhat drastic. Can you give some numbers on the impact of this pass for typical functions? |
@jwfromm this pass should be used in conjunction with the gradient pass. Since tensors for reverse-mode ad are instantiated as 0-filled tensors but are actually not useful during the forward pass, this PR should reduce memory allocation by 50% during the forward pass. I don't have any data on this but mathematically this would make sense |
5456d56
to
0d78212
Compare
thanks @jroesch ! |
* save * gradient.rly * fix * NOT WORKING: gradient cell pass * test gradient pass * fixed basic call ops * more tests * fix bug * transform calls to one ones_like zero zero_like * maintenance stuff * fix linting * linting * linting * throw default * remove unrelated changes * import gradent.rly in pass * comment * linting * remove changes to test files * move gradient_cell.cc to transforms * revert change * update files with new commits * type * wrapper function to main outermost function type * fix linting * fix unsigned and signed int comparison * review * GetConstructor definition in module and change op comparison * update node instantiations * increase code readability Co-authored-by: Marisa Kirisame <[email protected]>
* save * gradient.rly * fix * NOT WORKING: gradient cell pass * test gradient pass * fixed basic call ops * more tests * fix bug * transform calls to one ones_like zero zero_like * maintenance stuff * fix linting * linting * linting * throw default * remove unrelated changes * import gradent.rly in pass * comment * linting * remove changes to test files * move gradient_cell.cc to transforms * revert change * update files with new commits * type * wrapper function to main outermost function type * fix linting * fix unsigned and signed int comparison * review * GetConstructor definition in module and change op comparison * update node instantiations * increase code readability Co-authored-by: Marisa Kirisame <[email protected]>
Add GradientCell relay pass which introduces the GradCell datatype. This pass can delay memory allocation and potentially improve memory usage and performance by delaying instantiations of zero-filled/one-filled tensors until necessary.