From df73c8aad6ea9a49da5ae07275ca27008541df28 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Wed, 23 Oct 2019 21:18:05 +0000 Subject: [PATCH 1/8] :add scale2 for upsample --- include/tvm/expr_operator.h | 3 +++ include/tvm/relay/attrs/nn.h | 5 ++++- python/tvm/relay/frontend/onnx.py | 4 ++-- python/tvm/relay/op/nn/_nn.py | 3 ++- python/tvm/relay/op/nn/nn.py | 3 ++- src/relay/op/nn/upsampling.cc | 11 +++++++---- topi/python/topi/nn/upsampling.py | 8 +++++--- 7 files changed, 25 insertions(+), 12 deletions(-) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 007ae58ad4ba..adc77a8d0f0b 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -700,6 +700,9 @@ inline Expr make_zero(Type t) { } \ inline Expr Name(const Expr& a, int b) { \ return Name(a, make_const(a.type(), b)); \ + } \ + inline Expr Name(const Expr& a, double b) { \ + return Name(a, make_const(Float(64), b)); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 793b43ad2bb3..953a7a5f0c17 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -387,7 +387,8 @@ struct FIFOBufferAttrs : public tvm::AttrsNode { /*! \brief Attributes for upsampling operator */ struct UpSamplingAttrs : public tvm::AttrsNode { - int scale; + double scale; + double scale2; std::string layout; std::string method; bool align_corners; @@ -395,6 +396,8 @@ struct UpSamplingAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { TVM_ATTR_FIELD(scale) .describe("Should be true to preserve the values at the corner pixels"); + TVM_ATTR_FIELD(scale2) + .describe("Should be true to preserve the values at the corner pixels"); TVM_ATTR_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b007b41e61fe..1779d06a3cd1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -581,7 +581,7 @@ def _impl_v9(cls, inputs, attr, params): assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) scales = params[inputs[1].name_hint].asnumpy() inputs = inputs[:1] - assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] + assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 mode = attr.get('mode') if mode == b'nearest': method = "nearest_neighbor" @@ -590,7 +590,7 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True} + attr = {'scale':scales[-2], 'scale2':scales[-1], 'method':method, 'layout':'NCHW', 'align_corners':True} return AttrCvt('upsampling')(inputs, attr) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5786c228abc0..a7daa2d34a4c 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -410,10 +410,11 @@ def schedule_upsampling(_, outs, target): @reg.register_compute("nn.upsampling") def compute_upsampling(attrs, inputs, out_dtype, target): scale = attrs.scale + scale2 = attrs.scale2 layout = attrs.layout method = attrs.method align_corners = attrs.align_corners - return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)] + return [topi.nn.upsampling(inputs[0], scale, scale2, layout, method, align_corners)] # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 1f289d1bd27a..1e7d74fee878 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -484,6 +484,7 @@ def global_avg_pool2d(data, def upsampling(data, scale=1, + scale2=1, layout="NCHW", method="nearest_neighbor", align_corners=False): @@ -519,7 +520,7 @@ def upsampling(data, result : tvm.relay.Expr The computed result. """ - return _make.upsampling(data, scale, layout, method, align_corners) + return _make.upsampling(data, scale, scale2, layout, method, align_corners) def batch_flatten(data): diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index c473f86a39ca..b3a2d8cd05b1 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -29,6 +29,7 @@ #include #include #include "../op_common.h" +#include namespace tvm { namespace relay { @@ -80,9 +81,9 @@ bool UpSamplingRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - - oshape.Set(2, oshape[2] * param->scale); - oshape.Set(3, oshape[3] * param->scale); + oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scale))); + oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scale2))); + // assign output type reporter->Assign(types[1], @@ -95,7 +96,8 @@ bool UpSamplingRel(const Array& types, // Positional relay function to create upsampling operator // used by frontend FFI. Expr MakeUpSampling(Expr data, - int scale, + double scale, + double scale2, std::string layout, std::string method, bool align_corners) { @@ -103,6 +105,7 @@ Expr MakeUpSampling(Expr data, attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale = scale; + attrs->scale2 = scale2; attrs->align_corners = align_corners; static const Op& op = Op::Get("nn.upsampling"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 609213637cf4..9cc2e79c5359 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -17,10 +17,11 @@ """TVM operator upsampling compute.""" from __future__ import absolute_import import topi +import tvm from ..util import simplify -def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False): +def upsampling(data, scale, scale2, layout="NCHW", method='nearest_neighbor', align_corners=False): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -48,9 +49,10 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale)) + out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale), data.shape[2].dtype)), simplify(topi.cast(tvm.round(data.shape[3] * scale2), data.shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) + out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale), data.shape[1].dtype)), simplify(topi.cast(tvm.round(data.shape[2] * scale2), data.shape[2].dtype))) + else: raise ValueError("not support this layout {} yet".format(layout)) return topi.image.resize(data, out_shape, layout=layout, From f65574e5b36f4c96c568aaf2d54a17115ab1dad8 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 05:10:14 +0000 Subject: [PATCH 2/8] update unit test for upsampling --- include/tvm/relay/attrs/nn.h | 8 ++--- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/relay/op/nn/_nn.py | 6 ++-- python/tvm/relay/op/nn/nn.py | 15 ++++---- src/relay/op/nn/upsampling.cc | 12 +++---- tests/python/relay/test_op_level2.py | 22 ++++++------ .../python/relay/test_pass_alter_op_layout.py | 4 +-- tests/python/relay/test_pass_fuse_ops.py | 8 ++--- topi/python/topi/nn/upsampling.py | 15 ++++---- topi/python/topi/testing/upsampling_python.py | 8 ++--- topi/tests/python/test_topi_upsampling.py | 34 ++++++++++--------- 11 files changed, 72 insertions(+), 62 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 953a7a5f0c17..c3170499bd20 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -387,16 +387,16 @@ struct FIFOBufferAttrs : public tvm::AttrsNode { /*! \brief Attributes for upsampling operator */ struct UpSamplingAttrs : public tvm::AttrsNode { - double scale; - double scale2; + double scaleH; + double scaleW; std::string layout; std::string method; bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { - TVM_ATTR_FIELD(scale) + TVM_ATTR_FIELD(scaleH) .describe("Should be true to preserve the values at the corner pixels"); - TVM_ATTR_FIELD(scale2) + TVM_ATTR_FIELD(scaleW) .describe("Should be true to preserve the values at the corner pixels"); TVM_ATTR_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1779d06a3cd1..ae05b1207161 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -590,7 +590,7 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale':scales[-2], 'scale2':scales[-1], 'method':method, 'layout':'NCHW', 'align_corners':True} + attr = {'scaleH':scales[-2], 'scaleW':scales[-1], 'method':method, 'layout':'NCHW', 'align_corners':True} return AttrCvt('upsampling')(inputs, attr) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index a7daa2d34a4c..4507d16770a8 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -409,12 +409,12 @@ def schedule_upsampling(_, outs, target): @reg.register_compute("nn.upsampling") def compute_upsampling(attrs, inputs, out_dtype, target): - scale = attrs.scale - scale2 = attrs.scale2 + scaleH = attrs.scaleH + scaleW = attrs.scaleW layout = attrs.layout method = attrs.method align_corners = attrs.align_corners - return [topi.nn.upsampling(inputs[0], scale, scale2, layout, method, align_corners)] + return [topi.nn.upsampling(inputs[0], scaleH, scaleW, layout, method, align_corners)] # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 1e7d74fee878..d57e1b6e1791 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -483,8 +483,8 @@ def global_avg_pool2d(data, def upsampling(data, - scale=1, - scale2=1, + scaleH=1, + scaleW=1, layout="NCHW", method="nearest_neighbor", align_corners=False): @@ -493,7 +493,7 @@ def upsampling(data, This operator takes data as input and does 2D scaling to the given scale factor. In the default case, where the data_layout is `NCHW` with data of shape (n, c, h, w) - out will have a shape (n, c, h*scale, w*scale) + out will have a shape (n, c, h*scaleH, w*scaleW) method indicates the algorithm to be used while calculating the out value and method can be one of ("bilinear", "nearest_neighbor", "bicubic") @@ -503,8 +503,11 @@ def upsampling(data, data : tvm.relay.Expr The input data to the operator. - scale : tvm.relay.Expr - The scale factor for upsampling. + scaleH : tvm.relay.Expr + The scale factor for height upsampling. + + scaleW : tvm.relay.Expr + The scale factor for width upsampling. layout : str, optional Layout of the input. @@ -520,7 +523,7 @@ def upsampling(data, result : tvm.relay.Expr The computed result. """ - return _make.upsampling(data, scale, scale2, layout, method, align_corners) + return _make.upsampling(data, scaleH, scaleW, layout, method, align_corners) def batch_flatten(data): diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index b3a2d8cd05b1..f65d8e7961a7 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -81,8 +81,8 @@ bool UpSamplingRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scale))); - oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scale2))); + oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scaleH))); + oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scaleW))); // assign output type @@ -96,16 +96,16 @@ bool UpSamplingRel(const Array& types, // Positional relay function to create upsampling operator // used by frontend FFI. Expr MakeUpSampling(Expr data, - double scale, - double scale2, + double scaleH, + double scaleW, std::string layout, std::string method, bool align_corners) { auto attrs = make_node(); attrs->layout = std::move(layout); attrs->method = std::move(method); - attrs->scale = scale; - attrs->scale2 = scale2; + attrs->scaleH = scaleH; + attrs->scaleW = scaleW; attrs->align_corners = align_corners; static const Op& op = Op::Get("nn.upsampling"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 9236d6e55fa0..94ff69049c7f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -232,14 +232,15 @@ def test_conv2d_transpose_run(): def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + scale = tvm.const(2.0, "float64") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") + y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") + assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), tvm.expr.Cast("int32", tvm.round(w*scale))), "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) - y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear") + y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") @@ -504,29 +505,30 @@ def test_batch_flatten(): def _test_upsampling(layout, method, align_corners=False): n, c, h, w = tvm.var("n"), 16, 32, 32 - scale = 2 + scaleH = 2.0 + scaleW = 2.0 dtype = "float32" def get_shape(): if layout == "NCHW": - return (c, h, w), (c, h*scale, w*scale) + return (c, h, w), (c, int(round(h*scaleH)), int(round(w*scaleW))) else: - return (h, w, c), (h*scale, w*scale, c) + return (h, w, c), (int(round(h*scaleH)), int(round(w*scaleW)), c) ishape, oshape = get_shape() x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) - y = relay.nn.upsampling(x, scale=scale, layout=layout, + y = relay.nn.upsampling(x, scaleH=scaleH, scaleW=scaleW, layout=layout, method=method, align_corners=align_corners) yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) dshape = (1,) + ishape x = relay.var("x", shape=dshape) - y = relay.nn.upsampling(x, scale=scale, layout=layout, + y = relay.nn.upsampling(x, scaleH=scaleH, scaleW=scaleW, layout=layout, method=method, align_corners=align_corners) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) if method == "nearest_neighbor": - ref = topi.testing.upsampling_python(data, (scale, scale), layout) + ref = topi.testing.upsampling_python(data, (scaleH, scaleW), layout) else: - ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout) + ref = topi.testing.bilinear_resize_python(data, (int(round(h*scaleH)), int(round(w*scaleW))), layout) for target, ctx in ctx_list(): executor = relay.create_executor("graph", ctx=ctx, target=target) out = executor.evaluate(func)(data) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 8ae7777057f3..510f4e5ec064 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -487,7 +487,7 @@ def before(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.nn.upsampling(y, scale=2) + y = relay.nn.upsampling(y, scaleH=2, scaleW=2) y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) y = relay.Function(analysis.free_vars(y), y) return y @@ -506,7 +506,7 @@ def expected(): x = relay.layout_transform(x, "NCHW", "NCHW16c") y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c") - y = relay.nn.upsampling(y, scale=2, layout="NCHW16c") + y = relay.nn.upsampling(y, scaleH=2, scaleW=2, layout="NCHW16c") y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c') y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.Function(analysis.free_vars(y), y) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 45faa14549ee..8c3641bab60f 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -126,7 +126,7 @@ def test_concatenate(): def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + upsampled = relay.nn.upsampling(pooled, scaleH=2, scaleW=2, layout="NCHW") concat = relay.concatenate((upsampled, x), axis=1) out = relay.add(concat, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) @@ -138,7 +138,7 @@ def expected(dshape): p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p1 = relay.var("p1", shape=dshape) - upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + upsampled = relay.nn.upsampling(p0, scaleH=2, scaleW=2, layout="NCHW") concat = relay.concatenate((upsampled, p1), axis=1) out = relay.add(concat, relay.const(1, "float32")) f1 = relay.Function([p0, p1], out) @@ -164,7 +164,7 @@ def test_tuple_root(): def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + upsampled = relay.nn.upsampling(pooled, scaleH=2, scaleW=2, layout="NCHW") out = relay.Tuple((upsampled, x)) return relay.Function(relay.analysis.free_vars(out), out) @@ -174,7 +174,7 @@ def expected(dshape): f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) - upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + upsampled = relay.nn.upsampling(p0, scaleH=2, scaleW=2, layout="NCHW") f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 9cc2e79c5359..8692ef249bbb 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -21,7 +21,7 @@ from ..util import simplify -def upsampling(data, scale, scale2, layout="NCHW", method='nearest_neighbor', align_corners=False): +def upsampling(data, scaleH, scaleW, layout="NCHW", method='nearest_neighbor', align_corners=False): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -32,8 +32,11 @@ def upsampling(data, scale, scale2, layout="NCHW", method='nearest_neighbor', al [batch, channel, in_height, in_width] or [batch, in_height, in_width, channel] - scale : int - Scaling factor + scaleH : float + Scaling factor for height + + scaleW : float + Scaling factor for width layout : string, optional either "NCHW" or "NHWC" @@ -44,14 +47,14 @@ def upsampling(data, scale, scale2, layout="NCHW", method='nearest_neighbor', al Returns ------- output : tvm.Tensor - 4-D with shape [batch, channel, in_height*scale, in_width*scale] + 4-D with shape [batch, channel, in_height*scaleH, in_width*scaleW] or [batch, in_height*scale, in_width*scale, channel] """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale), data.shape[2].dtype)), simplify(topi.cast(tvm.round(data.shape[3] * scale2), data.shape[3].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scaleH), data.shape[2].dtype)), simplify(topi.cast(tvm.round(data.shape[3] * scaleW), data.shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale), data.shape[1].dtype)), simplify(topi.cast(tvm.round(data.shape[2] * scale2), data.shape[2].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scaleH), data.shape[1].dtype)), simplify(topi.cast(tvm.round(data.shape[2] * scaleW), data.shape[2].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 167fdfc7f227..99f1e4a483b3 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -22,8 +22,8 @@ def upsample_nearest(arr, scale): """ Populate the array by scale factor""" h, w = arr.shape - out_h = math.floor(h * scale[0]) - out_w = math.floor(w * scale[1]) + out_h = int(round(h * scale[0])) + out_w = int(round(w * scale[1])) out = np.empty((out_h, out_w)) for y in range(out_h): for x in range(out_w): @@ -37,14 +37,14 @@ def upsampling_python(data, scale, layout='NCHW'): ishape = data.shape if layout == 'NCHW': - oshape = (ishape[0], ishape[1], math.floor(ishape[2]*scale[0]), math.floor(ishape[3]*scale[1])) + oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])), int(round(ishape[3]*scale[1]))) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[1]): output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) return output_np if layout == 'NHWC': - oshape = (ishape[0], math.floor(ishape[1]*scale[0]), math.floor(ishape[1]*scale[1]), ishape[3]) + oshape = (ishape[0], int(round(ishape[1]*scale[0])), int(round(ishape[2]*scale[1])), ishape[3]) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[3]): diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index f878c23aed92..416299365b31 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -23,30 +23,28 @@ from common import get_all_backend -def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="nearest_neighbor"): - - +def verify_upsampling(batch, in_channel, in_height, in_width, scaleH, scaleW, layout='NCHW', method="nearest_neighbor"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') dtype = A.dtype - out_shape = (batch, in_channel, in_height*scale, in_width*scale) + out_shape = (batch, in_channel, int(round(in_height*scaleH)), int(round(in_width*scaleW))) a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) elif layout == 'NHWC': A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') dtype = A.dtype - out_shape = (batch, in_height*scale, in_width*scale, in_channel) + out_shape = (batch, int(round(in_height*scaleH)), int(round(in_width*scaleW)), in_channel) a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.nn.upsampling(A, scale, layout=layout, method=method, align_corners=False) + B = topi.nn.upsampling(A, scaleH, scaleW, layout=layout, method=method, align_corners=False) if method == "bilinear": - out_size = (in_height*scale, in_width*scale) + out_size = (int(round(in_height*scaleH)), int(round(in_width*scaleW))) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) else: - b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout) + b_np = topi.testing.upsampling_python(a_np, (scaleH, scaleW), layout) def check_device(device): ctx = tvm.context(device, 0) @@ -68,20 +66,24 @@ def check_device(device): def test_upsampling(): # nearest_neighbor - NCHW - verify_upsampling(8, 16, 32, 32, 2) - verify_upsampling(2, 32, 64, 64, 3) + verify_upsampling(8, 16, 32, 32, 2.0, 2.0) + verify_upsampling(2, 32, 64, 64, 3.0, 3.0) + verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0) ## nearest_neighbor - NHWC - verify_upsampling(8, 16, 32, 32, 2, layout="NHWC") - verify_upsampling(2, 32, 64, 64, 3, layout="NHWC") + verify_upsampling(8, 16, 32, 32, 2.0, 2.0, layout="NHWC") + verify_upsampling(2, 32, 64, 64, 3.0, 3.0, layout="NHWC") + verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, layout="NHWC") # bilinear - NCHW - verify_upsampling(2, 2, 32, 32, 2, method="bilinear") - verify_upsampling(2, 2, 32, 32, 3, method="bilinear") + verify_upsampling(2, 2, 32, 32, 2.0, 2.0, method="bilinear") + verify_upsampling(2, 2, 32, 32, 3.0, 3.0, method="bilinear") + verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear") # bilinear - NHWC - verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="bilinear") - verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="bilinear") + verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear") + verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear") + verify_upsampling(1, 64, 22, 32, 3.0, 3.0, layout="NHWC", method="bilinear") if __name__ == "__main__": test_upsampling() From 56876d193ae44b7d57de64b8b3f160600fa4be52 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 05:47:32 +0000 Subject: [PATCH 3/8] support latest upsample op for multiple frontend --- python/tvm/relay/frontend/caffe2.py | 2 +- python/tvm/relay/frontend/coreml.py | 2 +- python/tvm/relay/frontend/darknet.py | 5 +++-- python/tvm/relay/frontend/keras.py | 8 +++++--- python/tvm/relay/frontend/nnvm_common.py | 2 +- src/relay/op/nn/upsampling.cc | 1 - 6 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index ac16a6bf13b6..429d3b744d99 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -280,7 +280,7 @@ def _impl(cls, inputs, args, params): assert width_scale == height_scale return _op.nn.upsampling( - inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR") + inputs[0], scaleH=int(width_scale), scaleW=int(width_scale), method="NEAREST_NEIGHBOR") class Sum(Caffe2OpConverter): diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 8b158ca0dec2..198356aa03dc 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -313,7 +313,7 @@ def _UpsampleLayerParams(op, inexpr, etab): raise tvm.error.OpAttributeUnimplemented( 'Upsample height and width must be equal.') interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear' - return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) + return _op.nn.upsampling(inexpr, scaleH=op.scalingFactor[0], scaleW=op.scalingFactor[1], method=interpolationMode) def _L2NormalizeLayerParams(op, inexpr, etab): diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 982bceaafd36..8b11e04766b0 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -129,7 +129,7 @@ def _darknet_shortcut(inputs, params, attrs, prefix): if input_0_size > input_1_size: scale = int(input_0_size/input_1_size) - input_1 = get_relay_op('upsampling')(input_1, scale=scale) + input_1 = get_relay_op('upsampling')(input_1, scaleH=scale, scaleW=scale) elif input_0_size < input_1_size: stride = int(input_1_size/input_0_size) @@ -196,7 +196,8 @@ def _darknet_reshape(inputs, params, attrs, prefix): def _darknet_upsampling(inputs, params, attrs, prefix): """Process the upsampling operation.""" new_attrs = {} - new_attrs['scale'] = attrs.get('scale', 1) + new_attrs['scaleH'] = attrs.get('scale', 1) + new_attrs['scaleW'] = attrs.get('scale', 1) return get_relay_op('upsampling')(*inputs, **new_attrs) def _darknet_l2normalize(inputs, params, attrs, prefix): diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index cc092f380c5c..61e7086a6426 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -398,13 +398,14 @@ def _convert_upsample(inexpr, keras_layer, _): params = {} if upsample_type == 'UpSampling1D': h = keras_layer.size - params['scale'] = h + params['scaleH'] = h elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: raise tvm.error.OpAttributeInvalid( 'Height must equal width for operator Upsample.') - params['scale'] = h + params['scaleH'] = h + params['scaleW'] = h if hasattr(keras_layer, 'interpolation'): interpolation = keras_layer.interpolation @@ -418,7 +419,8 @@ def _convert_upsample(inexpr, keras_layer, _): if h != w or w != d: raise tvm.error.OpAttributeInvalid( 'Height, width, and depth must all be equal for operator Upsample.') - params['scale'] = h + params['scaleH'] = h + params['scaleW'] = h else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index ef2b81c1d2b8..f9233ce16c27 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -112,7 +112,7 @@ def _transpose(inputs, attrs): def _upsampling(inputs, attrs): scale = attrs.get_int("scale") - return _op.nn.upsampling(inputs[0], scale=scale) + return _op.nn.upsampling(inputs[0], scaleH=scale, scaleW=scale) def _elemwise_sum(inputs, _, _dtype='float32'): diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index f65d8e7961a7..56aa1e00b9bc 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -29,7 +29,6 @@ #include #include #include "../op_common.h" -#include namespace tvm { namespace relay { From 9077f6510a63209d3f89dbcd3eae811c20898c01 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 18:17:34 +0000 Subject: [PATCH 4/8] fix lint --- src/relay/op/nn/upsampling.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 56aa1e00b9bc..00e4e6bf4fd3 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -82,7 +82,6 @@ bool UpSamplingRel(const Array& types, auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scaleH))); oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scaleW))); - // assign output type reporter->Assign(types[1], From 068e4f289c7c477a15666b68abbb217e02b4f958 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 18:39:16 +0000 Subject: [PATCH 5/8] fix lint --- python/tvm/relay/frontend/coreml.py | 3 ++- python/tvm/relay/frontend/onnx.py | 3 ++- tests/python/relay/test_op_level2.py | 6 ++++-- topi/python/topi/nn/upsampling.py | 6 ++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 198356aa03dc..7b94729d4495 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -313,7 +313,8 @@ def _UpsampleLayerParams(op, inexpr, etab): raise tvm.error.OpAttributeUnimplemented( 'Upsample height and width must be equal.') interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear' - return _op.nn.upsampling(inexpr, scaleH=op.scalingFactor[0], scaleW=op.scalingFactor[1], method=interpolationMode) + return _op.nn.upsampling(inexpr, scaleH=op.scalingFactor[0], + scaleW=op.scalingFactor[1], method=interpolationMode) def _L2NormalizeLayerParams(op, inexpr, etab): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ae05b1207161..8b696989a8b3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -590,7 +590,8 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scaleH':scales[-2], 'scaleW':scales[-1], 'method':method, 'layout':'NCHW', 'align_corners':True} + attr = {'scaleH':scales[-2], 'scaleW':scales[-1], 'method':method, + 'layout':'NCHW', 'align_corners':True} return AttrCvt('upsampling')(inputs, attr) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 94ff69049c7f..5df3573d0214 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -237,7 +237,8 @@ def test_upsampling_infer_type(): y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), tvm.expr.Cast("int32", tvm.round(w*scale))), "float32") + assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), + tvm.expr.Cast("int32", tvm.round(w*scale))), "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") @@ -528,7 +529,8 @@ def get_shape(): if method == "nearest_neighbor": ref = topi.testing.upsampling_python(data, (scaleH, scaleW), layout) else: - ref = topi.testing.bilinear_resize_python(data, (int(round(h*scaleH)), int(round(w*scaleW))), layout) + ref = topi.testing.bilinear_resize_python(data, (int(round(h*scaleH)), + int(round(w*scaleW))), layout) for target, ctx in ctx_list(): executor = relay.create_executor("graph", ctx=ctx, target=target) out = executor.evaluate(func)(data) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 8692ef249bbb..d8237568f0b1 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -52,9 +52,11 @@ def upsampling(data, scaleH, scaleW, layout="NCHW", method='nearest_neighbor', a """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scaleH), data.shape[2].dtype)), simplify(topi.cast(tvm.round(data.shape[3] * scaleW), data.shape[3].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scaleH), data.shape[2].dtype)), + simplify(topi.cast(tvm.round(data.shape[3] * scaleW), data.shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scaleH), data.shape[1].dtype)), simplify(topi.cast(tvm.round(data.shape[2] * scaleW), data.shape[2].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scaleH), data.shape[1].dtype)), + simplify(topi.cast(tvm.round(data.shape[2] * scaleW), data.shape[2].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) From 1a0240546a0a8687fb8d4b8e5e2da2334037cb5d Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 19:51:36 +0000 Subject: [PATCH 6/8] fix lint --- include/tvm/relay/attrs/nn.h | 8 +++---- python/tvm/relay/frontend/caffe2.py | 2 +- python/tvm/relay/frontend/coreml.py | 4 ++-- python/tvm/relay/frontend/darknet.py | 6 ++--- python/tvm/relay/frontend/keras.py | 10 ++++----- python/tvm/relay/frontend/nnvm_common.py | 2 +- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/relay/op/nn/_nn.py | 6 ++--- python/tvm/relay/op/nn/nn.py | 12 +++++----- src/relay/op/nn/upsampling.cc | 12 +++++----- tests/python/relay/test_op_level2.py | 22 +++++++++---------- .../python/relay/test_pass_alter_op_layout.py | 4 ++-- tests/python/relay/test_pass_fuse_ops.py | 8 +++---- topi/python/topi/nn/upsampling.py | 16 +++++++------- topi/tests/python/test_topi_upsampling.py | 12 +++++----- 15 files changed, 63 insertions(+), 63 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index c3170499bd20..78597ffa26ca 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -387,16 +387,16 @@ struct FIFOBufferAttrs : public tvm::AttrsNode { /*! \brief Attributes for upsampling operator */ struct UpSamplingAttrs : public tvm::AttrsNode { - double scaleH; - double scaleW; + double scale_h; + double scale_w; std::string layout; std::string method; bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { - TVM_ATTR_FIELD(scaleH) + TVM_ATTR_FIELD(scale_h) .describe("Should be true to preserve the values at the corner pixels"); - TVM_ATTR_FIELD(scaleW) + TVM_ATTR_FIELD(scale_w) .describe("Should be true to preserve the values at the corner pixels"); TVM_ATTR_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 429d3b744d99..456d782e521f 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -280,7 +280,7 @@ def _impl(cls, inputs, args, params): assert width_scale == height_scale return _op.nn.upsampling( - inputs[0], scaleH=int(width_scale), scaleW=int(width_scale), method="NEAREST_NEIGHBOR") + inputs[0], scale_h=int(width_scale), scale_w=int(width_scale), method="NEAREST_NEIGHBOR") class Sum(Caffe2OpConverter): diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 7b94729d4495..a24043df135d 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -313,8 +313,8 @@ def _UpsampleLayerParams(op, inexpr, etab): raise tvm.error.OpAttributeUnimplemented( 'Upsample height and width must be equal.') interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear' - return _op.nn.upsampling(inexpr, scaleH=op.scalingFactor[0], - scaleW=op.scalingFactor[1], method=interpolationMode) + return _op.nn.upsampling(inexpr, scale_h=op.scalingFactor[0], + scale_w=op.scalingFactor[1], method=interpolationMode) def _L2NormalizeLayerParams(op, inexpr, etab): diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 8b11e04766b0..a2a72eaf57ca 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -129,7 +129,7 @@ def _darknet_shortcut(inputs, params, attrs, prefix): if input_0_size > input_1_size: scale = int(input_0_size/input_1_size) - input_1 = get_relay_op('upsampling')(input_1, scaleH=scale, scaleW=scale) + input_1 = get_relay_op('upsampling')(input_1, scale_h=scale, scale_w=scale) elif input_0_size < input_1_size: stride = int(input_1_size/input_0_size) @@ -196,8 +196,8 @@ def _darknet_reshape(inputs, params, attrs, prefix): def _darknet_upsampling(inputs, params, attrs, prefix): """Process the upsampling operation.""" new_attrs = {} - new_attrs['scaleH'] = attrs.get('scale', 1) - new_attrs['scaleW'] = attrs.get('scale', 1) + new_attrs['scale_h'] = attrs.get('scale', 1) + new_attrs['scale_w'] = attrs.get('scale', 1) return get_relay_op('upsampling')(*inputs, **new_attrs) def _darknet_l2normalize(inputs, params, attrs, prefix): diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 61e7086a6426..15f7440c3b42 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -398,14 +398,14 @@ def _convert_upsample(inexpr, keras_layer, _): params = {} if upsample_type == 'UpSampling1D': h = keras_layer.size - params['scaleH'] = h + params['scale_h'] = h elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: raise tvm.error.OpAttributeInvalid( 'Height must equal width for operator Upsample.') - params['scaleH'] = h - params['scaleW'] = h + params['scale_h'] = h + params['scale_w'] = h if hasattr(keras_layer, 'interpolation'): interpolation = keras_layer.interpolation @@ -419,8 +419,8 @@ def _convert_upsample(inexpr, keras_layer, _): if h != w or w != d: raise tvm.error.OpAttributeInvalid( 'Height, width, and depth must all be equal for operator Upsample.') - params['scaleH'] = h - params['scaleW'] = h + params['scale_h'] = h + params['scale_w'] = h else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index f9233ce16c27..5f24fa0a504e 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -112,7 +112,7 @@ def _transpose(inputs, attrs): def _upsampling(inputs, attrs): scale = attrs.get_int("scale") - return _op.nn.upsampling(inputs[0], scaleH=scale, scaleW=scale) + return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale) def _elemwise_sum(inputs, _, _dtype='float32'): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8b696989a8b3..1d74a01b1860 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -590,7 +590,7 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scaleH':scales[-2], 'scaleW':scales[-1], 'method':method, + attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method, 'layout':'NCHW', 'align_corners':True} return AttrCvt('upsampling')(inputs, attr) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 4507d16770a8..891548036017 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -409,12 +409,12 @@ def schedule_upsampling(_, outs, target): @reg.register_compute("nn.upsampling") def compute_upsampling(attrs, inputs, out_dtype, target): - scaleH = attrs.scaleH - scaleW = attrs.scaleW + scale_h = attrs.scale_h + scale_w = attrs.scale_w layout = attrs.layout method = attrs.method align_corners = attrs.align_corners - return [topi.nn.upsampling(inputs[0], scaleH, scaleW, layout, method, align_corners)] + return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)] # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index d57e1b6e1791..6488eab0d1d8 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -483,8 +483,8 @@ def global_avg_pool2d(data, def upsampling(data, - scaleH=1, - scaleW=1, + scale_h=1, + scale_w=1, layout="NCHW", method="nearest_neighbor", align_corners=False): @@ -493,7 +493,7 @@ def upsampling(data, This operator takes data as input and does 2D scaling to the given scale factor. In the default case, where the data_layout is `NCHW` with data of shape (n, c, h, w) - out will have a shape (n, c, h*scaleH, w*scaleW) + out will have a shape (n, c, h*scale_h, w*scale_w) method indicates the algorithm to be used while calculating the out value and method can be one of ("bilinear", "nearest_neighbor", "bicubic") @@ -503,10 +503,10 @@ def upsampling(data, data : tvm.relay.Expr The input data to the operator. - scaleH : tvm.relay.Expr + scale_h : tvm.relay.Expr The scale factor for height upsampling. - scaleW : tvm.relay.Expr + scale_w : tvm.relay.Expr The scale factor for width upsampling. layout : str, optional @@ -523,7 +523,7 @@ def upsampling(data, result : tvm.relay.Expr The computed result. """ - return _make.upsampling(data, scaleH, scaleW, layout, method, align_corners) + return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) def batch_flatten(data): diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 00e4e6bf4fd3..e044722380ce 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -80,8 +80,8 @@ bool UpSamplingRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scaleH))); - oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scaleW))); + oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scale_h))); + oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scale_w))); // assign output type reporter->Assign(types[1], @@ -94,16 +94,16 @@ bool UpSamplingRel(const Array& types, // Positional relay function to create upsampling operator // used by frontend FFI. Expr MakeUpSampling(Expr data, - double scaleH, - double scaleW, + double scale_h, + double scale_w, std::string layout, std::string method, bool align_corners) { auto attrs = make_node(); attrs->layout = std::move(layout); attrs->method = std::move(method); - attrs->scaleH = scaleH; - attrs->scaleW = scaleW; + attrs->scale_h = scale_h; + attrs->scale_w = scale_w; attrs->align_corners = align_corners; static const Op& op = Op::Get("nn.upsampling"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 5df3573d0214..24a29ed8eaa4 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -234,14 +234,14 @@ def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") scale = tvm.const(2.0, "float64") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") + y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), tvm.expr.Cast("int32", tvm.round(w*scale))), "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) - y = relay.nn.upsampling(x, scaleH=2, scaleW=2, layout="NCHW", method="bilinear") + y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") @@ -506,31 +506,31 @@ def test_batch_flatten(): def _test_upsampling(layout, method, align_corners=False): n, c, h, w = tvm.var("n"), 16, 32, 32 - scaleH = 2.0 - scaleW = 2.0 + scale_h = 2.0 + scale_w = 2.0 dtype = "float32" def get_shape(): if layout == "NCHW": - return (c, h, w), (c, int(round(h*scaleH)), int(round(w*scaleW))) + return (c, h, w), (c, int(round(h*scale_h)), int(round(w*scale_w))) else: - return (h, w, c), (int(round(h*scaleH)), int(round(w*scaleW)), c) + return (h, w, c), (int(round(h*scale_h)), int(round(w*scale_w)), c) ishape, oshape = get_shape() x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) - y = relay.nn.upsampling(x, scaleH=scaleH, scaleW=scaleW, layout=layout, + y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout, method=method, align_corners=align_corners) yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) dshape = (1,) + ishape x = relay.var("x", shape=dshape) - y = relay.nn.upsampling(x, scaleH=scaleH, scaleW=scaleW, layout=layout, + y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout, method=method, align_corners=align_corners) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) if method == "nearest_neighbor": - ref = topi.testing.upsampling_python(data, (scaleH, scaleW), layout) + ref = topi.testing.upsampling_python(data, (scale_h, scale_w), layout) else: - ref = topi.testing.bilinear_resize_python(data, (int(round(h*scaleH)), - int(round(w*scaleW))), layout) + ref = topi.testing.bilinear_resize_python(data, (int(round(h*scale_h)), + int(round(w*scale_w))), layout) for target, ctx in ctx_list(): executor = relay.create_executor("graph", ctx=ctx, target=target) out = executor.evaluate(func)(data) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 510f4e5ec064..f1200ec62a32 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -487,7 +487,7 @@ def before(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.nn.upsampling(y, scaleH=2, scaleW=2) + y = relay.nn.upsampling(y, scale_h=2, scale_w=2) y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) y = relay.Function(analysis.free_vars(y), y) return y @@ -506,7 +506,7 @@ def expected(): x = relay.layout_transform(x, "NCHW", "NCHW16c") y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c") - y = relay.nn.upsampling(y, scaleH=2, scaleW=2, layout="NCHW16c") + y = relay.nn.upsampling(y, scale_h=2, scale_w=2, layout="NCHW16c") y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c') y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.Function(analysis.free_vars(y), y) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 8c3641bab60f..7ec21eab12df 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -126,7 +126,7 @@ def test_concatenate(): def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - upsampled = relay.nn.upsampling(pooled, scaleH=2, scaleW=2, layout="NCHW") + upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") concat = relay.concatenate((upsampled, x), axis=1) out = relay.add(concat, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) @@ -138,7 +138,7 @@ def expected(dshape): p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p1 = relay.var("p1", shape=dshape) - upsampled = relay.nn.upsampling(p0, scaleH=2, scaleW=2, layout="NCHW") + upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") concat = relay.concatenate((upsampled, p1), axis=1) out = relay.add(concat, relay.const(1, "float32")) f1 = relay.Function([p0, p1], out) @@ -164,7 +164,7 @@ def test_tuple_root(): def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - upsampled = relay.nn.upsampling(pooled, scaleH=2, scaleW=2, layout="NCHW") + upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") out = relay.Tuple((upsampled, x)) return relay.Function(relay.analysis.free_vars(out), out) @@ -174,7 +174,7 @@ def expected(dshape): f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) - upsampled = relay.nn.upsampling(p0, scaleH=2, scaleW=2, layout="NCHW") + upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index d8237568f0b1..cfe0935df906 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -21,7 +21,7 @@ from ..util import simplify -def upsampling(data, scaleH, scaleW, layout="NCHW", method='nearest_neighbor', align_corners=False): +def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', align_corners=False): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -32,10 +32,10 @@ def upsampling(data, scaleH, scaleW, layout="NCHW", method='nearest_neighbor', a [batch, channel, in_height, in_width] or [batch, in_height, in_width, channel] - scaleH : float + scale_h : float Scaling factor for height - scaleW : float + scale_w : float Scaling factor for width layout : string, optional @@ -47,16 +47,16 @@ def upsampling(data, scaleH, scaleW, layout="NCHW", method='nearest_neighbor', a Returns ------- output : tvm.Tensor - 4-D with shape [batch, channel, in_height*scaleH, in_width*scaleW] + 4-D with shape [batch, channel, in_height*scale_h, in_width*scale_w] or [batch, in_height*scale, in_width*scale, channel] """ base_layout = layout[0:4] if base_layout == "NCHW": - out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scaleH), data.shape[2].dtype)), - simplify(topi.cast(tvm.round(data.shape[3] * scaleW), data.shape[3].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale_h), data.shape[2].dtype)), + simplify(topi.cast(tvm.round(data.shape[3] * scale_w), data.shape[3].dtype))) elif layout == "NHWC": - out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scaleH), data.shape[1].dtype)), - simplify(topi.cast(tvm.round(data.shape[2] * scaleW), data.shape[2].dtype))) + out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale_h), data.shape[1].dtype)), + simplify(topi.cast(tvm.round(data.shape[2] * scale_w), data.shape[2].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 416299365b31..63a3e6efa574 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -23,28 +23,28 @@ from common import get_all_backend -def verify_upsampling(batch, in_channel, in_height, in_width, scaleH, scaleW, layout='NCHW', method="nearest_neighbor"): +def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, layout='NCHW', method="nearest_neighbor"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') dtype = A.dtype - out_shape = (batch, in_channel, int(round(in_height*scaleH)), int(round(in_width*scaleW))) + out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w))) a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) elif layout == 'NHWC': A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') dtype = A.dtype - out_shape = (batch, int(round(in_height*scaleH)), int(round(in_width*scaleW)), in_channel) + out_shape = (batch, int(round(in_height*scale_h)), int(round(in_width*scale_w)), in_channel) a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.nn.upsampling(A, scaleH, scaleW, layout=layout, method=method, align_corners=False) + B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False) if method == "bilinear": - out_size = (int(round(in_height*scaleH)), int(round(in_width*scaleW))) + out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w))) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) else: - b_np = topi.testing.upsampling_python(a_np, (scaleH, scaleW), layout) + b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) def check_device(device): ctx = tvm.context(device, 0) From 4ed9eceecb28876aaf26acae573115cecc588a83 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 25 Oct 2019 20:00:04 +0000 Subject: [PATCH 7/8] fix lint --- nnvm/python/nnvm/to_relay.py | 3 ++- tests/python/relay/test_op_level2.py | 3 ++- topi/python/topi/nn/upsampling.py | 3 ++- topi/python/topi/testing/upsampling_python.py | 6 ++++-- topi/tests/python/test_topi_upsampling.py | 3 ++- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index 26dba0f94a27..94a736dabe70 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -219,7 +219,8 @@ def _upsampling(children, attrs, odtype='float32'): method = attrs.get_str('method', 'NEAREST_NEIGHBOR') return op.nn.upsampling( children[0], - scale=scale, + scale_h=scale, + scale_w=scale, layout=layout, method=method) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 24a29ed8eaa4..982161d9899f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -238,7 +238,8 @@ def test_upsampling_infer_type(): "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), - tvm.expr.Cast("int32", tvm.round(w*scale))), "float32") + tvm.expr.Cast("int32", tvm.round(w*scale))), + "float32") n, c = tvm.var("n"), tvm.var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index cfe0935df906..771c9e207a17 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -21,7 +21,8 @@ from ..util import simplify -def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', align_corners=False): +def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', + align_corners=False): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 99f1e4a483b3..6ea7d6ad8835 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -37,14 +37,16 @@ def upsampling_python(data, scale, layout='NCHW'): ishape = data.shape if layout == 'NCHW': - oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])), int(round(ishape[3]*scale[1]))) + oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])), + int(round(ishape[3]*scale[1]))) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[1]): output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) return output_np if layout == 'NHWC': - oshape = (ishape[0], int(round(ishape[1]*scale[0])), int(round(ishape[2]*scale[1])), ishape[3]) + oshape = (ishape[0], int(round(ishape[1]*scale[0])), + int(round(ishape[2]*scale[1])), ishape[3]) output_np = np.zeros(oshape, dtype=data.dtype) for b in range(oshape[0]): for c in range(oshape[3]): diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 63a3e6efa574..83909c085d14 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -23,7 +23,8 @@ from common import get_all_backend -def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, layout='NCHW', method="nearest_neighbor"): +def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, + layout='NCHW', method="nearest_neighbor"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') dtype = A.dtype From 643c903238aef1ba2b955a31b62de005be4e0155 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Sun, 27 Oct 2019 03:00:27 +0000 Subject: [PATCH 8/8] update scale description and rebase --- include/tvm/relay/attrs/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 78597ffa26ca..f8e5af98c0a0 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -395,9 +395,9 @@ struct UpSamplingAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { TVM_ATTR_FIELD(scale_h) - .describe("Should be true to preserve the values at the corner pixels"); + .describe("The upsampling factor for height"); TVM_ATTR_FIELD(scale_w) - .describe("Should be true to preserve the values at the corner pixels"); + .describe("The upsampling factor for width"); TVM_ATTR_FIELD(layout).set_default("NCHW") .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"