Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add gradient for reshape #3901

Merged
merged 1 commit into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ def dense_grad(orig, grad):
collapse_sum_like(data * transpose(grad), weight)]


@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]


@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
Expand Down
5 changes: 2 additions & 3 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest

import tvm
from tvm import relay
Expand Down Expand Up @@ -58,6 +59,4 @@ def test_negative_grad():


if __name__ == "__main__":
test_clip()
test_transpose_grad()
test_negative_grad()
pytest.main()
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, check_grad

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
Expand Down Expand Up @@ -247,6 +247,7 @@ def verify_reshape(shape, newshape, oshape):
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

func = relay.Function([x], z)
check_grad(func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be a good idea to separate this test out since gradient is implemented separately from the op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allow us to get the same coverage for both test_op_level3 and test_op_grad_level3. I think we should move toward this style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I think grad implementations could also use some refactoring to move around; it would be more convenient implementation-wise to put grad implementations next to their forward op implementations as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That also sounds good. Tensor_grad is getting big.

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
Expand Down