-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gradient OP for nn.conv2d #10439
Conversation
Current backward impl raises error for nn.Conv2d, either normal conv or depth-wise conv. See the code attached below. ```python import numpy as np import tvm from tvm import relay from tvm.contrib import graph_executor normal_conv_code = """ fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) { %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3]); nn.bias_add(%0, %v0_bias) } """ depthwise_conv_code = """ fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 3, 3, 3), float32], %v0_bias: Tensor[(3), float32]) { %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=1, channels=3, kernel_size=[3, 3]); nn.bias_add(%0, %v0_bias) } """ SEMVER = '#[version = "0.0.5"]\n' expr = tvm.parser.parse_expr(SEMVER + normal_conv_code) fmod = tvm.IRModule.from_expr(expr) mod = relay.transform.InferType()(fmod) bwd_expr = relay.transform.gradient(mod["main"], mode="first_order") bwd_mod = tvm.IRModule.from_expr(bwd_expr) bwd_mod = relay.transform.InferType()(bwd_mod) ``` This PR aims to roll back the impl to previous version while fixing the bug for depth-wise (previous backward does not work for depth-wise conv).
You need to add
|
See tvm/src/relay/op/nn/convolution.cc Line 1803 in be17697
Usually, if And the reason |
Thanks for the explaination. I agree it is important to annote the return dtype in gradient calculation. But such dtype information is missing in most tvm use-cases. For example, when we load model using import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
net = nn.Sequential(
nn.Conv2d(3, 3, 3, padding=1, groups=3)
)
input_shape = [1, 3, 32, 32]
input_data = torch.randn(input_shape)
input_name = "input0"
shape_list = [(input_name, input_data.shape)]
scripted_model = torch.jit.trace(net, input_data).eval()
fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list, default_dtype="float32")
mod = relay.transform.InferType()(fmod)
bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod) data types float32 and void do not match in BroadcastRel
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace. |
Also when I try to build model with relay API, such dtype still cannot be handled properly. input = relay.var("input", shape=[1,3,32,32], dtype="float32")
weight = relay.var("weight", shape=[3,1,3,3], dtype="float32")
out = relay.nn.conv2d(input, weight, groups=3, channels=3)
fn = relay.Function([input, weight], out)
fmod = tvm.IRModule.from_expr(fn)
mod = relay.transform.InferType()(fmod)
bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod) data types float32 and void do not match in BroadcastRel
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace. There might be some missing part that make tvm/python/tvm/relay/op/_tensor_grad.py Lines 417 to 440 in 8f6fa8f
shoud be changed to
|
We can update tvm/src/relay/op/nn/convolution.cc Line 1803 in be17697
to
I don't know why the default out dtype is |
Current backward impl raises error for nn.Conv2d, either normal conv or depth-wise conv. See the code attached below.
Current impl raises error
This PR aims to roll back the impl to previous version while fixing the bug for depth-wise (previous backward does not work for depth-wise conv).