diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 89f4ca82f52ea..08624e1c02df3 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -290,6 +290,12 @@ def dense_grad(orig, grad): collapse_sum_like(data * transpose(grad), weight)] +@register_gradient("reshape") +def reshape_grad(orig, grad): + """Gradient of reshape""" + return [reshape_like(grad, orig.args[0])] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 9324555b59dc4..cc57361538dfd 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay @@ -58,6 +59,4 @@ def test_negative_grad(): if __name__ == "__main__": - test_clip() - test_transpose_grad() - test_negative_grad() + pytest.main() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f1d91a255fbb8..03fe7f76a1634 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, check_grad def run_infer_type(expr): mod = relay.Module.from_expr(expr) @@ -247,6 +247,7 @@ def verify_reshape(shape, newshape, oshape): assert zz.checked_type == relay.ty.TensorType(oshape, "float32") func = relay.Function([x], z) + check_grad(func) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") ref_res = np.reshape(x_data, oshape) for target, ctx in ctx_list():