diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 93110e313642..0b81cb9c7ec6 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -72,7 +72,15 @@ def _np_randn_from_type(t, scale=1, mean=0): def check_grad( - func, inputs=None, test_inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0 + func, + inputs=None, + test_inputs=None, + eps=1e-6, + atol=1e-5, + rtol=1e-3, + scale=None, + mean=0, + mode="higher_order", ): """Perform numerical gradient checking given a relay function. @@ -112,7 +120,7 @@ def check_grad( """ fwd_func = run_infer_type(func) - bwd_func = run_infer_type(gradient(fwd_func)) + bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) if scale is None: scale = 10 * eps diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 80a567d9cb65..bcf75de7915b 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -168,13 +168,6 @@ def test_global_avg_pool2d_grad(): def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode="higher_order"): - try: - import torch - import torch.nn.functional as F - except ImportError: - print("Skip because pytorch is not installed") - return - dtype = "float32" data = relay.var("data", shape=dshape, dtype=dtype) weight = relay.var("weight", shape=wshape, dtype=dtype) @@ -182,49 +175,7 @@ def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mod data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups ) fwd_func = relay.Function([data, weight], conv) - fwd_func = run_infer_type(fwd_func) - bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) - - data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True) - weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True) - out_pt = F.conv2d( - data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation, groups=groups - ) - grad_output_pt = torch.ones(out_pt.shape) - grad_input_pt = ( - F.grad.conv2d_input( - dshape, - weight_pt, - grad_output_pt, - stride=strides, - padding=padding, - dilation=dilation, - groups=groups, - ) - .detach() - .numpy() - ) - grad_weight_pt = ( - F.grad.conv2d_weight( - data_pt, - wshape, - grad_output_pt, - stride=strides, - padding=padding, - dilation=dilation, - groups=groups, - ) - .detach() - .numpy() - ) - - for target, ctx in tvm.testing.enabled_targets(): - data = tvm.nd.array(data_pt.detach().numpy(), ctx) - weight = tvm.nd.array(weight_pt.detach().numpy(), ctx) - intrp = relay.create_executor(ctx=ctx, target=target) - op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight) - np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4) - np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4) + check_grad(fwd_func, mode=mode) @tvm.testing.uses_gpu