Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient OP for nn.conv2d #10439

Closed
wants to merge 1 commit into from
Closed

Fix gradient OP for nn.conv2d #10439

wants to merge 1 commit into from

Conversation

Lyken17
Copy link
Contributor

@Lyken17 Lyken17 commented Mar 2, 2022

Current backward impl raises error for nn.Conv2d, either normal conv or depth-wise conv. See the code attached below.

import numpy as np

import tvm
from tvm import relay
from tvm.contrib import graph_executor

depthwise_conv_code = """
fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
  %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3]);
  nn.bias_add(%0, %v0_bias)
}
"""

normal_conv_code = """
fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 3, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
  %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=1, channels=3, kernel_size=[3, 3]);
  nn.bias_add(%0, %v0_bias)
}
"""

SEMVER = '#[version = "0.0.5"]\n'
expr = tvm.parser.parse_expr(SEMVER + normal_conv_code)
fmod = tvm.IRModule.from_expr(expr)

mod = relay.transform.InferType()(fmod)
bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")

bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod)

Current impl raises error

data types float32 and void do not match in BroadcastRel

  [bt] (2) 3   libtvm.dylib                        0x000000015e5f1fa0 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 1588
  [bt] (1) 2   libtvm.dylib                        0x000000015d381ac0 tvm::DiagnosticContext::Render() + 484
  [bt] (0) 1   libtvm.dylib                        0x000000015d0aa848 tvm::runtime::detail::LogFatal::Entry::Finalize() + 84
  File "/Users/ligeng/Workspace/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

This PR aims to roll back the impl to previous version while fixing the bug for depth-wise (previous backward does not work for depth-wise conv).

Current backward impl raises error for nn.Conv2d, either normal conv or depth-wise conv. See the code attached below.

```python
import numpy as np

import tvm
from tvm import relay
from tvm.contrib import graph_executor

normal_conv_code = """
fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
  %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3]);
  nn.bias_add(%0, %v0_bias)
}
"""

depthwise_conv_code = """
fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 3, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
  %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=1, channels=3, kernel_size=[3, 3]);
  nn.bias_add(%0, %v0_bias)
}
"""

SEMVER = '#[version = "0.0.5"]\n'
expr = tvm.parser.parse_expr(SEMVER + normal_conv_code)
fmod = tvm.IRModule.from_expr(expr)

mod = relay.transform.InferType()(fmod)
bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")

bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod)
```

This PR aims to roll back the impl to previous version while fixing the bug for depth-wise (previous backward does not work for depth-wise conv).
@Lyken17
Copy link
Contributor Author

Lyken17 commented Mar 2, 2022

cc @masahi since it is related with #9954

My env is built with commit 111b2da and the error can be easily reproduced via above program.

While I agree set a customized OP for conv2d_grad will ease Cudnn / Cutlass to accelerate, there is something wrong with current impl.

@junrushao
Copy link
Member

CC @Hzfengsy @YuchenJin @ZihengJiang

@masahi
Copy link
Member

masahi commented Mar 2, 2022

You need to add out_dtype="float32" to your mod. It works.

def @main(%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) -> (Tensor[(1, 3, 32, 32), float32], (Tensor[(1, 3, 32, 32), float32], Tensor[(3, 1, 3, 3), float32], Tensor[(3), float32])) {
  let %x_0: Tensor[(1, 3, 32, 32), float32] = %input0;
  let %x_1: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_0) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  let %x_2: Tensor[(3, 1, 3, 3), float32] = %v0_weight;
  let %x_3: Tensor[(3, 1, 3, 3), float32] = zeros_like(%x_2) /* ty=Tensor[(3, 1, 3, 3), float32] */;
  let %x_4: Tensor[(3), float32] = %v0_bias;
  let %x_5: Tensor[(3), float32] = zeros_like(%x_4) /* ty=Tensor[(3), float32] */;
  let %x_6: Tensor[(1, 3, 32, 32), float32] = nn.conv2d(%x_0, %x_2, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 3, 32, 32), float32] */;
  let %x_7: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_6) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  let %x_8: Tensor[(1, 3, 32, 32), float32] = nn.bias_add(%x_6, %x_4) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  let %x_9: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_8) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %0 = ones_like(%x_8) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %1 = collapse_sum_like(%0, %x_6) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %5 = (
    let %x_10: Tensor[(1, 3, 32, 32), float32] = add(%x_7, %1) /* ty=Tensor[(1, 3, 32, 32), float32] */;
    %2 = sum(%0, axis=[1], exclude=True) /* ty=Tensor[(3), float32] */;
    let %x_11: Tensor[(3), float32] = add(%x_5, %2) /* ty=Tensor[(3), float32] */;
    %3 = nn.conv2d_transpose(%x_10, %x_2, padding=[1, 1, 1, 1], groups=3, kernel_layout="IOHW") /* ty=Tensor[(1, 1, 32, 32), float32] */;
    let %x_12: Tensor[(1, 3, 32, 32), float32] = add(%x_1, %3) /* ty=Tensor[(1, 3, 32, 32), float32] */;
    %4 = nn.conv2d_backward_weight(%x_10, %x_0, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], kernel_layout="NCHW", out_layout="OIHW", out_dtype="float32") /* ty=Tensor[(3, 1, 3, 3), float32] */;
    let %x_13: Tensor[(3, 1, 3, 3), float32] = add(%x_3, %4) /* ty=Tensor[(3, 1, 3, 3), float32] */;
    (%x_12, %x_13, %x_11)
  );
  (%x_8, %5)
}

@masahi
Copy link
Member

masahi commented Mar 2, 2022

See

const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype;

Usually, if out_dtype is not provided, we can use the dtype from inputs. But in your case, the parser somehow returns conv2d with out_dtype == void, which is weird. So the output dtype of the wgrad becomes void. So the correct fix is to do some change in the parser.

And the reason out_dtype is important for wgrad is, if the input is fp16, we might want to set the out dtype to be fp32. Without out dtype, if the input is fp16 we will end up computing wgrad with fp16 precision, which is probably not what we want.

@masahi masahi closed this Mar 2, 2022
@Lyken17
Copy link
Contributor Author

Lyken17 commented Mar 2, 2022

Thanks for the explaination. I agree it is important to annote the return dtype in gradient calculation. But such dtype information is missing in most tvm use-cases. For example, when we load model using tvm.relay.frontend.from_xxx, current impl cannot handle it properly

import numpy as np

import torch
import torch.nn as nn

import tvm
from tvm import relay
from tvm.contrib import graph_executor

net = nn.Sequential(
    nn.Conv2d(3, 3, 3, padding=1, groups=3)
)

input_shape = [1, 3, 32, 32]
input_data = torch.randn(input_shape)
input_name = "input0"
shape_list = [(input_name, input_data.shape)]

scripted_model = torch.jit.trace(net, input_data).eval()
fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list, default_dtype="float32")

mod = relay.transform.InferType()(fmod)
bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")

bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod)
data types float32 and void do not match in BroadcastRel
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

@Lyken17
Copy link
Contributor Author

Lyken17 commented Mar 2, 2022

Also when I try to build model with relay API, such dtype still cannot be handled properly.

input = relay.var("input", shape=[1,3,32,32], dtype="float32")
weight = relay.var("weight", shape=[3,1,3,3], dtype="float32")
out = relay.nn.conv2d(input, weight, groups=3, channels=3)
fn = relay.Function([input, weight], out)
fmod = tvm.IRModule.from_expr(fn)

mod = relay.transform.InferType()(fmod)

bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_expr)
bwd_mod = relay.transform.InferType()(bwd_mod)
data types float32 and void do not match in BroadcastRel
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

There might be some missing part that make out_type null in this case. For now, I would recommend put "float32" as a default value to avoid these errors.

backward_data = _nn.conv2d_transpose(
grad,
weight,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.dilation,
groups=attrs.groups,
output_padding=output_padding,
)
backward_weight = _nn.conv2d_backward_weight(
grad,
data,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.dilation,
groups=attrs.groups,
channels=attrs.channels,
kernel_size=(filter_h, filter_w),
grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
data_layout=attrs.data_layout,
kernel_layout=attrs.kernel_layout,
out_dtype=attrs.out_dtype,
)

shoud be changed to

    if attrs.out_dtype == "":
        out_dtype = data.checked_type.dtype
    backward_data = _nn.conv2d_transpose(
        grad,
        weight,
        strides=attrs.strides,
        padding=attrs.padding,
        dilation=attrs.dilation,
        groups=attrs.groups,
        output_padding=output_padding,
        out_dtype=out_dtype,
    )

    backward_weight = _nn.conv2d_backward_weight(
        grad,
        data,
        strides=attrs.strides,
        padding=attrs.padding,
        dilation=attrs.dilation,
        groups=attrs.groups,
        channels=attrs.channels,
        kernel_size=(filter_h, filter_w),
        grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
        data_layout=attrs.data_layout,
        kernel_layout=attrs.kernel_layout,
        out_dtype=out_dtype,
    )

@masahi
Copy link
Member

masahi commented Mar 2, 2022

We can update

const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype;

to

  const auto dw_dtype = (param->out_dtype == DataType() or param->out_dtype.is_void())
                            ? grad->dtype
                            : param->out_dtype;

I don't know why the default out dtype is void... but this should do the job. Welcome to send a PR.

@masahi
Copy link
Member

masahi commented Mar 3, 2022

#10459

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants