Skip to content

Commit

Permalink
[QNN] Use Int16 upcast in Fallback Conv2D. Fix test names. (apache#4329)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and Xingyu Zhou committed Nov 15, 2019
1 parent 9ab329e commit 53b48df
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 269 deletions.
41 changes: 22 additions & 19 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,26 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
* \brief Fallback to simpler lowering for dilation or depthwise conv.
* \param data The input expr.
* \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \return The fallback lowered sequence of Relay expr.
* \note In case of dilation, normal lowering would require a dilated pool.
* Since, we don't have dilated pool, we fallback to a simpler sequence of
* Relay operations. This will potentially lead to performance degradation
* as the convolution is called on int32 tensors instead of int8 tensors.
*/
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& zp_data,
const Expr& zp_kernel, const QnnConv2DAttrs* param) {
auto shifted_data = data;
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs* param) {
// Upcast the zero point to Int16.
auto zp_data = MakeConstantScalar(Int(16), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(16), param->kernel_zero_point);

auto shifted_data = Cast(data, Int(16));
if (param->input_zero_point != 0) {
shifted_data = Subtract(Cast(data, Int(32)), zp_data);
shifted_data = Subtract(Cast(data, Int(16)), zp_data);
}

auto shifted_kernel = weight;
auto shifted_kernel = Cast(weight, Int(16));
if (param->kernel_zero_point != 0) {
shifted_kernel = Subtract(Cast(weight, Int(32)), zp_kernel);
shifted_kernel = Subtract(Cast(weight, Int(16)), zp_kernel);
}

return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
Expand Down Expand Up @@ -186,7 +187,6 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
/*
* \brief Calculates the second term in the qnn.conv2d lowering sequence.
* \param padded_data The padded data expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
Expand All @@ -200,8 +200,11 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
* followed by a reduce on the C axis. Using avg_pool2d also gives an
* opportunity to reuse alter_op_layout infrastructure.
*/
Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnConv2DAttrs* param,
int kernel_h, int kernel_w, int out_channels) {
Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
int kernel_w, int out_channels) {
// Constant Expr for the kernel zero point.
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);

auto casted_t2 = Cast(padded_data, Int(32));

// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
Expand Down Expand Up @@ -241,7 +244,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
/*
* \brief Calculates the third term in the qnn.conv2d lowering sequence.
* \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param out_channels The number of output channels.
Expand All @@ -254,8 +256,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
* a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
* format.
*/
Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAttrs* param,
int batch_size, int out_channels) {
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
int out_channels) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);

// Find which dimensions are C, R, S.
Array<Integer> axes_t3;
if (param->kernel_layout == "OIHW") {
Expand Down Expand Up @@ -415,21 +420,19 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
int batch_size, in_channels, out_channels, kernel_h, kernel_w;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
GetWorkload(arg_types, param);
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);

// Fallback to int32 conv if there is dilation or depthwise conv2d
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
return Conv2DFallBack(data, weight, zp_data, zp_kernel, param);
return Conv2DFallBack(data, weight, param);
}

auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, zp_kernel, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, zp_data, param, batch_size, out_channels);
auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}
Expand Down
52 changes: 26 additions & 26 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)

def no_zero_point_test():
def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def kernel_zero_point_test():
def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype)


def input_zero_point_test():
def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def both_zero_point_test():
def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def layout_test():
def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -378,7 +378,7 @@ def layout_test():



def padding_test():
def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
Expand Down Expand Up @@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def dilation_test():
def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
Expand All @@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype)


def const_folding_test():
def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()

def kernel_size_1x1_test():
def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand All @@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def tflite_large_irregular_test():
def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
Expand Down Expand Up @@ -526,7 +526,7 @@ def tflite_large_irregular_test():
golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
np.testing.assert_equal(qnn_output, golden_output)

def tflite_output_multiplier_greater_than_one():
def test_tflite_output_multiplier_greater_than_one():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -570,7 +570,7 @@ def tflite_output_multiplier_greater_than_one():
0, 0)).reshape(2, 3, 1, 2)
np.testing.assert_equal(qnn_output, golden_output)

def tflite_anistropic_strides():
def test_tflite_anistropic_strides():
# uint8 input
data_shape = (1, 1, 3, 6)
data_dtype = 'uint8'
Expand Down Expand Up @@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def broadcast_layout_test():
def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -641,16 +641,16 @@ def broadcast_layout_test():
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")

if __name__ == "__main__":
no_zero_point_test()
input_zero_point_test()
kernel_zero_point_test()
both_zero_point_test()
layout_test()
padding_test()
dilation_test()
const_folding_test()
kernel_size_1x1_test()
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()
test_no_zero_point()
test_input_zero_point()
test_kernel_zero_point()
test_both_zero_point()
test_layout()
test_padding()
test_dilation()
test_const_folding()
test_kernel_size_1x1()
test_tflite_large_irregular()
test_broadcast_layout()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
File renamed without changes.
Loading

0 comments on commit 53b48df

Please sign in to comment.