Skip to content

Commit

Permalink
[AutoScheduler] Add layout rewrite support for dense and batch matmul…
Browse files Browse the repository at this point in the history
… on CPU (#7161)

* [AutoScheduler] Add layout rewrite for dense and batch_matmul

* Fix test & Address comments

* Fix shape inference

* fix test
  • Loading branch information
merrymercy authored Dec 25, 2020
1 parent e27ad08 commit 7dcafb0
Show file tree
Hide file tree
Showing 22 changed files with 276 additions and 92 deletions.
8 changes: 8 additions & 0 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ class ComputeDAG : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};

/*!
* \brief Get the orginal shape from a rewritten layout string.
* \param rewritten_layout The layout after auto-scheduler's layout rewrite.
* \param axis_names Specifiy the names of axes.
* \return shape The original shape.
*/
Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names);

} // namespace auto_scheduler
} // namespace tvm

Expand Down
10 changes: 9 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
std::string auto_scheduler_rewritten_layout;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down Expand Up @@ -924,6 +924,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
Expand All @@ -936,6 +937,13 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {}
};

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
bool sparse_lhs;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from . import workload_registry

# Shortcut
from .compute_dag import ComputeDAG, LayoutRewriteOption
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,21 @@ def __setstate__(self, state):
# Since we always use tensors to recover the ComputeDAG, we do not support
# (de)serialization of the ComputeDAG constructed by a schedule.
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, LoadJSON(state["tensors"]), None)


def get_shape_from_rewritten_layout(rewritten_layout, axis_names):
"""Get the orginal shape from a rewritten layout string.
Parameters
----------
rewritten_layout: str
The layout after rewrite
axis_names: List[str]
Specify the order of axes by names
Returns
-------
shape: List[PrimExpr]
The original shape
"""
return _ffi_api.GetShapeFromRewrittenLayout(rewritten_layout, axis_names)
17 changes: 11 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def _compute_conv2d(attrs, inputs, out_type):
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
Expand All @@ -210,7 +209,7 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(out_layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(auto_scheduler_rewritten_layout)
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_conv2d
Expand Down Expand Up @@ -684,14 +683,17 @@ def dilation2d_strategy(attrs, inputs, out_type, target):


# dense
def wrap_compute_dense(topi_compute):
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""

def _compute_dense(attrs, inputs, out_type):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
args = [inputs[0], inputs[1], None, out_dtype]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_dense

Expand All @@ -710,11 +712,14 @@ def dense_strategy(attrs, inputs, out_type, target):


# batch_matmul
def wrap_compute_batch_matmul(topi_compute):
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap batch_matmul topi compute"""

def _compute_batch_matmul(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], out_type.shape)]
args = [inputs[0], inputs[1], out_type.shape]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_batch_matmul

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
name="dense_nopack.x86",
plevel=10,
)

if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
naive_schedule,
name="dense.generic",
plevel=11,
)

if "cblas" in target.libs:
with SpecializedCondition(same_type and dtype in ["float32", "float64"]):
strategy.add_implementation(
Expand All @@ -350,7 +359,7 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
plevel=15,
)
with SpecializedCondition(m >= 16):
# this implementation may not be well-optimized, so use plevel=8 for now.
# this implementation may not be well-optimized, so use plevel=5 for now.
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
Expand All @@ -364,9 +373,9 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
"""batch_matmul x86 strategy"""
strategy = _op.OpStrategy()
if is_dynamic(out_type):
if is_dynamic(out_type) or is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
name="batch_matmul.generic",
plevel=10,
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_something():
import os
import sys
import time
import threading
import pytest
import numpy as np
import tvm
Expand Down Expand Up @@ -742,4 +743,21 @@ def terminate_self():
sys.exit(-1)


class PropagatingThread(threading.Thread):
"""A thread that propagates the exection to the main thread"""

def run(self):
self.exc = None
try:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e

def join(self, timeout=None):
super(PropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret


tvm._ffi._init_api("testing", __name__)
30 changes: 24 additions & 6 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Binary Neural Network (BNN) Operators"""
"""Batch matrix multiplication"""
# pylint: disable=invalid-name
from tvm import te
from tvm import te, auto_scheduler
from ..utils import get_const_tuple


def batch_matmul(x, y, oshape=None):
def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch. Supports broadcasting for batch dimension.
Expand All @@ -36,14 +36,25 @@ def batch_matmul(x, y, oshape=None):
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
y_shape = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["b", "j", "k"]
)
auto_scheduler.remove_index_check(y)
else:
y_shape = get_const_tuple(y.shape)
assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"

XB = x_shape[0]
YB = y_shape[0]
_, M, K = x.shape
Expand All @@ -54,8 +65,15 @@ def batch_matmul(x, y, oshape=None):
batch = te.max(XB, YB)
N = y.shape[1]
oshape = (batch, M, N)
return te.compute(

output = te.compute(
oshape,
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
attrs={"layout_free_placeholders": [y]},
)

if auto_scheduler_rewritten_layout:
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)

return output
37 changes: 9 additions & 28 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def conv2d_nhwc(
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype: str = "float32",
The type of output tensor
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
Returns
-------
output : tvm.te.Tensor
Expand All @@ -381,34 +387,9 @@ def conv2d_nhwc(

if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
# todo(merrymercy): wrap this with a more general interface.
if len(Filter.shape) == 17:
# For mali.
# GPU tile structure is SSSRRSRS
# You could refer function comment of DoMultiLevelTiling
# in the utils.h to see more detail explanation.
kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13]
kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14]
channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15]
num_filter = Filter.shape[12] * Filter.shape[16]
for i in range(6):
num_filter *= Filter.shape[i]
elif len(Filter.shape) >= 10:
# For cpu tile structure SSRSRS
base = len(Filter.shape) - 10
kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base]
channel = Filter.shape[4 + base] * Filter.shape[8 + base]
num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
for i in range(base + 2):
num_filter *= Filter.shape[i]
elif len(Filter.shape) == 4:
num_filter, kernel_h, kernel_w, channel = Filter.shape
else:
raise ValueError(
"Don't know how to infer the layout for filter shape: %s. "
"Please add a new branch to handle this case." % str(Filter)
)
kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["ry", "rx", "rc", "ff"]
)
auto_scheduler.remove_index_check(Filter)
else:
kernel_h, kernel_w, channel, num_filter = Filter.shape
Expand Down
30 changes: 24 additions & 6 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""TVM operator fully connected compute."""
from tvm import te
from tvm import te, auto_scheduler
from .. import tag


def dense(data, weight, bias=None, out_dtype=None):
def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""):
"""The default implementation of dense in topi.
Parameters
Expand All @@ -30,35 +30,53 @@ def dense(data, weight, bias=None, out_dtype=None):
weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.te.Tensor, optional
bias : Optional[tvm.te.Tensor]
1-D with shape [out_dim]
out_dtype : str
out_dtype : Optional[str]
The output type. This is used for mixed precision.
auto_scheduler_rewritten_layout: str = ""
The layout after auto-scheduler's layout rewrite pass.
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
assert len(data.shape) == 2, "only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = data.shape
out_dim, _ = weight.shape

if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
auto_scheduler_rewritten_layout, ["j", "k"]
)
auto_scheduler.remove_index_check(weight)
else:
out_dim, red_dim = weight.shape
assert in_dim == red_dim

k = te.reduce_axis((0, in_dim), name="k")
matmul = te.compute(
(batch, out_dim),
lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
name="T_dense",
tag="dense",
attrs={"layout_free_placeholders": [weight]},
)
if bias is not None:
matmul = te.compute(
(batch, out_dim),
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
)

if auto_scheduler_rewritten_layout:
matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)

return matmul
Loading

0 comments on commit 7dcafb0

Please sign in to comment.