From 9db2f66b5f01218df929ba7f651e5e475d5c571d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 12:48:55 -0700 Subject: [PATCH 01/15] implementing upsampling op --- python/tvm/relay/op/nn/dyn/__init__.py | 20 +++ python/tvm/relay/op/nn/dyn/_make.py | 20 +++ python/tvm/relay/op/nn/dyn/_nn.py | 86 +++++++++++++ python/tvm/relay/op/nn/nn.py | 19 ++- python/tvm/te/hybrid/calls.py | 2 +- python/tvm/te/hybrid/runtime.py | 1 + python/tvm/topi/nn/upsampling.py | 22 +++- src/relay/op/dyn/nn/upsampling.cc | 121 ++++++++++++++++++ src/relay/op/dyn/nn/upsampling.h | 36 ++++++ src/relay/op/make_op.h | 2 + src/relay/transforms/dynamic_to_static.cc | 13 ++ src/tir/ir/data_layout.cc | 6 +- .../relay/dyn/test_dynamic_op_level2.py | 101 +++++++++++++++ .../relay/test_pass_dynamic_to_static.py | 25 +++- 14 files changed, 456 insertions(+), 18 deletions(-) create mode 100644 python/tvm/relay/op/nn/dyn/__init__.py create mode 100644 python/tvm/relay/op/nn/dyn/_make.py create mode 100644 python/tvm/relay/op/nn/dyn/_nn.py create mode 100644 src/relay/op/dyn/nn/upsampling.cc create mode 100644 src/relay/op/dyn/nn/upsampling.h create mode 100644 tests/python/relay/dyn/test_dynamic_op_level2.py diff --git a/python/tvm/relay/op/nn/dyn/__init__.py b/python/tvm/relay/op/nn/dyn/__init__.py new file mode 100644 index 000000000000..01a3a1bc0679 --- /dev/null +++ b/python/tvm/relay/op/nn/dyn/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay namespace containing dynamic ops.""" + +from . import _nn diff --git a/python/tvm/relay/op/nn/dyn/_make.py b/python/tvm/relay/op/nn/dyn/_make.py new file mode 100644 index 000000000000..711dc460726b --- /dev/null +++ b/python/tvm/relay/op/nn/dyn/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.nn.dyn._make", __name__) diff --git a/python/tvm/relay/op/nn/dyn/_nn.py b/python/tvm/relay/op/nn/dyn/_nn.py new file mode 100644 index 000000000000..9ed7a887f23d --- /dev/null +++ b/python/tvm/relay/op/nn/dyn/_nn.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in +"""Backend compiler related feature registration""" + +from __future__ import absolute_import + +import tvm +from tvm import topi + +from tvm import te +from tvm.topi.util import get_const_tuple + +from tvm.runtime import convert +from tvm.te.hybrid import script +from tvm.tir import layout, bijective_layout +from ...op import register_shape_func, register_compute +from ...op import register_injective_schedule, register_broadcast_schedule +from .._nn import _pad_shape_func + +# upsampling +@register_compute("nn.dyn.upsampling") +def compute_upsampling(attrs, inputs, out_dtype): + data = inputs[0] + scale_h = inputs[1] + scale_w = inputs[2] + layout = attrs.layout + method = attrs.method + align_corners = attrs.align_corners + return [topi.nn.upsampling(data, scale_h, scale_w, layout, method, align_corners, out_dtype.shape)] + +register_injective_schedule("nn.dyn.upsampling") + +##################### +# Shape functions # +##################### + +# upsampling + +@script +def _upsampling_nhwc_shape_func(dshape, scale_h, scale_w, ndim): + out = output_tensor((ndim,), "int64") + batch_size = dshape[0] + in_height = dshape[1] + in_width = dshape[2] + channels = dshape[3] + out[0] = int64(batch_size) + out[1] = int64(round(in_height * scale_h[0])) + out[2] = int64(round(in_width * scale_w[0])) + out[3] = int64(channels) + return out + +@script +def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): + out = output_tensor((ndim,), "int64") + batch_size = dshape[0] + channels = dshape[1] + in_height = dshape[2] + in_width = dshape[3] + out[0] = int64(batch_size) + out[1] = int64(channels) + out[2] = int64(round(in_height * scale_h[0])) + out[3] = int64(round(in_width * scale_w[0])) + return out + +@register_shape_func("nn.dyn.upsampling", True) +def upsampling_shape_func(attrs, inputs, _): + if (attrs.layout == "NHWC"): + return [_upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] + if (attrs.layout == "NCHW"): + return [_upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] + diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b2df8505e691..a486531e1b08 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -16,10 +16,14 @@ # under the License. #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" +import tvm +from tvm import relay from tvm.relay import expr from . import _make +from .dyn import _make as _dyn_make from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d +from ...expr import const, Expr def conv1d(data, @@ -1147,13 +1151,13 @@ def upsampling(data, Parameters ---------- - data : tvm.relay.Expr + data : tvm.relay.Expr or tuple or list The input data to the operator. - scale_h : tvm.relay.Expr + scale_h : tvm.relay.Expr or int or float The scale factor for height upsampling. - scale_w : tvm.relay.Expr + scale_w : tvm.relay.Expr or int or float The scale factor for width upsampling. layout : str, optional @@ -1170,7 +1174,14 @@ def upsampling(data, result : tvm.relay.Expr The computed result. """ - return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) + if isinstance(scale_h, Expr) or isinstance(scale_w, Expr): + if not isinstance(scale_h, Expr): + scale_h = const(scale_h, "float64") + if not isinstance(scale_w, Expr): + scale_w = const(scale_w, "float64") + return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners) + else: + return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) def upsampling3d(data, diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 78ed1dce3a44..88ade6e49294 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -73,7 +73,7 @@ def _math_intrin(func_id, args): from tvm.tir import op return getattr(op, func_id)(*args) -sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name +sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name def _min_max(func_id, args): diff --git a/python/tvm/te/hybrid/runtime.py b/python/tvm/te/hybrid/runtime.py index 7dcfc7c3966b..7987e46a4768 100644 --- a/python/tvm/te/hybrid/runtime.py +++ b/python/tvm/te/hybrid/runtime.py @@ -126,6 +126,7 @@ def max_num_threads(allow_none=True): 'exp' : numpy.exp, 'sigmoid' : sigmoid, 'popcount' : popcount, + 'round' : round, 'likely' : lambda cond: cond, 'uint8' : numpy.uint8, 'uint16' : numpy.uint16, diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index 96a13efc541a..e31aedf1e74a 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -21,7 +21,7 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', - align_corners=False): + align_corners=False, output_shape=None): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -52,17 +52,25 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)), - simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype))) + if not output_shape: #static case + reshape_size = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)), + simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype))) + else: #dynamic case -- we don't need to scale; already done in shape func + reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)), + simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)), - simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype))) + if not output_shape: #static case + reshape_size = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)), + simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype))) + else: #dynamic case + reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)), + simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" - return topi.image.resize(data, out_shape, layout=layout, - method=method, coordinate_transformation_mode=coord_trans) + return topi.image.resize(data, reshape_size, layout=layout, + method=method, coordinate_transformation_mode=coord_trans, output_shape=output_shape) def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc new file mode 100644 index 000000000000..2adb49aca457 --- /dev/null +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file upsampling.cc + * \brief upsampling operator + */ +#include +#include +#include +#include + +#include + +#include "../../op_common.h" +#include "../nn/upsampling.h" + +namespace tvm { +namespace relay { +namespace dyn { + +bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, scale_h_type, scale_w_type, ret_type] + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* scale_h = types[1].as(); + const auto* scale_w = types[2].as(); + if (data == nullptr) return false; + if (scale_h == nullptr) return false; + if (scale_w == nullptr) return false; + + CHECK_EQ(data->shape.size(), 4); + CHECK_EQ(scale_h->shape.size(), 0); + CHECK_EQ(scale_w->shape.size(), 0); + static const Layout kNCHW("NCHW"); + + const UpSamplingAttrs* param = attrs.as(); + CHECK(param); + const Layout in_layout(param->layout); + + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + CHECK(layout_converter.defined()) + << "UpSampling only supports input layouts that are convertible from NCHW." + << " But got " << in_layout; + + auto nchw_oshape = layout_converter.ForwardShape(data->shape); + + nchw_oshape.Set(2, Any()); + nchw_oshape.Set(3, Any()); + auto oshape = layout_converter.BackwardShape(nchw_oshape); + + reporter->Assign(types[3], TensorType(oshape, data->dtype)); + return true; +} + +// Positional relay function to create upsampling operator +// used by frontend FFI. +Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method, + bool align_corners) { + auto attrs = make_object(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->align_corners = align_corners; + + static const Op& op = Op::Get("nn.dyn.upsampling"); + return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn.dyn._make.upsampling").set_body_typed(MakeUpSampling); + +RELAY_REGISTER_OP("nn.dyn.upsampling") + .describe( + R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **scale_h**: scale_h is an integer of the amount to scale height by + +- **scale_w**: scale_w is an integer of the amount to scale width by + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, in_height*scale, in_width*scale) + + for layout NHWC + (batch_size, in_height*scale, in_width*scale, channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("scale_h", "double", "The scale for the height.") + .add_argument("scale_w", "double", "The scale for the width.") + .set_support_level(2) + .add_type_rel("DynamicUpSampling", UpSamplingRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/dyn/nn/upsampling.h b/src/relay/op/dyn/nn/upsampling.h new file mode 100644 index 000000000000..2f1e4a87eaa4 --- /dev/null +++ b/src/relay/op/dyn/nn/upsampling.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/op.nn/upsampling.h + * \brief Header of the upsampling file for methods that need to be accessed by multiple files + */ + +namespace tvm { +namespace relay { + +template +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types); + +} +} diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 1e17bbe90692..5235572336d0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -74,6 +74,8 @@ Expr MakeTile(Expr data, Array reps); Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, bool align_corners); + Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, bool unbiased); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 0ccc4c3d1269..b121bd3df5ee 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -124,6 +124,19 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("nn.dyn.upsampling"), + [](const CallNode* call_node) { + const ConstantNode* scale_h = call_node->args[1].as(); + const ConstantNode* scale_w = call_node->args[2].as(); + if (scale_h && scale_w) { + CHECK_EQ(scale_h->data->ndim, 0); + CHECK_EQ(scale_w->data->ndim, 0); + const UpSamplingAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), ToScalar(scale_w->data), param->layout, param->method, param->align_corners); + } + return Expr(nullptr); + }}, }; } diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index bc777db55dbe..a2a70514b942 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -320,11 +320,7 @@ inline Array TransformShape(const Array& src_shape, if (!LayoutAxis::Get(axis).IsPrimal()) { result.push_back(axis->dom->extent); } else { - if (symbolic_var_set.count(i)) { - result.push_back(tir::Any()); - } else { - result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); - } + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } } return result; diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py new file mode 100644 index 000000000000..5fe4b428cb7a --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level2 dynamic operator test cases. +""" + +import numpy as np +import tvm +from tvm import relay +from tvm import te +from tvm.relay.testing import ctx_list +import random +from test_dynamic_op_level3 import verify_func +import tvm.topi.testing +from tvm.relay.testing import run_infer_type + +def test_dyn_upsampling_run(): + def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False): + + if layout == "NCHW": + (n, c, h, w) = dshape + x_data = np.random.uniform(size=(n, c, h, w)).astype("float32") + + elif layout == "NHWC": + (n, h, w, c) = dshape + x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") + + + if method == "nearest_neighbor": + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) + else: + ref_res = tvm.topi.testing.bilinear_resize_python(x_data, (int(round(h*scale_h)), + int(round(w*scale_w))), layout) + x = relay.Var("x", relay.TensorType(dshape, "float32")) + scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) + scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) + + z = relay.nn.upsampling(x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners) + zz = run_infer_type(z) + func = relay.Function([x, scale_h_var, scale_w_var], z) + + for target, ctx in ctx_list(): + if "llvm" not in target: continue + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + + verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "bilinear", True) + verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) + +# tests upsampling type inference with scale_h and scale_w as variables +def test_upsampling_infer_type(): + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + + data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) + scale_h = relay.Var("scale_h", relay.TensorType((), "float32")) + scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) + + z = relay.nn.upsampling(data, scale_h, scale_w) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") + + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + data = relay.var("data", relay.TensorType((n, h, w, c), "int8")) + z2 = relay.nn.upsampling(data, scale_h, scale_w, layout="NHWC") + zz2 = run_infer_type(z2) + assert zz2.checked_type == relay.TensorType((n, relay.Any(), relay.Any(), c), "int8") + +#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable +def test_upsampling_infer_type_const(): + n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") + + data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) + scale_h = relay.Var("scale_h", relay.TensorType((), "float32")) + scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) + + z = relay.nn.upsampling(data, 2.0, scale_w) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") + +if __name__ == "__main__": + test_upsampling_infer_type() + test_upsampling_infer_type_const() + test_dyn_upsampling_run() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index c61f169d53e0..32ed817e621e 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -312,7 +312,7 @@ def verify_full(fill_value, fill_shape, dtype): zz = func2.body assert isinstance(zz, relay.Call) - assert zz.checked_type == relay.TensorType(fill_shape, dtype) + assert zz.op == relay.op.get("full") ref_res = np.full(fill_shape, fill_value).astype(dtype) y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64') @@ -321,6 +321,28 @@ def verify_full(fill_value, fill_shape, dtype): verify_full(4, (1, 2, 3, 4), 'int32') verify_full(4.0, (1, 2, 8, 10), 'float32') +def test_dynamic_to_static_upsampling(): + def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): + x = relay.var("x", relay.TensorType(data_shape, dtype)) + scale_h = relay.const(scale_h_val) + scale_w = relay.const(scale_w_val) + z = relay.nn.upsampling(x, scale_h, scale_w) + + func = run_infer_type(relay.Function([x], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("nn.upsampling") + + x_data = np.random.uniform(size=data_shape).astype(dtype) + ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") + verify_func(func2, [x_data], ref_res) + + verify_upsampling((1, 16, 32, 32), 2, 2, 'int8') + verify_upsampling((1, 16, 32, 32), 4, 4, 'int32') + + if __name__ == "__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() @@ -332,3 +354,4 @@ def verify_full(fill_value, fill_shape, dtype): test_dynamic_to_static_resize() test_dynamic_to_static_one_hot() test_dynamic_to_static_full() + test_dynamic_to_static_upsampling() From c0a534fed5f286a04266b76d98183cb78671a4e4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 13:53:11 -0700 Subject: [PATCH 02/15] fix lint --- python/tvm/relay/op/nn/dyn/_nn.py | 37 ++++++++----------- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/topi/nn/upsampling.py | 19 ++++++---- src/relay/op/dyn/nn/upsampling.cc | 7 ++-- src/relay/op/dyn/nn/upsampling.h | 7 +++- src/relay/op/make_op.h | 3 +- src/relay/transforms/dynamic_to_static.cc | 28 +++++++------- .../relay/dyn/test_dynamic_op_level2.py | 1 - .../relay/test_pass_dynamic_to_static.py | 4 +- 9 files changed, 57 insertions(+), 51 deletions(-) diff --git a/python/tvm/relay/op/nn/dyn/_nn.py b/python/tvm/relay/op/nn/dyn/_nn.py index 9ed7a887f23d..0954a4abea60 100644 --- a/python/tvm/relay/op/nn/dyn/_nn.py +++ b/python/tvm/relay/op/nn/dyn/_nn.py @@ -19,18 +19,12 @@ from __future__ import absolute_import -import tvm from tvm import topi -from tvm import te -from tvm.topi.util import get_const_tuple - from tvm.runtime import convert from tvm.te.hybrid import script -from tvm.tir import layout, bijective_layout from ...op import register_shape_func, register_compute -from ...op import register_injective_schedule, register_broadcast_schedule -from .._nn import _pad_shape_func +from ...op import register_injective_schedule # upsampling @register_compute("nn.dyn.upsampling") @@ -41,7 +35,8 @@ def compute_upsampling(attrs, inputs, out_dtype): layout = attrs.layout method = attrs.method align_corners = attrs.align_corners - return [topi.nn.upsampling(data, scale_h, scale_w, layout, method, align_corners, out_dtype.shape)] + return [topi.nn.upsampling(data, scale_h, scale_w, layout, + method, align_corners, out_dtype.shape)] register_injective_schedule("nn.dyn.upsampling") @@ -50,7 +45,6 @@ def compute_upsampling(attrs, inputs, out_dtype): ##################### # upsampling - @script def _upsampling_nhwc_shape_func(dshape, scale_h, scale_w, ndim): out = output_tensor((ndim,), "int64") @@ -66,21 +60,20 @@ def _upsampling_nhwc_shape_func(dshape, scale_h, scale_w, ndim): @script def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): - out = output_tensor((ndim,), "int64") - batch_size = dshape[0] - channels = dshape[1] - in_height = dshape[2] - in_width = dshape[3] - out[0] = int64(batch_size) - out[1] = int64(channels) - out[2] = int64(round(in_height * scale_h[0])) - out[3] = int64(round(in_width * scale_w[0])) - return out + out = output_tensor((ndim,), "int64") + batch_size = dshape[0] + channels = dshape[1] + in_height = dshape[2] + in_width = dshape[3] + out[0] = int64(batch_size) + out[1] = int64(channels) + out[2] = int64(round(in_height * scale_h[0])) + out[3] = int64(round(in_width * scale_w[0])) + return out @register_shape_func("nn.dyn.upsampling", True) def upsampling_shape_func(attrs, inputs, _): - if (attrs.layout == "NHWC"): + if attrs.layout == "NHWC": return [_upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] - if (attrs.layout == "NCHW"): + if attrs.layout == "NCHW": return [_upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] - diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index a486531e1b08..7790ba134638 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1180,7 +1180,7 @@ def upsampling(data, if not isinstance(scale_w, Expr): scale_w = const(scale_w, "float64") return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners) - else: + else: return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index e31aedf1e74a..db4af06e0694 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -53,24 +53,29 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', base_layout = layout[0:4] if base_layout == "NCHW": if not output_shape: #static case - reshape_size = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)), - simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype))) + scaled_h = data.shape[2] * scale_h + scaled_w = data.shape[3] * scale_w + reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype))) else: #dynamic case -- we don't need to scale; already done in shape func reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)), - simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype))) + simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype))) elif layout == "NHWC": if not output_shape: #static case - reshape_size = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)), - simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype))) + scaled_h = data.shape[1] * scale_h + scaled_w = data.shape[2] * scale_w + reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype))) else: #dynamic case reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)), - simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype))) + simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" return topi.image.resize(data, reshape_size, layout=layout, - method=method, coordinate_transformation_mode=coord_trans, output_shape=output_shape) + method=method, coordinate_transformation_mode=coord_trans, + output_shape=output_shape) def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 2adb49aca457..0f2e4b0f5b8b 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -21,6 +21,8 @@ * \file upsampling.cc * \brief upsampling operator */ +#include "../nn/upsampling.h" + #include #include #include @@ -29,7 +31,6 @@ #include #include "../../op_common.h" -#include "../nn/upsampling.h" namespace tvm { namespace relay { @@ -37,7 +38,7 @@ namespace dyn { bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types = [data_type, scale_h_type, scale_w_type, ret_type] + // types = [data_type, scale_h_type, scale_w_type, ret_type] CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* scale_h = types[1].as(); @@ -45,7 +46,7 @@ bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, if (data == nullptr) return false; if (scale_h == nullptr) return false; if (scale_w == nullptr) return false; - + CHECK_EQ(data->shape.size(), 4); CHECK_EQ(scale_h->shape.size(), 0); CHECK_EQ(scale_w->shape.size(), 0); diff --git a/src/relay/op/dyn/nn/upsampling.h b/src/relay/op/dyn/nn/upsampling.h index 2f1e4a87eaa4..58c7c3bef8fe 100644 --- a/src/relay/op/dyn/nn/upsampling.h +++ b/src/relay/op/dyn/nn/upsampling.h @@ -23,6 +23,9 @@ * \brief Header of the upsampling file for methods that need to be accessed by multiple files */ +#ifndef TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ +#define TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ + namespace tvm { namespace relay { @@ -33,4 +36,6 @@ Array > UpsamplingInferCorrectLayout(const Attrs& attrs, const Array& old_in_types); } -} +} // namespace tvm + +#endif // TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 5235572336d0..4969d3701f0b 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -74,7 +74,8 @@ Expr MakeTile(Expr data, Array reps); Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); -Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, bool align_corners); +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, + bool align_corners); Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, bool unbiased); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index b121bd3df5ee..5c54e158be20 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -124,19 +124,21 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, - {Op::Get("nn.dyn.upsampling"), - [](const CallNode* call_node) { - const ConstantNode* scale_h = call_node->args[1].as(); - const ConstantNode* scale_w = call_node->args[2].as(); - if (scale_h && scale_w) { - CHECK_EQ(scale_h->data->ndim, 0); - CHECK_EQ(scale_w->data->ndim, 0); - const UpSamplingAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), ToScalar(scale_w->data), param->layout, param->method, param->align_corners); - } - return Expr(nullptr); - }}, + {Op::Get("nn.dyn.upsampling"), + [](const CallNode* call_node) { + const ConstantNode* scale_h = call_node->args[1].as(); + const ConstantNode* scale_w = call_node->args[2].as(); + if (scale_h && scale_w) { + CHECK_EQ(scale_h->data->ndim, 0); + CHECK_EQ(scale_w->data->ndim, 0); + const UpSamplingAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), + ToScalar(scale_w->data), param->layout, param->method, + param->align_corners); + } + return Expr(nullptr); + }}, }; } diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index 5fe4b428cb7a..637f47ae5e9b 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -37,7 +37,6 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa elif layout == "NHWC": (n, h, w, c) = dshape x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") - if method == "nearest_neighbor": ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 32ed817e621e..2afb827af5c0 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -330,7 +330,7 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): func = run_infer_type(relay.Function([x], z)) func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) - + zz = func2.body assert isinstance(zz, relay.Call) assert zz.op == relay.op.get("nn.upsampling") @@ -338,7 +338,7 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): x_data = np.random.uniform(size=data_shape).astype(dtype) ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") verify_func(func2, [x_data], ref_res) - + verify_upsampling((1, 16, 32, 32), 2, 2, 'int8') verify_upsampling((1, 16, 32, 32), 4, 4, 'int32') From d88029822ad22f11305ae211a3fec071bfc36a41 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 15:11:19 -0700 Subject: [PATCH 03/15] fix lint again --- python/tvm/relay/op/nn/dyn/_nn.py | 11 ++++++++--- python/tvm/relay/op/nn/nn.py | 5 +---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/nn/dyn/_nn.py b/python/tvm/relay/op/nn/dyn/_nn.py index 0954a4abea60..89321756456b 100644 --- a/python/tvm/relay/op/nn/dyn/_nn.py +++ b/python/tvm/relay/op/nn/dyn/_nn.py @@ -74,6 +74,11 @@ def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): @register_shape_func("nn.dyn.upsampling", True) def upsampling_shape_func(attrs, inputs, _): if attrs.layout == "NHWC": - return [_upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] - if attrs.layout == "NCHW": - return [_upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape)))] + shape_func = _upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], + convert(len(inputs[0].shape))) + elif attrs.layout == "NCHW": + shape_func = _upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2], + convert(len(inputs[0].shape))) + else: + assert false, "Layout passed to the upsampling shape func must be NCHW or NHWC" + return [shape_func] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 7790ba134638..f233cd39d7c4 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -16,8 +16,6 @@ # under the License. #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" -import tvm -from tvm import relay from tvm.relay import expr from . import _make @@ -1180,8 +1178,7 @@ def upsampling(data, if not isinstance(scale_w, Expr): scale_w = const(scale_w, "float64") return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners) - else: - return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) + return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) def upsampling3d(data, From 476b06fc1096c9694cee7e9e874307b9252c6ab6 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 15:16:10 -0700 Subject: [PATCH 04/15] add doc to upsampling shape func --- python/tvm/relay/op/nn/dyn/_nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/op/nn/dyn/_nn.py b/python/tvm/relay/op/nn/dyn/_nn.py index 89321756456b..ee480efc3655 100644 --- a/python/tvm/relay/op/nn/dyn/_nn.py +++ b/python/tvm/relay/op/nn/dyn/_nn.py @@ -73,6 +73,7 @@ def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): @register_shape_func("nn.dyn.upsampling", True) def upsampling_shape_func(attrs, inputs, _): + """Shape function for upsampling. Supports NCHW and NHWC layouts.""" if attrs.layout == "NHWC": shape_func = _upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(len(inputs[0].shape))) From b2afd72975a233a435c945212e583c2479588406 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 16:12:08 -0700 Subject: [PATCH 05/15] fix set attrs build problem --- src/relay/op/dyn/nn/upsampling.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 0f2e4b0f5b8b..8b65ca7343b3 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -21,8 +21,6 @@ * \file upsampling.cc * \brief upsampling operator */ -#include "../nn/upsampling.h" - #include #include #include @@ -31,6 +29,7 @@ #include #include "../../op_common.h" +#include "../nn/upsampling.h" namespace tvm { namespace relay { @@ -38,7 +37,7 @@ namespace dyn { bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types = [data_type, scale_h_type, scale_w_type, ret_type] + // types = [data_type, scale_h_type, scale_w_type, ret_type] CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* scale_h = types[1].as(); @@ -114,7 +113,7 @@ RELAY_REGISTER_OP("nn.dyn.upsampling") .set_support_level(2) .add_type_rel("DynamicUpSampling", UpSamplingRel) .set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) + UpsamplingInferCorrectLayout) .set_attr("TOpPattern", kInjective); } // namespace dyn From 11140570eed99ef0f8429f2ebf128b7ad7573b21 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 13 Aug 2020 16:28:40 -0700 Subject: [PATCH 06/15] fixing imports --- src/relay/op/dyn/nn/upsampling.cc | 8 +++++--- src/relay/op/dyn/nn/upsampling.h | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 8b65ca7343b3..9bbe68f1c691 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -21,6 +21,9 @@ * \file upsampling.cc * \brief upsampling operator */ + +#include "../nn/upsampling.h" + #include #include #include @@ -29,7 +32,6 @@ #include #include "../../op_common.h" -#include "../nn/upsampling.h" namespace tvm { namespace relay { @@ -37,7 +39,7 @@ namespace dyn { bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types = [data_type, scale_h_type, scale_w_type, ret_type] + // types = [data_type, scale_h_type, scale_w_type, ret_type] CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* scale_h = types[1].as(); @@ -113,7 +115,7 @@ RELAY_REGISTER_OP("nn.dyn.upsampling") .set_support_level(2) .add_type_rel("DynamicUpSampling", UpSamplingRel) .set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) + UpsamplingInferCorrectLayout) .set_attr("TOpPattern", kInjective); } // namespace dyn diff --git a/src/relay/op/dyn/nn/upsampling.h b/src/relay/op/dyn/nn/upsampling.h index 58c7c3bef8fe..c42543a11916 100644 --- a/src/relay/op/dyn/nn/upsampling.h +++ b/src/relay/op/dyn/nn/upsampling.h @@ -26,6 +26,11 @@ #ifndef TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ #define TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ +#include +#include + +#include "../../op_common.h" + namespace tvm { namespace relay { From 9d5e107dd53ef4d370e5f15999a2d6f6eac8c14d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 14 Aug 2020 12:56:31 -0700 Subject: [PATCH 07/15] reverting data layout transform changes --- src/tir/ir/data_layout.cc | 6 +++++- .../python/relay/dyn/test_dynamic_op_level2.py | 18 ------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index a2a70514b942..bc777db55dbe 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -320,7 +320,11 @@ inline Array TransformShape(const Array& src_shape, if (!LayoutAxis::Get(axis).IsPrimal()) { result.push_back(axis->dom->extent); } else { - result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); + if (symbolic_var_set.count(i)) { + result.push_back(tir::Any()); + } else { + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); + } } } return result; diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index 637f47ae5e9b..e7af31faae5d 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -64,24 +64,6 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor") verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) -# tests upsampling type inference with scale_h and scale_w as variables -def test_upsampling_infer_type(): - n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") - - data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) - scale_h = relay.Var("scale_h", relay.TensorType((), "float32")) - scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) - - z = relay.nn.upsampling(data, scale_h, scale_w) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") - - n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") - data = relay.var("data", relay.TensorType((n, h, w, c), "int8")) - z2 = relay.nn.upsampling(data, scale_h, scale_w, layout="NHWC") - zz2 = run_infer_type(z2) - assert zz2.checked_type == relay.TensorType((n, relay.Any(), relay.Any(), c), "int8") - #tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable def test_upsampling_infer_type_const(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") From faa6e738b01db1f7fb52adccea8bfeaec773ea50 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 11:15:03 -0700 Subject: [PATCH 08/15] moved layout template to header file --- src/relay/op/dyn/nn/upsampling.cc | 2 +- src/relay/op/nn/upsampling.cc | 31 +++-------------------- src/relay/op/{dyn => }/nn/upsampling.h | 35 ++++++++++++++++++++------ 3 files changed, 33 insertions(+), 35 deletions(-) rename src/relay/op/{dyn => }/nn/upsampling.h (51%) diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 9bbe68f1c691..355edda4b3a1 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -22,7 +22,7 @@ * \brief upsampling operator */ -#include "../nn/upsampling.h" +#include "../../nn/upsampling.h" #include #include diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index cb20881c1c5f..bdf3090cefad 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,11 +21,15 @@ * \file upsampling.cc * \brief upsampling operator */ + +#include "upsampling.h" + #include #include #include #include +#include #include #include "../op_common.h" @@ -36,33 +40,6 @@ namespace relay { TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); -template -Array > UpsamplingInferCorrectLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - // NOTE: Discard "const" qualifier here. - T* params = const_cast(attrs.as()); - - if (new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); - - Layout raw_layout(params->layout); - Layout input = new_in_layouts[0]; - if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && - (input.IndexOf(LayoutAxis::Get('D')) == -1 || - (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && - !input.Contains(LayoutAxis::Get('d'))))) { - params->layout = input.name(); // modify self to follow the input layout - } - } - - Layout inferred_layout(params->layout); - return Array >{{inferred_layout}, {inferred_layout}}; -} - bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); diff --git a/src/relay/op/dyn/nn/upsampling.h b/src/relay/op/nn/upsampling.h similarity index 51% rename from src/relay/op/dyn/nn/upsampling.h rename to src/relay/op/nn/upsampling.h index c42543a11916..e4e3bc9b1929 100644 --- a/src/relay/op/dyn/nn/upsampling.h +++ b/src/relay/op/nn/upsampling.h @@ -19,17 +19,17 @@ /*! * - * \file src/relay/op.nn/upsampling.h - * \brief Header of the upsampling file for methods that need to be accessed by multiple files + * \file src/relay/op/nn/upsampling.h + * \brief implementation of the InferCorrectLayout pass for upsampling */ -#ifndef TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ -#define TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ +#ifndef TVM_RELAY_OP_NN_UPSAMPLING_H_ +#define TVM_RELAY_OP_NN_UPSAMPLING_H_ #include #include -#include "../../op_common.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -38,9 +38,30 @@ template Array > UpsamplingInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array& old_in_types); + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && + (input.IndexOf(LayoutAxis::Get('D')) == -1 || + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout + } + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; } + +} // namespace relay } // namespace tvm -#endif // TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ +#endif // TVM_RELAY_OP_NN_UPSAMPLING_H_ From 6509615ae7e94d09e53c9b9bfe9146b2056b85aa Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 12:11:12 -0700 Subject: [PATCH 09/15] changing python module from nn.dyn to dyn.nn --- python/tvm/relay/op/{nn/dyn => dyn/nn}/__init__.py | 0 python/tvm/relay/op/{nn/dyn => dyn/nn}/_make.py | 2 +- python/tvm/relay/op/{nn/dyn => dyn/nn}/_nn.py | 6 +++--- python/tvm/relay/op/nn/nn.py | 2 +- src/relay/op/dyn/nn/upsampling.cc | 6 +++--- src/relay/transforms/dynamic_to_static.cc | 2 +- tests/python/relay/dyn/test_dynamic_op_level2.py | 5 ++--- 7 files changed, 11 insertions(+), 12 deletions(-) rename python/tvm/relay/op/{nn/dyn => dyn/nn}/__init__.py (100%) rename python/tvm/relay/op/{nn/dyn => dyn/nn}/_make.py (93%) rename python/tvm/relay/op/{nn/dyn => dyn/nn}/_nn.py (95%) diff --git a/python/tvm/relay/op/nn/dyn/__init__.py b/python/tvm/relay/op/dyn/nn/__init__.py similarity index 100% rename from python/tvm/relay/op/nn/dyn/__init__.py rename to python/tvm/relay/op/dyn/nn/__init__.py diff --git a/python/tvm/relay/op/nn/dyn/_make.py b/python/tvm/relay/op/dyn/nn/_make.py similarity index 93% rename from python/tvm/relay/op/nn/dyn/_make.py rename to python/tvm/relay/op/dyn/nn/_make.py index 711dc460726b..280fe72315ad 100644 --- a/python/tvm/relay/op/nn/dyn/_make.py +++ b/python/tvm/relay/op/dyn/nn/_make.py @@ -17,4 +17,4 @@ """Constructor APIs""" import tvm._ffi -tvm._ffi._init_api("relay.op.nn.dyn._make", __name__) +tvm._ffi._init_api("relay.op.dyn.nn._make", __name__) diff --git a/python/tvm/relay/op/nn/dyn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py similarity index 95% rename from python/tvm/relay/op/nn/dyn/_nn.py rename to python/tvm/relay/op/dyn/nn/_nn.py index ee480efc3655..dca49ced3a74 100644 --- a/python/tvm/relay/op/nn/dyn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -27,7 +27,7 @@ from ...op import register_injective_schedule # upsampling -@register_compute("nn.dyn.upsampling") +@register_compute("dyn.nn.upsampling") def compute_upsampling(attrs, inputs, out_dtype): data = inputs[0] scale_h = inputs[1] @@ -38,7 +38,7 @@ def compute_upsampling(attrs, inputs, out_dtype): return [topi.nn.upsampling(data, scale_h, scale_w, layout, method, align_corners, out_dtype.shape)] -register_injective_schedule("nn.dyn.upsampling") +register_injective_schedule("dyn.nn.upsampling") ##################### # Shape functions # @@ -71,7 +71,7 @@ def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): out[3] = int64(round(in_width * scale_w[0])) return out -@register_shape_func("nn.dyn.upsampling", True) +@register_shape_func("dyn.nn.upsampling", True) def upsampling_shape_func(attrs, inputs, _): """Shape function for upsampling. Supports NCHW and NHWC layouts.""" if attrs.layout == "NHWC": diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index f233cd39d7c4..288ab15c5fec 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -19,7 +19,7 @@ from tvm.relay import expr from . import _make -from .dyn import _make as _dyn_make +from ..dyn.nn import _make as _dyn_make from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d from ...expr import const, Expr diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 355edda4b3a1..e2718481ac8c 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -81,13 +81,13 @@ Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String attrs->method = std::move(method); attrs->align_corners = align_corners; - static const Op& op = Op::Get("nn.dyn.upsampling"); + static const Op& op = Op::Get("dyn.nn.upsampling"); return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn.dyn._make.upsampling").set_body_typed(MakeUpSampling); +TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling); -RELAY_REGISTER_OP("nn.dyn.upsampling") +RELAY_REGISTER_OP("dyn.nn.upsampling") .describe( R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 5c54e158be20..98e0139d3238 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -124,7 +124,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, - {Op::Get("nn.dyn.upsampling"), + {Op::Get("dyn.nn.upsampling"), [](const CallNode* call_node) { const ConstantNode* scale_h = call_node->args[1].as(); const ConstantNode* scale_w = call_node->args[2].as(); diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index e7af31faae5d..d4961ae1c1b3 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -65,7 +65,7 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) #tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable -def test_upsampling_infer_type_const(): +def test_dyn_upsampling_infer_type_const(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) @@ -77,6 +77,5 @@ def test_upsampling_infer_type_const(): assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") if __name__ == "__main__": - test_upsampling_infer_type() - test_upsampling_infer_type_const() + test_dyn_upsampling_infer_type_const() test_dyn_upsampling_run() From 3807bafcc0e20806971622ffc416698c826be878 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 16:50:39 -0700 Subject: [PATCH 10/15] adding support for more layouts to upsampling --- python/tvm/relay/op/dyn/nn/_nn.py | 49 +++++++++++-------------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index dca49ced3a74..79a8009f4fe4 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -46,40 +46,25 @@ def compute_upsampling(attrs, inputs, out_dtype): # upsampling @script -def _upsampling_nhwc_shape_func(dshape, scale_h, scale_w, ndim): - out = output_tensor((ndim,), "int64") - batch_size = dshape[0] - in_height = dshape[1] - in_width = dshape[2] - channels = dshape[3] - out[0] = int64(batch_size) - out[1] = int64(round(in_height * scale_h[0])) - out[2] = int64(round(in_width * scale_w[0])) - out[3] = int64(channels) - return out - -@script -def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim): - out = output_tensor((ndim,), "int64") - batch_size = dshape[0] - channels = dshape[1] - in_height = dshape[2] - in_width = dshape[3] - out[0] = int64(batch_size) - out[1] = int64(channels) - out[2] = int64(round(in_height * scale_h[0])) - out[3] = int64(round(in_width * scale_w[0])) +def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis): + out = output_tensor((4,), "int64") + out[0] = int64(dshape[0]) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[channel_axis] = int64(dshape[channel_axis]) return out @register_shape_func("dyn.nn.upsampling", True) def upsampling_shape_func(attrs, inputs, _): """Shape function for upsampling. Supports NCHW and NHWC layouts.""" - if attrs.layout == "NHWC": - shape_func = _upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2], - convert(len(inputs[0].shape))) - elif attrs.layout == "NCHW": - shape_func = _upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2], - convert(len(inputs[0].shape))) - else: - assert false, "Layout passed to the upsampling shape func must be NCHW or NHWC" - return [shape_func] + layout = attrs.layout + height_axis = width_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2], + convert(height_axis), convert(width_axis), convert(channel_axis))] From e5829fca43334f50fe1c3daf4117d97cb0d46b53 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 17 Aug 2020 16:56:44 -0700 Subject: [PATCH 11/15] fix lint --- python/tvm/relay/op/dyn/nn/_nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index 79a8009f4fe4..ca9d17d1614f 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -67,4 +67,5 @@ def upsampling_shape_func(attrs, inputs, _): if letter == "C": channel_axis = i return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2], - convert(height_axis), convert(width_axis), convert(channel_axis))] + convert(height_axis), convert(width_axis), + convert(channel_axis))] From 2d21a5e9c7871aa5ff81fb97bca66f9587c61727 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 18 Aug 2020 12:48:24 -0700 Subject: [PATCH 12/15] fix upsampling doc --- python/tvm/relay/op/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 288ab15c5fec..11fb3385bf40 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1149,7 +1149,7 @@ def upsampling(data, Parameters ---------- - data : tvm.relay.Expr or tuple or list + data : tvm.relay.Expr The input data to the operator. scale_h : tvm.relay.Expr or int or float From 8c4d8cc6455d1192e1ae9f0194d28e9f5922b1b0 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 19 Aug 2020 10:24:34 -0700 Subject: [PATCH 13/15] change _nn.py doc --- python/tvm/relay/op/dyn/nn/_nn.py | 2 +- tests/python/relay/dyn/test_dynamic_op_level2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index ca9d17d1614f..5184b747061d 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in -"""Backend compiler related feature registration""" +"""Backend compiler related feature registration for dynamic relay ops in nn namespace""" from __future__ import absolute_import diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index d4961ae1c1b3..935e463bbdb9 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -60,7 +60,7 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor") - verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "bilinear", True) + verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NCHW", "bilinear", True) verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor") verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) From 69971671076b94b9917abce7adf9474dafb1f99e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 19 Aug 2020 12:02:56 -0700 Subject: [PATCH 14/15] failed flakey test From 676dfae80bf2625a429417ef061788711b4599c8 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 20 Aug 2020 11:30:49 -0700 Subject: [PATCH 15/15] fix build after merge --- python/tvm/relay/op/dyn/nn/_nn.py | 1 + src/relay/transforms/dynamic_to_static.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index 8e4d1ccf69c5..a263561006c8 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -70,6 +70,7 @@ def upsampling_shape_func(attrs, inputs, _): return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2], convert(height_axis), convert(width_axis), convert(channel_axis))] +# pad @script def _dyn_pad_shape_func(data, pad_width): ndim = len(data.shape) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 66c14ad763eb..629f5afe9612 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -136,6 +136,9 @@ class DynamicToStaticMutator : public MixedModeMutator { return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), ToScalar(scale_w->data), param->layout, param->method, param->align_corners); + } + return Expr(nullptr); + }}, {Op::Get("dyn.nn.pad"), [](const CallNode* call_node) { const ConstantNode* pad_width = call_node->args[1].as();