diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 97cba7964000..4d443c3fb5ba 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -83,8 +83,12 @@ bool Conv2DRel(const Array& types, channels = param->channels; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -701,7 +705,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 9f49f61c0d5f..cb1b0854d27f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -45,7 +45,6 @@ def test_conv2d_infer_type(): (2, 10, 3, 3), "float32") # infer by shape of w, mixed precision - n, c, h, w = tvm.var("n"), 10, 224, 224 x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) @@ -55,6 +54,16 @@ def test_conv2d_infer_type(): assert yy.checked_type == relay.TensorType( (n, 2, 222, 222), "int32") + # infer shape in case of different dtypes for input and weight. + n, c, h, w = tvm.var("n"), 10, 224, 224 + x = relay.var("x", relay.TensorType((n, c, h, w), "uint8")) + w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) + y = relay.nn.conv2d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 222, 222), "int32") + # Infer with a different layout n, c, h, w = 4, 32, 224, 224 x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))