diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index a703d928ba5f..c2ae572de818 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -210,6 +210,28 @@ class Layout : public NodeRef { return ct; } + /*! + * \brief Returns a new layout where the dims have been expanded to match the primal dimensions. + * \param dst_layout The dst layout to which current layout has to be expanded. + * \return The expanded Layout. + */ + inline Layout ExpandPrimal(const Layout& dst_layout) { + Layout new_src_layout; + // 1) Find the axis which are missing in the current layout. Make them the prefix. + std::string new_src_layout_str = ""; + for (auto dst_axis : dst_layout->axes) { + if (LayoutAxis::Get(dst_axis).IsPrimal()) { + if (!this->Contains(LayoutAxis::Get(dst_axis))) { + new_src_layout_str += dst_axis->var->name_hint; + } + } + } + // 2) Now, add the primal axis of the current layout. + new_src_layout_str += this->name(); + new_src_layout = Layout(new_src_layout_str); + return new_src_layout; + } + /*! * \brief return the index of the input axis. * If it is not found in the layout or the layout is undefined, diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0002390be809..3f371f2f0b0b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -37,6 +37,7 @@ #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" #include "../../pass/alter_op_layout.h" +#include "../../pass/pattern_util.h" #include "transform.h" namespace tvm { diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 9142c0eee80e..23a480b4e42f 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -38,6 +38,7 @@ #include #include "alter_op_layout.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -45,19 +46,35 @@ namespace relay { namespace alter_op_layout { // Make a transform CallNode +/* Performs 2 operations + * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size. + * For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC. + * 2) Call layout transform with new src layout. + */ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { - if (src_layout.Equals(dst_layout)) { return raw; } - CHECK(src_layout.defined() && dst_layout.defined()) - << "Cannot insert layout transform because there are undefined layouts"; - CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined()) - << "Cannot insert layout transform because there are inconvertible layouts: " - << src_layout << " v.s. " << dst_layout; - static auto &transform_op = Op::Get("layout_transform"); - NodePtr attrs = make_node(); - attrs->src_layout = src_layout.name(); - attrs->dst_layout = dst_layout.name(); - Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs}); - return std::move(transform); + if (src_layout.Equals(dst_layout)) { + return raw; + } + + // 1) Check if the shape lengths are different. If yes, expand dims. + Expr input_expr = raw; + Layout new_src_layout = src_layout; + if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { + int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); + new_src_layout = src_layout.ExpandPrimal(dst_layout); + input_expr = MakeExpandDims(input_expr, 0, num_new_axis); + if (new_src_layout.Equals(dst_layout)) { + return input_expr; + } + } + + // 2) Insert layout transform on the transformed src. + CHECK(new_src_layout.defined() && dst_layout.defined()) + << "Cannot insert layout transform because there are undefined layouts"; + CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined()) + << "Cannot insert layout transform because there are inconvertible layouts: " + << new_src_layout << " v.s. " << dst_layout; + return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); } // Memorize layout transform so we can reuse internal transformed nodes diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h index 80593a521f25..350cedeb98ad 100644 --- a/src/relay/pass/alter_op_layout.h +++ b/src/relay/pass/alter_op_layout.h @@ -30,10 +30,57 @@ #include #include +#include namespace tvm { namespace relay { +/*! + * \brief Returns a new layout where the subordinate factors are adjusted based on the tensor + * shape. + * \param old_layout The old layout before any transformation. + * \param old_shape The shape of the original tensor. + * \return The adjusted Layout. + */ +inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout, + const Array& old_shape) { + // For each subordinate axis + // 1) Find the corresponding dual axis. + // 2) Find the Index of this dual axis in old_layout. + // 3) Find the shape of the that axis in old_shape. + // 4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor. + std::string new_layout; + for (auto axis : src_layout->axes) { + if (!LayoutAxis::Get(axis).IsPrimal()) { + // 1) Find the corresponding dual axis + auto dual_axis = LayoutAxis::Get(axis).ToPrimal().name()[0]; + + // 2) Find the index of this dual axis in old_layout + int old_axis = old_layout.IndexOf(LayoutAxis::Get(dual_axis)); + + // 3) Find the shape of this index in old_shape + auto shape_val = old_shape[old_axis]; + + // 4) a) Check if this shape element is 1. + bool is_shape_one = false; + if (auto* shape_int = shape_val.as()) { + if (shape_int->value == 1) { + new_layout += "1"; + is_shape_one = true; + } + } + + // 4) b) If shape is not 1, retain the factor. + if (!is_shape_one) { + auto new_shape_val = src_layout.FactorOf(LayoutAxis::Get(dual_axis)); + new_layout += std::to_string(new_shape_val); + } + } + new_layout += LayoutAxis::Get(axis).name(); + } + return Layout(new_layout); +} + /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention * \param attrs The attribute of the node. @@ -111,28 +158,39 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, int scalar = layouts[0].ndim() == 0 ? 0 : 1; return Array >{layouts, {layouts[1-scalar]}}; } else { - // try to broadcast the tensors to the larger dimension + // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims + // while transforming layout. int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1; int small_idx = 1 - large_idx; Layout ret = layouts[large_idx]; - // extract common part - size_t i = layouts[large_idx].ndim(); - for (; i != 0; --i) { - const auto& axis = layouts[large_idx][i-1]; - if (!layouts[small_idx].Contains(axis.ToPrimal())) { - break; - } - } - - Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i); - if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) { - // not convertible - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + if (old_in_layouts[0].Equals(old_in_layouts[1])) { + // Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1] + // In this case, we might have NCHW16c coming for 1 operand. However, the other operand does + // not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the + // second operand. The following section of code walks through the layouts and shapes to + // perform that operation. + // a in NCHWC16c + // b in NHW1 + // b = layout_transform(b) from NHW1 -> NCHW1c + // add(a, b) + auto old_small_shape = old_in_shapes[small_idx]; + auto old_small_layout = old_in_layouts[small_idx]; + auto new_small_layout = + AdjustSubordinateFactors(layouts[large_idx], old_small_layout, old_small_shape); + layouts.Set(small_idx, new_small_layout); + } else { + // Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case, + // while transforming the layout, we expand dims to make C go to NHWC, and then use the + // modified layout of the first operator to call the layout transform. E.g. + // a in NCHWC16c + // b in C + // b = expand_dims(b) from C -> NHWC + // b = layout_transform(b) from NHWC -> NCHW16c + // add(a, b) + layouts.Set(small_idx, ret); } - - layouts.Set(small_idx, common_part); - return Array > {layouts, {ret}}; + return Array>{layouts, {ret}}; } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index bf9621bb404a..988b13c20361 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -505,6 +505,8 @@ Expr MakeSqueeze(Expr data, Array axis); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); +Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout); + Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index deac4e672964..a73a65804e98 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -242,19 +242,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC if (param->kernel_zero_point != 1) { multiplied_t2 = Multiply(zp_kernel, reduced_t2); } - - // Replicate to go back to NHWC/NCHW. This is not necessarily needed, but it fails AlterOpLayout. - // We can remove this once AlterOpLayout refactoring completes - - // https://github.com/dmlc/tvm/issues/3670 - Array reps; - if (param->data_layout == "NCHW") { - reps = {1, out_channels, 1, 1}; - } else if (param->data_layout == "NHWC") { - reps = {1, 1, 1, out_channels}; - } else { - LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout"; - } - return Tile(multiplied_t2, reps); + return multiplied_t2; } /* diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index c8e479d99ee4..b4e8bfd71b62 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -607,6 +607,39 @@ def tflite_anistropic_strides(): golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2) np.testing.assert_equal(qnn_output, golden_output) +def broadcast_layout_test(): + # Test broadcast support for NHWC layout. + data_shape = (1, 229, 229, 3) # NHWC + data_dtype = 'uint8' + kernel_shape = (7, 7, 3, 64) # HWIO + kernel_dtype = 'int8' + _, qnn_func = get_funcs(data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=8, + kernel_zero_point=3, + kernel_size=(7, 7), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32") + func = qnn_func['main'].body + bias = relay.var("bias", shape=(64,), dtype="int32") + bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32") + + # Check broadcast support on both lhs and rhs + func = relay.add(func, bias2) + func = relay.add(bias2, func) + func = relay.add(bias, func) + func = relay.add(func, bias) + func = relay.Function(relay.analysis.free_vars(func), func) + mod = relay.Module.from_expr(func) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + if __name__ == "__main__": no_zero_point_test() input_zero_point_test() @@ -620,3 +653,4 @@ def tflite_anistropic_strides(): tflite_large_irregular_test() tflite_output_multiplier_greater_than_one() tflite_anistropic_strides() + broadcast_layout_test() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 6b31eed8f166..cc668d7d1366 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -134,7 +134,8 @@ def expected(): kernel_layout="OIHW16i", data_layout="NCHW16c") b = relay.expand_dims(bias, axis=1, num_newaxis=2) - b = relay.layout_transform(b, "CHW", "CHW16c") + b = relay.expand_dims(b, axis=0, num_newaxis=1) + b = relay.layout_transform(b, "NCHW", "NCHW16c") y = relay.add(y, b) y = relay.nn.relu(y) @@ -304,8 +305,10 @@ def expected(): weight = relay.var("weight") x = relay.layout_transform(x, "NCHW", "NCHW16c") bias = relay.expand_dims(bias, 1, 2) - bias = relay.layout_transform(bias, "CHW", "CHW16c") - scale = relay.layout_transform(scale, "CHW", "CHW16c") + bias = relay.expand_dims(bias, 0, 1) + bias = relay.layout_transform(bias, "NCHW", "NCHW16c") + scale = relay.expand_dims(scale, 0, 1) + scale = relay.layout_transform(scale, "NCHW", "NCHW16c") y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c") y = relay.add(y, bias) # test broadcasting to lhs