Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] GradientCell Relay Pass #5039

Merged
merged 30 commits into from
Mar 24, 2020
Merged

[Relay] GradientCell Relay Pass #5039

merged 30 commits into from
Mar 24, 2020

Conversation

hypercubestart
Copy link
Contributor

Add GradientCell relay pass which introduces the GradCell datatype. This pass can delay memory allocation and potentially improve memory usage and performance by delaying instantiations of zero-filled/one-filled tensors until necessary.

@MarisaKirisame
Copy link
Contributor

had you rebased?
you should ask some ppl to help review. since you are new I will help you this time.
@junrushao1994 @slyubomirsky @altanh @wweic can you guys help review?

@MarisaKirisame
Copy link
Contributor

Just to be clear, this pass save 50% memory for reverse mode ad.

@hypercubestart hypercubestart changed the title GradientCell Relay Pass [Relay] GradientCell Relay Pass Mar 11, 2020
src/relay/pass/gradient.cc Outdated Show resolved Hide resolved
tests/python/relay/test_pass_gradient_cell.py Outdated Show resolved Hide resolved
src/relay/transforms/gradient_cell.cc Outdated Show resolved Hide resolved
src/relay/transforms/gradient_cell.cc Outdated Show resolved Hide resolved
src/relay/transforms/gradient_cell.cc Outdated Show resolved Hide resolved
* module must have TypeDefinition of GradCell (defined in gradient.rly)
*/
Constructor getGradCellConstructor(IRModule module, std::string name_hint) {
TypeData gradCell = module->LookupTypeDef("GradCell");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you refactor this function into module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, added it as GetConstructor function although it may be more appropriate to be a separate PR

@@ -219,6 +219,18 @@ def DeadCodeElimination(inline_once=False):
"""
return _ffi_api.DeadCodeElimination(inline_once)

def GradientCell():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is gradient cell really the right name for this given it is just making 1/0s lazy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to LazyGradientInit, highlight that this pass should be used after reverse mode ad to lazily initiate tensor gradients

Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some style comments on the PR, please fix then I'll merge.

@jwfromm
Copy link
Contributor

jwfromm commented Mar 23, 2020

It would be great if you can add some information on why this pass is useful / needed. I understand that its important to conserve memory when computing gradients, but converting all tensors to a new type just lazily evaluate constant fills seems somewhat drastic. Can you give some numbers on the impact of this pass for typical functions?

@hypercubestart
Copy link
Contributor Author

@jwfromm this pass should be used in conjunction with the gradient pass. Since tensors for reverse-mode ad are instantiated as 0-filled tensors but are actually not useful during the forward pass, this PR should reduce memory allocation by 50% during the forward pass. I don't have any data on this but mathematically this would make sense

@hypercubestart hypercubestart force-pushed the ad branch 2 times, most recently from 5456d56 to 0d78212 Compare March 24, 2020 03:16
@jroesch jroesch merged commit e6dd8e1 into apache:master Mar 24, 2020
@hypercubestart
Copy link
Contributor Author

thanks @jroesch !

@MarisaKirisame MarisaKirisame deleted the ad branch March 24, 2020 19:06
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* save

* gradient.rly

* fix

* NOT WORKING: gradient cell pass

* test gradient pass

* fixed basic call ops

* more tests

* fix bug

* transform calls to one ones_like zero zero_like

* maintenance stuff

* fix linting

* linting

* linting

* throw default

* remove unrelated changes

* import gradent.rly in pass

* comment

* linting

* remove changes to test files

* move gradient_cell.cc to transforms

* revert change

* update files with new commits

* type

* wrapper function to main outermost function type

* fix linting

* fix unsigned and signed int comparison

* review

* GetConstructor definition in module and change op comparison

* update node instantiations

* increase code readability

Co-authored-by: Marisa Kirisame <[email protected]>
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* save

* gradient.rly

* fix

* NOT WORKING: gradient cell pass

* test gradient pass

* fixed basic call ops

* more tests

* fix bug

* transform calls to one ones_like zero zero_like

* maintenance stuff

* fix linting

* linting

* linting

* throw default

* remove unrelated changes

* import gradent.rly in pass

* comment

* linting

* remove changes to test files

* move gradient_cell.cc to transforms

* revert change

* update files with new commits

* type

* wrapper function to main outermost function type

* fix linting

* fix unsigned and signed int comparison

* review

* GetConstructor definition in module and change op comparison

* update node instantiations

* increase code readability

Co-authored-by: Marisa Kirisame <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants