Skip to content

Commit

Permalink
[QNN] Lowering for Depthwise Convolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Nov 16, 2019
1 parent 5d66e7a commit 9d45f03
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 53 deletions.
2 changes: 2 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {

Expr MakeConcatenate(Expr data, int axis);

Expr MakeRepeat(Expr data, int repeats, int axis);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);

Expr MakeStack(Expr data, int axis);
Expand Down
208 changes: 183 additions & 25 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ namespace qnn {
// relay.op.qnn.conv2d
TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);

bool QnnConv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
Expand All @@ -50,17 +48,22 @@ bool QnnConv2DRel(const Array<Type>& types,
const auto* param = attrs.as<QnnConv2DAttrs>();
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
}

// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
using WorkloadType = std::tuple<int, int, int, int, int>;
bool is_depthwise(const QnnConv2DAttrs* param) {
return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) &&
param->groups != 1;
}

// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier
using WorkloadType = std::tuple<int, int, int, int, int, int>;

/*
* \brief Get the conv parameters like batch_size, kernel_height etc.
Expand All @@ -84,26 +87,38 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv

const auto kernel_shape = get_shape(arg_types[1]);
int out_channels, kernel_h, kernel_w;
int channel_multiplier = -1;
if (param->kernel_layout == "OIHW") {
out_channels = get_const_int(kernel_shape[0]);
kernel_h = get_const_int(kernel_shape[2]);
kernel_w = get_const_int(kernel_shape[3]);
if (is_depthwise(param)) {
channel_multiplier = get_const_int(kernel_shape[1]);
}
} else if (param->kernel_layout == "HWIO") {
kernel_h = get_const_int(kernel_shape[0]);
kernel_w = get_const_int(kernel_shape[1]);
out_channels = get_const_int(kernel_shape[3]);
if (is_depthwise(param)) {
channel_multiplier = get_const_int(kernel_shape[2]);
}
} else if (param->kernel_layout == "HWOI") {
kernel_h = get_const_int(kernel_shape[0]);
kernel_w = get_const_int(kernel_shape[1]);
out_channels = get_const_int(kernel_shape[2]);
if (is_depthwise(param)) {
channel_multiplier = get_const_int(kernel_shape[3]);
}
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
}
return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w);

return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w,
channel_multiplier);
}

/*
* \brief Fallback to simpler lowering for dilation or depthwise conv.
* \brief Fallback to simpler lowering for dilation or grouped conv.
* \param data The input expr.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
Expand Down Expand Up @@ -166,6 +181,130 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
return padded_data;
}

/*
* \brief Calculates the second term in the qnn.conv2d depthwise lowering sequence.
* \param padded_data The padded data expr.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \param channel_multiplier The channel/depth multiplier.
* \return The sequence of Relay operators for term2.
* \note The term2 looks like this
*
* Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
*
* Second term is not directly representable by one Relay operator.
* However, deeper analysis shows that we can reduce r,s using avg_pool2d,
* followed by repeat on the C axis by cm times.
*/
Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
int kernel_w, int channel_multiplier) {
// Constant Expr for the kernel zero point.
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);

auto casted_t2 = Cast(padded_data, Int(32));

// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
// Since, this is integer division (floor), we can first multiply the data by the pool_size and
// then perform avg_pool2d. Reversing this causes inaccuracy due to floor division.
auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
Array<IndexExpr> padding({0, 0});

// If the pool_size is 1x1, we don't need avg_pool2d.
auto reduced_t2 = scaled_hw_t2;
if (kernel_h * kernel_w != 1) {
reduced_t2 =
AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
false, // ceil_mode
false); // count_include_pad
}

auto multiplied_t2 = reduced_t2;
if (param->kernel_zero_point != 1) {
multiplied_t2 = Multiply(zp_kernel, reduced_t2);
}

// Reduce the C dimension. Find the dimension.
int axis_t2 = 0;
if (param->data_layout == "NCHW") {
axis_t2 = 1;
} else if (param->data_layout == "NHWC") {
axis_t2 = 3;
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
auto repeated_t2 = multiplied_t2;
if (channel_multiplier != 1) {
repeated_t2 = MakeRepeat(multiplied_t2, channel_multiplier, axis_t2);
}
return repeated_t2;
}

/*
* \brief Calculates the third term in the qnn.conv2d depthwise lowering sequence.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
* \param out_channels The number of output channels.
* \param channel_multiplier The channel/depth multiplier.
* \return The sequence of Relay operatos for term3.
* \note The term3 looks like this
*
* Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
*
* This can be achieved by calling reduce on r and s axis. The tensor can be then reshaped to
* (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations.
*/
Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels,
int channel_multiplier) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);

// Find which dimensions are R, S.
Array<Integer> axes_t3;
if (param->kernel_layout == "OIHW") {
// For OIHW kernel layout, HW are reduce axis
axes_t3 = {2, 3};
} else if (param->kernel_layout == "HWIO") {
axes_t3 = {0, 1};
} else if (param->kernel_layout == "HWOI") {
axes_t3 = {0, 1};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
}
auto reduced_t3 = Sum(Cast(weight, Int(32)), axes_t3, false, false);

// Find the newshape depending on NCHW/NHWC layout.
Array<Integer> newshape;
if (param->data_layout == "NCHW") {
newshape = {1, out_channels * channel_multiplier, 1, 1};
} else if (param->data_layout == "NHWC") {
newshape = {1, 1, 1, out_channels * channel_multiplier};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
auto reshaped_t3 = Reshape(reduced_t3, newshape);

if (param->input_zero_point == 1) {
return reshaped_t3;
}
return Multiply(zp_data, reshaped_t3);
}

/*
* \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(r, s) zp_a * zp_w
*/
Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int kernel_w) {
int scalar_term4 = param->input_zero_point * param->kernel_zero_point * kernel_h * kernel_w;
return MakeConstantScalar(Int(32), scalar_term4);
}

/*
* \brief Calculates the first term in the qnn.conv2d lowering sequence.
* \param data The input expr.
Expand Down Expand Up @@ -245,7 +384,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
* \brief Calculates the third term in the qnn.conv2d lowering sequence.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param out_channels The number of output channels.
* \return The sequence of Relay operatos for term3.
* \note The term3 looks like this
Expand All @@ -256,8 +394,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
* a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
* format.
*/
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
int out_channels) {
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);

Expand All @@ -278,9 +415,9 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
// Find the newshape depending on NCHW/NHWC layout.
Array<Integer> newshape;
if (param->data_layout == "NCHW") {
newshape = {batch_size, out_channels, 1, 1};
newshape = {1, out_channels, 1, 1};
} else if (param->data_layout == "NHWC") {
newshape = {batch_size, 1, 1, out_channels};
newshape = {1, 1, 1, out_channels};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
Expand All @@ -295,7 +432,6 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
/*
* \brief Calculates the fourth term in the qnn.conv2d lowering sequence.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
Expand All @@ -305,8 +441,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int batch_size, int in_channels, int kernel_h,
int kernel_w) {
Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h, int kernel_w) {
int scalar_term4 =
param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w;
return MakeConstantScalar(Int(32), scalar_term4);
Expand Down Expand Up @@ -391,7 +526,20 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* gives an opportunity to reuse alter_op_layout infrastructure.
* 3) For dilated conv, in current lowering, we need dilated pool. So as
* a workaround, we fall back to simpler lowering using int32 conv if
* the conv is dilated. We fallback also in case of depthwise conv.
* the conv is dilated. We fallback also in case of grouped conv.
*
* For depthwise, we can similarly unroll the computation. The intial compute is as follows
* wehere cm = channel_multiplier
*
* Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w)
* * (Qa(n, oc/cm, oh + r, ow + s) - zp_a)
*
* This can be written as
*
* Sigma(r, s) Qw(oc/m, oc%/m, r, s) * Qa(n, oc/cm, oh + r, ow + s)
* - Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
* - Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
* - Sigma(r, s) zp_a * zp_w
*
* The whole process can be broken down into following steps
* * Assertion checks for existing support, fallback if necessary
Expand All @@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
param->kernel_layout == "HWOI")
<< "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";

int batch_size, in_channels, out_channels, kernel_h, kernel_w;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
GetWorkload(arg_types, param);

// Fallback to int32 conv if there is dilation or depthwise conv2d
// Fallback to int32 conv if there is dilation or grouped conv2d

CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
return Conv2DFallBack(data, weight, param);
} else if (is_depthwise(param)) {
CHECK_NE(channel_multiplier, -1);
auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 =
DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier);
auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier);
auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}

auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
auto term3 = Conv2DThirdTerm(weight, param, out_channels);
auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}

Expand Down
Loading

0 comments on commit 9d45f03

Please sign in to comment.