Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Sep 5, 2019
1 parent 57cd27f commit 51d704d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ def dense_grad(orig, grad):
collapse_sum_like(data * transpose(grad), weight)]


@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]


@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
Expand Down
5 changes: 2 additions & 3 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest

import tvm
from tvm import relay
Expand Down Expand Up @@ -58,6 +59,4 @@ def test_negative_grad():


if __name__ == "__main__":
test_clip()
test_transpose_grad()
test_negative_grad()
pytest.main()
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, check_grad

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
Expand Down Expand Up @@ -247,6 +247,7 @@ def verify_reshape(shape, newshape, oshape):
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

func = relay.Function([x], z)
check_grad(func)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
Expand Down

0 comments on commit 51d704d

Please sign in to comment.