From e5f6f06a05bbf4f66131171745cfe3c956f7b0bb Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 21 Aug 2020 09:06:53 -0700 Subject: [PATCH] [RELAY][DYN] Dynamic upsampling relay op (#6273) * implementing upsampling op * fix lint * fix lint again * add doc to upsampling shape func * fix set attrs build problem * fixing imports * reverting data layout transform changes * moved layout template to header file * changing python module from nn.dyn to dyn.nn * adding support for more layouts to upsampling * fix lint * fix upsampling doc * change _nn.py doc * failed flakey test * fix build after merge --- python/tvm/relay/op/dyn/nn/_nn.py | 49 ++++++- python/tvm/relay/op/nn/nn.py | 10 +- python/tvm/topi/nn/upsampling.py | 28 +++- src/relay/op/dyn/nn/upsampling.cc | 123 ++++++++++++++++++ src/relay/op/make_op.h | 3 + src/relay/op/nn/upsampling.cc | 31 +---- src/relay/op/nn/upsampling.h | 67 ++++++++++ src/relay/transforms/dynamic_to_static.cc | 15 +++ .../relay/dyn/test_dynamic_op_level2.py | 51 ++++++++ .../relay/test_pass_dynamic_to_static.py | 23 +++- 10 files changed, 359 insertions(+), 41 deletions(-) create mode 100644 src/relay/op/dyn/nn/upsampling.cc create mode 100644 src/relay/op/nn/upsampling.h diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index 141fc22a1e808..a263561006c85 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -15,21 +15,62 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in -"""Backend compiler related feature registration""" +"""Backend compiler related feature registration for dynamic relay ops in nn namespace""" from __future__ import absolute_import +from tvm import topi + +from tvm.runtime import convert from tvm.te.hybrid import script -from ...op import register_shape_func -from ...op import register_broadcast_schedule +from ...op import register_shape_func, register_compute +from ...op import register_injective_schedule, register_broadcast_schedule -# pad +# upsampling +@register_compute("dyn.nn.upsampling") +def compute_upsampling(attrs, inputs, out_dtype): + data = inputs[0] + scale_h = inputs[1] + scale_w = inputs[2] + layout = attrs.layout + method = attrs.method + align_corners = attrs.align_corners + return [topi.nn.upsampling(data, scale_h, scale_w, layout, + method, align_corners, out_dtype.shape)] + +register_injective_schedule("dyn.nn.upsampling") register_broadcast_schedule("dyn.nn.pad") ##################### # Shape functions # ##################### +# upsampling +@script +def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis): + out = output_tensor((4,), "int64") + out[0] = int64(dshape[0]) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[channel_axis] = int64(dshape[channel_axis]) + return out + +@register_shape_func("dyn.nn.upsampling", True) +def upsampling_shape_func(attrs, inputs, _): + """Shape function for upsampling. Supports NCHW and NHWC layouts.""" + layout = attrs.layout + height_axis = width_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2], + convert(height_axis), convert(width_axis), + convert(channel_axis))] +# pad @script def _dyn_pad_shape_func(data, pad_width): ndim = len(data.shape) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index c04db3060f97d..6f6849a79a9a3 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1152,10 +1152,10 @@ def upsampling(data, data : tvm.relay.Expr The input data to the operator. - scale_h : tvm.relay.Expr + scale_h : tvm.relay.Expr or int or float The scale factor for height upsampling. - scale_w : tvm.relay.Expr + scale_w : tvm.relay.Expr or int or float The scale factor for width upsampling. layout : str, optional @@ -1172,6 +1172,12 @@ def upsampling(data, result : tvm.relay.Expr The computed result. """ + if isinstance(scale_h, Expr) or isinstance(scale_w, Expr): + if not isinstance(scale_h, Expr): + scale_h = const(scale_h, "float64") + if not isinstance(scale_w, Expr): + scale_w = const(scale_w, "float64") + return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners) return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index d8da41f79c0aa..db4af06e06945 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -21,7 +21,7 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', - align_corners=False): + align_corners=False, output_shape=None): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -52,16 +52,30 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)), - simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype))) + if not output_shape: #static case + scaled_h = data.shape[2] * scale_h + scaled_w = data.shape[3] * scale_w + reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype))) + else: #dynamic case -- we don't need to scale; already done in shape func + reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)), + simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)), - simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype))) + if not output_shape: #static case + scaled_h = data.shape[1] * scale_h + scaled_w = data.shape[2] * scale_w + reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype))) + else: #dynamic case + reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)), + simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype))) + else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" - return topi.image.resize(data, out_shape, layout=layout, - method=method, coordinate_transformation_mode=coord_trans) + return topi.image.resize(data, reshape_size, layout=layout, + method=method, coordinate_transformation_mode=coord_trans, + output_shape=output_shape) def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc new file mode 100644 index 0000000000000..e2718481ac8c1 --- /dev/null +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file upsampling.cc + * \brief upsampling operator + */ + +#include "../../nn/upsampling.h" + +#include +#include +#include +#include + +#include + +#include "../../op_common.h" + +namespace tvm { +namespace relay { +namespace dyn { + +bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, scale_h_type, scale_w_type, ret_type] + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* scale_h = types[1].as(); + const auto* scale_w = types[2].as(); + if (data == nullptr) return false; + if (scale_h == nullptr) return false; + if (scale_w == nullptr) return false; + + CHECK_EQ(data->shape.size(), 4); + CHECK_EQ(scale_h->shape.size(), 0); + CHECK_EQ(scale_w->shape.size(), 0); + static const Layout kNCHW("NCHW"); + + const UpSamplingAttrs* param = attrs.as(); + CHECK(param); + const Layout in_layout(param->layout); + + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + CHECK(layout_converter.defined()) + << "UpSampling only supports input layouts that are convertible from NCHW." + << " But got " << in_layout; + + auto nchw_oshape = layout_converter.ForwardShape(data->shape); + + nchw_oshape.Set(2, Any()); + nchw_oshape.Set(3, Any()); + auto oshape = layout_converter.BackwardShape(nchw_oshape); + + reporter->Assign(types[3], TensorType(oshape, data->dtype)); + return true; +} + +// Positional relay function to create upsampling operator +// used by frontend FFI. +Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method, + bool align_corners) { + auto attrs = make_object(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->align_corners = align_corners; + + static const Op& op = Op::Get("dyn.nn.upsampling"); + return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling); + +RELAY_REGISTER_OP("dyn.nn.upsampling") + .describe( + R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **scale_h**: scale_h is an integer of the amount to scale height by + +- **scale_w**: scale_w is an integer of the amount to scale width by + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, in_height*scale, in_width*scale) + + for layout NHWC + (batch_size, in_height*scale, in_width*scale, channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("scale_h", "double", "The scale for the height.") + .add_argument("scale_w", "double", "The scale for the width.") + .set_support_level(2) + .add_type_rel("DynamicUpSampling", UpSamplingRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index c759be338cd93..fb3bf023140ee 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -74,6 +74,9 @@ Expr MakeTile(Expr data, Array reps); Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, + bool align_corners); + Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, bool unbiased); diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index cb20881c1c5f7..bdf3090cefad8 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,11 +21,15 @@ * \file upsampling.cc * \brief upsampling operator */ + +#include "upsampling.h" + #include #include #include #include +#include #include #include "../op_common.h" @@ -36,33 +40,6 @@ namespace relay { TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); -template -Array > UpsamplingInferCorrectLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - // NOTE: Discard "const" qualifier here. - T* params = const_cast(attrs.as()); - - if (new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); - - Layout raw_layout(params->layout); - Layout input = new_in_layouts[0]; - if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && - (input.IndexOf(LayoutAxis::Get('D')) == -1 || - (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && - !input.Contains(LayoutAxis::Get('d'))))) { - params->layout = input.name(); // modify self to follow the input layout - } - } - - Layout inferred_layout(params->layout); - return Array >{{inferred_layout}, {inferred_layout}}; -} - bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); diff --git a/src/relay/op/nn/upsampling.h b/src/relay/op/nn/upsampling.h new file mode 100644 index 0000000000000..e4e3bc9b19298 --- /dev/null +++ b/src/relay/op/nn/upsampling.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/op/nn/upsampling.h + * \brief implementation of the InferCorrectLayout pass for upsampling + */ + +#ifndef TVM_RELAY_OP_NN_UPSAMPLING_H_ +#define TVM_RELAY_OP_NN_UPSAMPLING_H_ + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relay { + +template +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && + (input.IndexOf(LayoutAxis::Get('D')) == -1 || + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout + } + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_NN_UPSAMPLING_H_ diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 3de773eeed9f8..629f5afe9612d 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -124,6 +124,21 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.nn.upsampling"), + [](const CallNode* call_node) { + const ConstantNode* scale_h = call_node->args[1].as(); + const ConstantNode* scale_w = call_node->args[2].as(); + if (scale_h && scale_w) { + CHECK_EQ(scale_h->data->ndim, 0); + CHECK_EQ(scale_w->data->ndim, 0); + const UpSamplingAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), + ToScalar(scale_w->data), param->layout, param->method, + param->align_corners); + } + return Expr(nullptr); + }}, {Op::Get("dyn.nn.pad"), [](const CallNode* call_node) { const ConstantNode* pad_width = call_node->args[1].as(); diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index 137febd19d1b2..e1a0d284d9bfc 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -27,6 +27,55 @@ import tvm.topi.testing from tvm.relay.testing import run_infer_type +def test_dyn_upsampling_run(): + def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False): + + if layout == "NCHW": + (n, c, h, w) = dshape + x_data = np.random.uniform(size=(n, c, h, w)).astype("float32") + + elif layout == "NHWC": + (n, h, w, c) = dshape + x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") + + if method == "nearest_neighbor": + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) + else: + ref_res = tvm.topi.testing.bilinear_resize_python(x_data, (int(round(h*scale_h)), + int(round(w*scale_w))), layout) + x = relay.Var("x", relay.TensorType(dshape, "float32")) + scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) + scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) + + z = relay.nn.upsampling(x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners) + zz = run_infer_type(z) + func = relay.Function([x, scale_h_var, scale_w_var], z) + + for target, ctx in ctx_list(): + if "llvm" not in target: continue + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + + verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NCHW", "bilinear", True) + verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) + +#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable +def test_dyn_upsampling_infer_type_const(): + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + + data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) + scale_h = relay.Var("scale_h", relay.TensorType((), "float32")) + scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) + + z = relay.nn.upsampling(data, 2.0, scale_w) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") + def test_dyn_pad(): def verify_pad(dshape, pad_width, pad_val, dtype): x = relay.var("x", relay.TensorType(dshape, dtype)) @@ -66,3 +115,5 @@ def verify_pad_default_fill(dshape, pad_width, dtype): if __name__ == "__main__": test_dyn_pad() + test_dyn_upsampling_infer_type_const() + test_dyn_upsampling_run() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index ed9b94c5a9d2d..c47d9596c7868 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -320,6 +320,27 @@ def verify_full(fill_value, fill_shape, dtype): verify_full(4, (1, 2, 3, 4), 'int32') verify_full(4.0, (1, 2, 8, 10), 'float32') +def test_dynamic_to_static_upsampling(): + def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): + x = relay.var("x", relay.TensorType(data_shape, dtype)) + scale_h = relay.const(scale_h_val) + scale_w = relay.const(scale_w_val) + z = relay.nn.upsampling(x, scale_h, scale_w) + + func = run_infer_type(relay.Function([x], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("nn.upsampling") + + x_data = np.random.uniform(size=data_shape).astype(dtype) + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") + verify_func(func2, [x_data], ref_res) + + verify_upsampling((1, 16, 32, 32), 2, 2, 'int8') + verify_upsampling((1, 16, 32, 32), 4, 4, 'int32') + def test_dynamic_to_static_pad(): def verify_pad(data_shape, pad_width, pad_val, dtype): x = relay.var("x", relay.TensorType(data_shape, dtype)) @@ -337,7 +358,6 @@ def verify_pad(data_shape, pad_width, pad_val, dtype): verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32") verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64") - if __name__ == "__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() @@ -349,4 +369,5 @@ def verify_pad(data_shape, pad_width, pad_val, dtype): test_dynamic_to_static_resize() test_dynamic_to_static_one_hot() test_dynamic_to_static_full() + test_dynamic_to_static_upsampling() test_dynamic_to_static_pad()