From 7dcafb017a05ac0d5ecd7cfe8d8741d33a24bbad Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 24 Dec 2020 16:57:24 -0800 Subject: [PATCH] [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161) * [AutoScheduler] Add layout rewrite for dense and batch_matmul * Fix test & Address comments * Fix shape inference * fix test --- include/tvm/auto_scheduler/compute_dag.h | 8 ++++ include/tvm/relay/attrs/nn.h | 10 +++- python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/compute_dag.py | 18 +++++++ python/tvm/relay/op/strategy/generic.py | 17 ++++--- python/tvm/relay/op/strategy/x86.py | 15 ++++-- python/tvm/testing.py | 18 +++++++ python/tvm/topi/nn/batch_matmul.py | 30 +++++++++--- python/tvm/topi/nn/conv2d.py | 37 ++++----------- python/tvm/topi/nn/dense.py | 30 +++++++++--- src/auto_scheduler/compute_dag.cc | 26 ++++++++++ src/relay/op/make_op.h | 2 + src/relay/op/nn/nn.cc | 34 ++++++++++---- src/relay/op/nn/nn.h | 10 +++- .../auto_scheduler_layout_rewrite.cc | 16 ++++++- .../combine_parallel_batch_matmul.cc | 7 ++- .../transforms/combine_parallel_dense.cc | 16 +++++++ .../transforms/combine_parallel_op_batch.h | 2 +- .../test_auto_scheduler_layout_rewrite.py | 47 ++++++++++++++++++- .../relay/test_pass_combine_parallel_dense.py | 2 - .../unittest/test_auto_scheduler_common.py | 18 ------- .../test_auto_scheduler_search_policy.py | 3 +- 22 files changed, 276 insertions(+), 92 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index bdb6489e7f0b..1e3f09721279 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -303,6 +303,14 @@ class ComputeDAG : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); }; +/*! + * \brief Get the orginal shape from a rewritten layout string. + * \param rewritten_layout The layout after auto-scheduler's layout rewrite. + * \param axis_names Specifiy the names of axes. + * \return shape The original shape. + */ +Array GetShapeFromRewrittenLayout(String rewritten_layout, Array axis_names); + } // namespace auto_scheduler } // namespace tvm diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 5ffca99d36d7..7bfd58080521 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -120,7 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; - std::string auto_scheduler_rewritten_layout; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { @@ -924,6 +924,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { @@ -936,6 +937,13 @@ struct DenseAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for batch matmul operator */ +struct BatchMatmulAttrs : public tvm::AttrsNode { + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + + TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {} +}; + /*! \brief Attributes for sparse_dense operator */ struct SparseDenseAttrs : public tvm::AttrsNode { bool sparse_lhs; diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 4926b88e4658..a03e156cc10f 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -31,7 +31,7 @@ from . import workload_registry # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteOption +from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout from .cost_model import RandomModel, XGBModel from .dispatcher import DispatchContext, ApplyHistoryBest from .measure import ( diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 94cb640f3516..d8a242260285 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -234,3 +234,21 @@ def __setstate__(self, state): # Since we always use tensors to recover the ComputeDAG, we do not support # (de)serialization of the ComputeDAG constructed by a schedule. self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, LoadJSON(state["tensors"]), None) + + +def get_shape_from_rewritten_layout(rewritten_layout, axis_names): + """Get the orginal shape from a rewritten layout string. + + Parameters + ---------- + rewritten_layout: str + The layout after rewrite + axis_names: List[str] + Specify the order of axes by names + + Returns + ------- + shape: List[PrimExpr] + The original shape + """ + return _ffi_api.GetShapeFromRewrittenLayout(rewritten_layout, axis_names) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 9fc6089fed97..95b5d6ad4ff9 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -199,7 +199,6 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype - auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs) out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: @@ -210,7 +209,7 @@ def _compute_conv2d(attrs, inputs, out_type): args.append(out_layout) args.append(out_dtype) if need_auto_scheduler_layout: - args.append(auto_scheduler_rewritten_layout) + args.append(get_auto_scheduler_rewritten_layout(attrs)) return [topi_compute(*args)] return _compute_conv2d @@ -684,14 +683,17 @@ def dilation2d_strategy(attrs, inputs, out_type, target): # dense -def wrap_compute_dense(topi_compute): +def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False): """wrap dense topi compute""" def _compute_dense(attrs, inputs, out_type): """Compute definition of dense""" out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype - return [topi_compute(inputs[0], inputs[1], None, out_dtype)] + args = [inputs[0], inputs[1], None, out_dtype] + if need_auto_scheduler_layout: + args.append(get_auto_scheduler_rewritten_layout(attrs)) + return [topi_compute(*args)] return _compute_dense @@ -710,11 +712,14 @@ def dense_strategy(attrs, inputs, out_type, target): # batch_matmul -def wrap_compute_batch_matmul(topi_compute): +def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False): """wrap batch_matmul topi compute""" def _compute_batch_matmul(attrs, inputs, out_type): - return [topi_compute(inputs[0], inputs[1], out_type.shape)] + args = [inputs[0], inputs[1], out_type.shape] + if need_auto_scheduler_layout: + args.append(get_auto_scheduler_rewritten_layout(attrs)) + return [topi_compute(*args)] return _compute_batch_matmul diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 5dfeca65e5c3..841213a517bc 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -325,6 +325,15 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): name="dense_nopack.x86", plevel=10, ) + + if is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True), + naive_schedule, + name="dense.generic", + plevel=11, + ) + if "cblas" in target.libs: with SpecializedCondition(same_type and dtype in ["float32", "float64"]): strategy.add_implementation( @@ -350,7 +359,7 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): plevel=15, ) with SpecializedCondition(m >= 16): - # this implementation may not be well-optimized, so use plevel=8 for now. + # this implementation may not be well-optimized, so use plevel=5 for now. strategy.add_implementation( wrap_compute_dense(topi.x86.dense_pack), wrap_topi_schedule(topi.x86.schedule_dense_pack), @@ -364,9 +373,9 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - if is_dynamic(out_type): + if is_dynamic(out_type) or is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 8311a63d0749..32307a99e65a 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -58,6 +58,7 @@ def test_something(): import os import sys import time +import threading import pytest import numpy as np import tvm @@ -742,4 +743,21 @@ def terminate_self(): sys.exit(-1) +class PropagatingThread(threading.Thread): + """A thread that propagates the exection to the main thread""" + + def run(self): + self.exc = None + try: + self.ret = self._target(*self._args, **self._kwargs) + except BaseException as e: + self.exc = e + + def join(self, timeout=None): + super(PropagatingThread, self).join(timeout) + if self.exc: + raise self.exc + return self.ret + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 98acc2d4848e..9ca2df7c46e1 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Binary Neural Network (BNN) Operators""" +"""Batch matrix multiplication""" # pylint: disable=invalid-name -from tvm import te +from tvm import te, auto_scheduler from ..utils import get_const_tuple -def batch_matmul(x, y, oshape=None): +def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. Supports broadcasting for batch dimension. @@ -36,14 +36,25 @@ def batch_matmul(x, y, oshape=None): Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" x_shape = get_const_tuple(x.shape) - y_shape = get_const_tuple(y.shape) + if auto_scheduler_rewritten_layout: + # Infer shape for the rewritten layout + y_shape = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["b", "j", "k"] + ) + auto_scheduler.remove_index_check(y) + else: + y_shape = get_const_tuple(y.shape) + assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul" + XB = x_shape[0] YB = y_shape[0] _, M, K = x.shape @@ -54,8 +65,15 @@ def batch_matmul(x, y, oshape=None): batch = te.max(XB, YB) N = y.shape[1] oshape = (batch, M, N) - return te.compute( + + output = te.compute( oshape, lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), tag="batch_matmul", + attrs={"layout_free_placeholders": [y]}, ) + + if auto_scheduler_rewritten_layout: + output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) + + return output diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index ead9f16a256f..e2384c4aafa5 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -361,6 +361,12 @@ def conv2d_nhwc( dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] + out_dtype: str = "float32", + The type of output tensor + + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + Returns ------- output : tvm.te.Tensor @@ -381,34 +387,9 @@ def conv2d_nhwc( if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout - # todo(merrymercy): wrap this with a more general interface. - if len(Filter.shape) == 17: - # For mali. - # GPU tile structure is SSSRRSRS - # You could refer function comment of DoMultiLevelTiling - # in the utils.h to see more detail explanation. - kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13] - kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14] - channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15] - num_filter = Filter.shape[12] * Filter.shape[16] - for i in range(6): - num_filter *= Filter.shape[i] - elif len(Filter.shape) >= 10: - # For cpu tile structure SSRSRS - base = len(Filter.shape) - 10 - kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] - kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] - channel = Filter.shape[4 + base] * Filter.shape[8 + base] - num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] - for i in range(base + 2): - num_filter *= Filter.shape[i] - elif len(Filter.shape) == 4: - num_filter, kernel_h, kernel_w, channel = Filter.shape - else: - raise ValueError( - "Don't know how to infer the layout for filter shape: %s. " - "Please add a new branch to handle this case." % str(Filter) - ) + kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["ry", "rx", "rc", "ff"] + ) auto_scheduler.remove_index_check(Filter) else: kernel_h, kernel_w, channel, num_filter = Filter.shape diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 0ce0f9ea1299..474fea42a7cb 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """TVM operator fully connected compute.""" -from tvm import te +from tvm import te, auto_scheduler from .. import tag -def dense(data, weight, bias=None, out_dtype=None): +def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""): """The default implementation of dense in topi. Parameters @@ -30,30 +30,44 @@ def dense(data, weight, bias=None, out_dtype=None): weight : tvm.te.Tensor 2-D with shape [out_dim, in_dim] - bias : tvm.te.Tensor, optional + bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] - out_dtype : str + out_dtype : Optional[str] The output type. This is used for mixed precision. + auto_scheduler_rewritten_layout: str = "" + The layout after auto-scheduler's layout rewrite pass. + Returns ------- output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ - assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense" + assert len(data.shape) == 2, "only support 2-dim dense" if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: out_dtype = data.dtype batch, in_dim = data.shape - out_dim, _ = weight.shape + + if auto_scheduler_rewritten_layout: + # Infer shape for the rewritten layout + out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["j", "k"] + ) + auto_scheduler.remove_index_check(weight) + else: + out_dim, red_dim = weight.shape + assert in_dim == red_dim + k = te.reduce_axis((0, in_dim), name="k") matmul = te.compute( (batch, out_dim), lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k), name="T_dense", tag="dense", + attrs={"layout_free_placeholders": [weight]}, ) if bias is not None: matmul = te.compute( @@ -61,4 +75,8 @@ def dense(data, weight, bias=None, out_dtype=None): lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST, ) + + if auto_scheduler_rewritten_layout: + matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout) + return matmul diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index af45f2df8b04..64114c8331b8 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -1410,6 +1411,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ss.str(); }); +Array GetShapeFromRewrittenLayout(String rewritten_layout, Array axis_names) { + Array shape; + std::vector extracted_names; + topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names); + + Array ret(axis_names.size(), 1); + + size_t ct = 0; + for (size_t i = 0; i < axis_names.size(); ++i) { + for (size_t j = 0; j < extracted_names.size(); ++j) { + if (axis_names[i] == extracted_names[j]) { + ret.Set(i, ret[i] * shape[j]); + ct++; + } + } + } + + CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match"; + + return ret; +} + TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG") .set_body_typed([](Optional> tensors, Optional sch) { if (sch) { @@ -1452,5 +1475,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout") return index_rewriter.Rewrite(body); }); +TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout") + .set_body_typed(GetShapeFromRewrittenLayout); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index d2fb6aa2b9c3..2b05290b270c 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -46,6 +46,8 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); +Expr MakeBatchMatmul(Expr lhs, Expr rhs); + Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr MakeFull(Expr fill_value, Array shape, DataType dtype); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 816b98038e46..fbb6204faed4 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -24,6 +24,7 @@ #include "nn.h" +#include #include #include #include @@ -845,37 +846,49 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). .add_type_rel("GroupNorm", GroupNormRel); // relay.nn.batch_matmul +TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs); + bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; - ICHECK(x->shape.size() == 3 && y->shape.size() == 3); + + const auto* param = attrs.as(); + Array y_shape; + if (param->auto_scheduler_rewritten_layout.size() == 0) { + y_shape = y->shape; + } else { + y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, + {"b", "j", "k"}); + } + + ICHECK(x->shape.size() == 3 && y_shape.size() == 3); bool is_dyn = false; Array oshape; for (size_t i = 0; i < 3; ++i) { - if (x->shape[i].as() != nullptr || y->shape[i].as() != nullptr) { + if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { is_dyn = true; oshape.push_back(Any()); } else { if (i == 0) { - oshape.push_back(max(x->shape[i], y->shape[i])); + oshape.push_back(max(x->shape[i], y_shape[i])); } else { oshape.push_back(x->shape[i]); } } } if (!is_dyn) { - ICHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) || - reporter->AssertEQ(y->shape[0], 1)) + ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || + reporter->AssertEQ(y_shape[0], 1)) << "BatchDot: batch dimensions don't match, " - << " x shape=" << x->shape << ", y shape=" << y->shape; - ICHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) + << " x shape=" << x->shape << ", y shape=" << y_shape; + ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y_shape; - oshape.Set(2, y->shape[1]); + oshape.Set(2, y_shape[1]); } // assign output type @@ -885,8 +898,9 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs // Positional relay function to create batch_matmul operator used by frontend FFI. Expr MakeBatchMatmul(Expr x, Expr y) { + auto attrs = make_object(); static const Op& op = Op::Get("nn.batch_matmul"); - return Call(op, {x, y}, Attrs(), {}); + return Call(op, {x, y}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 30ef3079e565..9b9cff2dba81 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -57,7 +57,15 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, // data dtype as the weight dtype. However if weight dtype is explicitly // present we will use that. auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype); - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + if (param->auto_scheduler_rewritten_layout.size() == 0) { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // If the layout is rewritten by auto-scheduler, + // we just forcly apply the layout provided by auto-scheduler and + // skip the normal inference logic. + {} // do nothing + } oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index c9875ef5d718..53e7a0256e5e 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -83,6 +83,12 @@ class FuncMutator : public ExprMutator { Attrs updated_attrs; if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else { + LOG(FATAL) << "Unhandled attribute: " << call->attrs; } new_n = Call(call->op, updated_args, updated_attrs); } @@ -93,7 +99,7 @@ class FuncMutator : public ExprMutator { std::deque ori_layouts_queue_; std::deque new_layouts_queue_; - std::vector target_ops_{"nn.conv2d"}; + std::vector target_ops_{"nn.conv2d", "nn.dense", "nn.batch_matmul"}; }; Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { @@ -150,8 +156,14 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout") .set_body_typed([](const Attrs& attrs) { if (attrs->IsInstance()) { return attrs.as()->auto_scheduler_rewritten_layout; + } else if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; + } else if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; + } else { + LOG(FATAL) << "Unhandled attribute: " << attrs; } - return std::string(); + return tvm::String(); }); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index 5b56504602a9..20a7c7ff7815 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -70,16 +70,15 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { } Call MakeCombinedOp(const Group& branches) { - const Op& batch_matmul = Op::Get("nn.batch_matmul"); Expr data = branches[0][0]->args[0]; Array weights; for (const auto& branch : branches) { - auto batch_matmul = branch[0]; - weights.push_back(batch_matmul->args[1]); + auto call = branch[0]; + weights.push_back(call->args[1]); } Expr new_weight = MakeConcatenate(Tuple(weights), 1); - return Call(batch_matmul, {data, new_weight}, {}, {}); + return Downcast(MakeBatchMatmul(data, new_weight)); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 6d4c8c000f31..d9ca4bf2042e 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -57,6 +57,22 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} protected: + Call MakeCombinedOp(const Group& branches) { + Array new_args; + size_t num_args = branches[0][0]->args.size(); + for (size_t i = 0; i < num_args; i++) { + Array arg_from_all_branches; + for (const auto& branch : branches) { + arg_from_all_branches.push_back(branch[0]->args[i]); + } + + new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0)); + } + + CHECK_EQ(num_args, 2); + return Downcast(MakeBatchMatmul(new_args[0], new_args[1])); + } + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { StructuralEqual eq; const auto* attrs_a = a->attrs.as(); diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h index 7a518e9ac370..db4734bffcf8 100644 --- a/src/relay/transforms/combine_parallel_op_batch.h +++ b/src/relay/transforms/combine_parallel_op_batch.h @@ -95,7 +95,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param branches branches that are to be combined * \return new call with branches combined as batch op by stacking args */ - Call MakeCombinedOp(const Group& branches) final; + virtual Call MakeCombinedOp(const Group& branches); /* * \brief Checks if argument of op following combined ops are able to be combined diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py index 299fcb8ebb2c..66d40bac0af5 100644 --- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py @@ -23,6 +23,7 @@ from tvm import relay, auto_scheduler from tvm.contrib import graph_runtime import tvm.testing +from tvm.testing import PropagatingThread def get_np_array(var, dtype): @@ -70,6 +71,28 @@ def get_relay_conv2d( return mod, data, weight +def get_relay_dense(m=128, n=128, k=128): + dtype = "float32" + d = relay.var("data", shape=(m, k), dtype=dtype) + w = relay.var("weight", shape=(n, k), dtype=dtype) + y = relay.nn.dense(d, w, units=n) + mod = tvm.IRModule() + mod["main"] = relay.Function([d, w], y) + data, weight = get_np_array(d, dtype), get_np_array(w, dtype) + return mod, data, weight + + +def get_relay_batchmm(batch=4, m=128, n=128, k=128): + dtype = "float32" + d = relay.var("data", shape=(batch, m, k), dtype=dtype) + w = relay.var("weight", shape=(batch, n, k), dtype=dtype) + y = relay.nn.batch_matmul(d, w) + mod = tvm.IRModule() + mod["main"] = relay.Function([d, w], y) + data, weight = get_np_array(d, dtype), get_np_array(w, dtype) + return mod, data, weight + + def tune_and_check(mod, data, weight): # Extract tasks from a relay program target = tvm.target.Target("llvm") @@ -109,13 +132,33 @@ def compile_and_run(disabled_pass={}): actual_output = compile_and_run() expected_output = compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"}) - tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4) + tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4) def test_conv2d(): + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool mod, data, weight = get_relay_conv2d(kh=1, kw=1) - tune_and_check(mod, data, weight) + t = PropagatingThread(target=tune_and_check, args=(mod, data, weight)) + t.start() + t.join() + + +def test_dense(): + mod, data, weight = get_relay_dense() + t = PropagatingThread(target=tune_and_check, args=(mod, data, weight)) + t.start() + t.join() + + +def test_batch_matmul(): + mod, data, weight = get_relay_batchmm() + t = PropagatingThread(target=tune_and_check, args=(mod, data, weight)) + t.start() + t.join() if __name__ == "__main__": test_conv2d() + test_dense() + test_batch_matmul() diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index a8c9782953bb..cd946ab593bf 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -286,8 +286,6 @@ def check(i, j, k, bias_shape1, bias_shape2): y = run_opt_pass(y_before, combine_pass) y_expected = expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2) y_expected = run_opt_pass(y_expected, transform.InferType()) - print(y.astext(False)) - print(y_expected.astext(False)) tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check(3, 5, 4, (), ()) diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 87814f28ad72..a037b680e2e1 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -16,9 +16,6 @@ # under the License. """Common functions for auto_scheduler test cases""" - -import threading - import tvm from tvm import te, auto_scheduler from tvm import topi @@ -251,18 +248,3 @@ def get_tiled_matmul(): ) return dag, s0 - - -class PropagatingThread(threading.Thread): - def run(self): - self.exc = None - try: - self.ret = self._target(*self._args, **self._kwargs) - except BaseException as e: - self.exc = e - - def join(self): - super(PropagatingThread, self).join() - if self.exc: - raise self.exc - return self.ret diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 6d4fb6884ff9..5bc7c2af21f8 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -24,9 +24,10 @@ import tvm import tvm.testing +from tvm.testing import PropagatingThread from tvm import auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread +from test_auto_scheduler_common import matmul_auto_scheduler_test import multiprocessing