Skip to content

Commit

Permalink
Relaxing convolution infer checks. (apache#3511)
Browse files Browse the repository at this point in the history
- Weight dtype can be different than idtype. So, using the weight tensor to set
the dtype of weight.
- For conv2d NCHWc operator, the weight can be of any dimension. For int8
computation on Intel, it can be 7D. Relaxing the weight type checking.
  • Loading branch information
anijain2305 authored and Wei Chen committed Jul 11, 2019
1 parent a89c2ad commit e00562d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ bool Conv2DRel(const Array<Type>& types,
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
DataType weight_dtype = data->dtype;
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
Expand Down Expand Up @@ -701,7 +705,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);

Expand Down
11 changes: 10 additions & 1 deletion tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_conv2d_infer_type():
(2, 10, 3, 3), "float32")

# infer by shape of w, mixed precision

n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
Expand All @@ -55,6 +54,16 @@ def test_conv2d_infer_type():
assert yy.checked_type == relay.TensorType(
(n, 2, 222, 222), "int32")

# infer shape in case of different dtypes for input and weight.
n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
y = relay.nn.conv2d(x, w, out_dtype="int32")
assert "out_dtype=\"int32\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 222, 222), "int32")

# Infer with a different layout
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))
Expand Down

0 comments on commit e00562d

Please sign in to comment.