Skip to content

Commit

Permalink
[RFC] Improve quantized convolution performance for armv8 architectur…
Browse files Browse the repository at this point in the history
…es (apache#5754)

* Improve quantized conv2d performance for armv8

Signed-off-by: Giuseppe Rossini <[email protected]>
Change-Id: I3a3d29f5332dd9b3354e8e0dfb24677a521f9c8f

* Add ASF header to conv2d_gemm.py

Change-Id: I33853279e39c849ae1b555a9c91d7557985a0a35

* Run clang-format-10 on c++ files

Change-Id: Ieee22f032e595dabfc1616ab33466fcbf8d94365

* Fix pylint errors/warnings

Change-Id: I435d4d7bca7500db99547f4401fdc0d0995a1ff4

* Fix pylint errors/warnings in topi

Change-Id: I2fc1ad8453e9020072ab967c849df5390c2967b5

* Fix legalizations tests for aarch64

Change-Id: I0a67a49a7849f52ef7d57b9292ce9125bbb7cb2c

* Reintroduce conv2d_nhwc_spatial_pack.arm_cpu and int16 cast

Change-Id: I91b67fabd475e90a9b75f2dd5ecfee851265e0bb

* Switch type of legalization depending on the strategy used

Change-Id: I9a03040a8c40a6cd2658ed14c3751e05a8e19f2b

* Revert last commit

Change-Id: Ice34101e358e3ce8ebfb12c58f73e910ba5de8e8

* Fix the auto-tuner by registering the correct schedules

Change-Id: Id9273688b2620e1ea849ab01b4c46af8fbf37fd0

* Address review comments

Change-Id: Ia1755a0af7b6d159072d9f0c93c932c481101e48

* Improve usability and readability of conv2d_gemm_weight_transform

Change-Id: I3333186bbc2fe4054b58ce15d910e3be7b315482

* Change variable name to weight in Conv2DGemmWeightTransformRel

Change-Id: Ifb5f1f33af7512fe67c6b049b20a42a0bb2d26c9

* Fix clang-10 linting errors

Change-Id: I25ccc844d9cee23766096e1daddb6180abc413a6

* Trigger tests

Change-Id: Id37706fb7cf77a87a3cc817ecf8046297d9ca95a
  • Loading branch information
Giuseppe Rossini authored and Trevor Morris committed Jun 30, 2020
1 parent 2df7a6f commit c01fa93
Show file tree
Hide file tree
Showing 14 changed files with 1,065 additions and 14 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeig
}
};

/*! \brief Attributes used in gemm weight transformation operators */
struct ConvGemmWeightTransformAttrs : public tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
int tile_rows;
int tile_cols;

TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm.");
TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm.");
}
};

/*! \brief Attributes used in convolution operators with winograd algorithm */
struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
int tile_size;
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,23 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

# conv2d_gemm related operators
reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform",
strategy.conv2d_gemm_without_weight_transform_strategy)
reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv2d_gemm_weight_transform"""
out = topi.nn.conv2d_gemm_weight_transform(
inputs[0], attrs.tile_rows, attrs.tile_cols)
return [out]

reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform",
strategy.schedule_conv2d_gemm_weight_transform)
reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
Expand Down
91 changes: 91 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,74 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype)


def contrib_conv2d_gemm_without_weight_transform(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with gemm algorithm.
The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_gemm_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_gemm_without_weight_transform(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)


def contrib_conv2d_nchwc(data,
kernel,
strides=(1, 1),
Expand Down Expand Up @@ -2204,6 +2272,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)


def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
r"""Weight Transformation part for 2D convolution with gemm algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv2d_gemm_without_weight_transform
Parameters
----------
weights : tvm.relay.Expr
The weight expressions.
tile_rows: int
Tile rows of the weight transformation for ConvGemm.
tile_cols: int
Tile columns of the weight transformation for ConvGemm.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)


def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
Expand Down
42 changes: 42 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd),
name='conv2d_direct_simd.micro_dev')
elif kernel_layout == "HWIO":
is_aarch64 = "aarch64" in str(isa.target)

if is_aarch64 and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized.arm_cpu")

strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
Expand Down Expand Up @@ -246,6 +254,40 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out
format(layout))
return strategy

def wrap_compute_conv2d_gemm(topi_compute):
"""wrap topi compute for conv2d_gemm"""

def _compute_conv2d_gemm(attrs, inputs, out_type):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs.get_str("out_dtype")
channels = attrs['channels']
kernel_size = attrs['kernel_size']
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
return [topi_compute(inputs[0], inputs[1], strides, padding,
dilation, out_dtype, kernel_size, channels)]

return _compute_conv2d_gemm

@conv2d_gemm_without_weight_transform_strategy.register("arm_cpu")
def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom arm cpu strategy"""
layout = attrs.data_layout
data = inputs[0]
strategy = _op.OpStrategy()

if layout == "NHWC" and data.dtype in ['int8', 'uint8']:
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized_without_transform.arm_cpu")
else:
raise RuntimeError(
"Unsupported conv2d_gemm_without_weight_transform layout {0} with datatype {1}".
format(layout, data.dtype))
return strategy

@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"])
def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_transpose arm cpu strategy"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, t
"""conv2d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")

# conv2d_gemm_without_weight_transform
@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
"""conv2d_gemm_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")

# conv2d_winograd_weight_transform
@generic_func
def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
Expand All @@ -280,6 +286,13 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

# conv2d_gemm_weight_transform
@generic_func
def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
"""Schedule conv2d_gemm_weight_transform"""
with target:
return topi.generic.schedule_conv2d_gemm_weight_transform(outs)

# deformable_conv2d
def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,23 @@ def is_fast_int8_on_arm():
target = tvm.target.Target.current(allow_none=False)
return '+v8.2a,+dotprod' in ' '.join(target.options)

def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in ' '.join(target.options)

########################
# ARM CPU legalizations.
########################

@qnn_conv2d_legalize.register('arm_cpu')
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
if is_fast_int8_on_arm():
if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)


@qnn_dense_legalize.register('arm_cpu')
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
Expand Down
82 changes: 82 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,41 @@ Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> st
return Call(op, {data, weight}, Attrs(attrs), {});
}

template <typename T>
Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, std::string data_layout, std::string kernel_layout,
std::string out_layout, DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get(op_name);
return Call(op, {data, weight}, Attrs(attrs), {});
}

Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
const Op& op = Op::Get(op_name);
return Call(op, {weight}, Attrs(attrs), {});
}

Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) {
auto attrs = make_object<ConvGemmWeightTransformAttrs>();
attrs->tile_rows = tile_rows;
attrs->tile_cols = tile_cols;
const Op& op = Op::Get(op_name);
return Call(op, {weight}, Attrs(attrs), {});
}

template <typename T>
Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Expand Down Expand Up @@ -504,6 +532,60 @@ weight transformation in advance.
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);

// relay.nn.contrib_conv2d_gemm_without_weight_transform
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform")
.set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, std::string data_layout,
std::string kernel_layout, std::string out_layout, DataType out_dtype) {
return MakeConvGemm<Conv2DAttrs>(
data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform");
});

RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
.describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_gemm_weight_transform.
- **data**: Input is 4D array of shape (batch_size, height, width, in_channels)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DGemm", Conv2DGemmRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.contrib_conv2d_gemm_weight_transform

TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);

TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform")
.set_body_typed([](Expr weights, int tile_rows, int tile_cols) {
return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols,
"nn.contrib_conv2d_gemm_weight_transform");
});

RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform")
.describe(R"code(Weight transformation of GEMM convolution algorithm.
Separate this into another operator in order to enable Precompute Pass to compute the
weight transformation in advance.
)code" TVM_ADD_FILELINE)
.set_attrs_type<ConvGemmWeightTransformAttrs>()
.set_num_inputs(1)
.add_argument("weights", "Tensor", "The weights tensor.")
.set_support_level(10)
.add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel);

// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
Expand Down
Loading

0 comments on commit c01fa93

Please sign in to comment.