From 03905c072ea4113bac524ea95904b5b9be92161f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 8 Jul 2019 20:59:50 +0000 Subject: [PATCH] Relaxing convolution infer checks. - Weight dtype can be different than idtype. So, using the weight tensor to set the dtype of weight. - For conv2d NCHWc operator, the weight can be of any dimension. For int8 computation on Intel, it can be 7D. Relaxing the weight type checking. --- src/relay/op/nn/convolution.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 97cba79640005..4d443c3fb5ba0 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -83,8 +83,12 @@ bool Conv2DRel(const Array& types, channels = param->channels; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -701,7 +705,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout);