Skip to content

Commit

Permalink
[QNN] Conv2D with dilation support. (apache#4796)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and alexwong committed Feb 26, 2020
1 parent 886e6d8 commit 0db7a1f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
18 changes: 11 additions & 7 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
}

/*
* \brief Fallback to simpler lowering for dilation or grouped conv.
* \brief Fallback to simpler lowering for dilation (when non-zero kernel point) or grouped conv.
* \param data The input expr.
* \param weight The weight expr.
* \param input_zero_point The input zero point expr.
* \param kernel_zero_point The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \return The fallback lowered sequence of Relay expr.
* \note In case of dilation, normal lowering would require a dilated pool.
* Since, we don't have dilated pool, we fallback to a simpler sequence of
* Relay operations. This will potentially lead to performance degradation
* as the convolution is called on int32 tensors instead of int8 tensors.
* \note In case of dilation with non-zero kernel zero point, normal lowering would require a
* dilated pool. Since, we don't have dilated pool, we fallback to a simpler sequence of Relay
* operations. This will potentially lead to performance degradation as the convolution is called on
* int32 tensors instead of int8 tensors.
*/
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
const Expr& kernel_zero_point, const Conv2DAttrs* param) {
Expand Down Expand Up @@ -598,12 +598,16 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);

// Fallback to int32 conv if there is dilation or grouped conv2d
// Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d
// For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to
// traverse the elements in dilated manner. Currently, we do not have strided pool. So, in case of
// dilated conv with non-zero kernel point, we fall back to simpler but slow lowering.

CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) ||
(param->groups != 1 && !is_depthwise(param))) {
return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
} else if (is_depthwise(param)) {
CHECK_NE(channel_multiplier, -1);
Expand Down
25 changes: 24 additions & 1 deletion tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_padding():
def test_dilation():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):

# uint8 input
# Non-zero kernel point - fall back to simpler lowering.
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -518,6 +518,29 @@ def test_dilation():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

# Zero kernel point
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(2, 2),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)


def test_const_folding():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
Expand Down

0 comments on commit 0db7a1f

Please sign in to comment.