Skip to content

Commit

Permalink
[Relay][Op] Enhance Upsample Operator to support float scales (apache…
Browse files Browse the repository at this point in the history
…#4206)

* :add scale2 for upsample

* update unit test for upsampling

* support latest upsample op for multiple frontend

* fix lint

* fix lint

* fix lint

* fix lint

* update scale description and rebase
  • Loading branch information
Xingyu Zhou authored and kevinthesun committed Oct 30, 2019
1 parent 6f3be60 commit 08776e4
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 68 deletions.
3 changes: 3 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ inline Expr make_zero(Type t) {
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(const Expr& a, double b) { \
return Name(a, make_const(Float(64), b)); \
}

#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,17 @@ struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
double scale_h;
double scale_w;
std::string layout;
std::string method;
bool align_corners;

TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
TVM_ATTR_FIELD(scale)
.describe("Should be true to preserve the values at the corner pixels");
TVM_ATTR_FIELD(scale_h)
.describe("The upsampling factor for height");
TVM_ATTR_FIELD(scale_w)
.describe("The upsampling factor for width");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand Down
3 changes: 2 additions & 1 deletion nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def _upsampling(children, attrs, odtype='float32'):
method = attrs.get_str('method', 'NEAREST_NEIGHBOR')
return op.nn.upsampling(
children[0],
scale=scale,
scale_h=scale,
scale_w=scale,
layout=layout,
method=method)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _impl(cls, inputs, args, params):
assert width_scale == height_scale

return _op.nn.upsampling(
inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR")
inputs[0], scale_h=int(width_scale), scale_w=int(width_scale), method="NEAREST_NEIGHBOR")


class Sum(Caffe2OpConverter):
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def _UpsampleLayerParams(op, inexpr, etab):
raise tvm.error.OpAttributeUnimplemented(
'Upsample height and width must be equal.')
interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode)
return _op.nn.upsampling(inexpr, scale_h=op.scalingFactor[0],
scale_w=op.scalingFactor[1], method=interpolationMode)


def _L2NormalizeLayerParams(op, inexpr, etab):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _darknet_shortcut(inputs, params, attrs, prefix):

if input_0_size > input_1_size:
scale = int(input_0_size/input_1_size)
input_1 = get_relay_op('upsampling')(input_1, scale=scale)
input_1 = get_relay_op('upsampling')(input_1, scale_h=scale, scale_w=scale)

elif input_0_size < input_1_size:
stride = int(input_1_size/input_0_size)
Expand Down Expand Up @@ -196,7 +196,8 @@ def _darknet_reshape(inputs, params, attrs, prefix):
def _darknet_upsampling(inputs, params, attrs, prefix):
"""Process the upsampling operation."""
new_attrs = {}
new_attrs['scale'] = attrs.get('scale', 1)
new_attrs['scale_h'] = attrs.get('scale', 1)
new_attrs['scale_w'] = attrs.get('scale', 1)
return get_relay_op('upsampling')(*inputs, **new_attrs)

def _darknet_l2normalize(inputs, params, attrs, prefix):
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,14 @@ def _convert_upsample(inexpr, keras_layer, _):
params = {}
if upsample_type == 'UpSampling1D':
h = keras_layer.size
params['scale'] = h
params['scale_h'] = h
elif upsample_type == 'UpSampling2D':
h, w = keras_layer.size
if h != w:
raise tvm.error.OpAttributeInvalid(
'Height must equal width for operator Upsample.')
params['scale'] = h
params['scale_h'] = h
params['scale_w'] = h

if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation
Expand All @@ -418,7 +419,8 @@ def _convert_upsample(inexpr, keras_layer, _):
if h != w or w != d:
raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.')
params['scale'] = h
params['scale_h'] = h
params['scale_w'] = h
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type))
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _transpose(inputs, attrs):

def _upsampling(inputs, attrs):
scale = attrs.get_int("scale")
return _op.nn.upsampling(inputs[0], scale=scale)
return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale)


def _elemwise_sum(inputs, _, _dtype='float32'):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _impl_v9(cls, inputs, attr, params):
assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
scales = params[inputs[1].name_hint].asnumpy()
inputs = inputs[:1]
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
mode = attr.get('mode')
if mode == b'nearest':
method = "nearest_neighbor"
Expand All @@ -590,7 +590,8 @@ def _impl_v9(cls, inputs, attr, params):
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW', 'align_corners':True}
attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method,
'layout':'NCHW', 'align_corners':True}
return AttrCvt('upsampling')(inputs, attr)


Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,12 @@ def schedule_upsampling(_, outs, target):

@reg.register_compute("nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype, target):
scale = attrs.scale
scale_h = attrs.scale_h
scale_w = attrs.scale_w
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale, layout, method, align_corners)]
return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]

# pad
reg.register_schedule("nn.pad", schedule_broadcast)
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def global_avg_pool2d(data,


def upsampling(data,
scale=1,
scale_h=1,
scale_w=1,
layout="NCHW",
method="nearest_neighbor",
align_corners=False):
Expand All @@ -492,7 +493,7 @@ def upsampling(data,
This operator takes data as input and does 2D scaling to the given scale factor.
In the default case, where the data_layout is `NCHW`
with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale)
out will have a shape (n, c, h*scale_h, w*scale_w)
method indicates the algorithm to be used while calculating the out value
and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Expand All @@ -502,8 +503,11 @@ def upsampling(data,
data : tvm.relay.Expr
The input data to the operator.
scale : tvm.relay.Expr
The scale factor for upsampling.
scale_h : tvm.relay.Expr
The scale factor for height upsampling.
scale_w : tvm.relay.Expr
The scale factor for width upsampling.
layout : str, optional
Layout of the input.
Expand All @@ -519,7 +523,7 @@ def upsampling(data,
result : tvm.relay.Expr
The computed result.
"""
return _make.upsampling(data, scale, layout, method, align_corners)
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)


def batch_flatten(data):
Expand Down
11 changes: 6 additions & 5 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ bool UpSamplingRel(const Array<Type>& types,
<< " But got " << in_layout;

auto oshape = layout_converter.ForwardShape(data->shape);

oshape.Set(2, oshape[2] * param->scale);
oshape.Set(3, oshape[3] * param->scale);
oshape.Set(2, ir::Cast::make(oshape[2].type(), tvm::round(oshape[2] * param->scale_h)));
oshape.Set(3, ir::Cast::make(oshape[3].type(), tvm::round(oshape[3] * param->scale_w)));

// assign output type
reporter->Assign(types[1],
Expand All @@ -95,14 +94,16 @@ bool UpSamplingRel(const Array<Type>& types,
// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data,
int scale,
double scale_h,
double scale_w,
std::string layout,
std::string method,
bool align_corners) {
auto attrs = make_node<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->scale = scale;
attrs->scale_h = scale_h;
attrs->scale_w = scale_w;
attrs->align_corners = align_corners;
static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {});
Expand Down
25 changes: 15 additions & 10 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,17 @@ def test_conv2d_transpose_run():

def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32")
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="bilinear")
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")

Expand Down Expand Up @@ -504,29 +507,31 @@ def test_batch_flatten():

def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = tvm.var("n"), 16, 32, 32
scale = 2
scale_h = 2.0
scale_w = 2.0
dtype = "float32"
def get_shape():
if layout == "NCHW":
return (c, h, w), (c, h*scale, w*scale)
return (c, h, w), (c, int(round(h*scale_h)), int(round(w*scale_w)))
else:
return (h, w, c), (h*scale, w*scale, c)
return (h, w, c), (int(round(h*scale_h)), int(round(w*scale_w)), c)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling(x, scale=scale, layout=layout,
y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
method=method, align_corners=align_corners)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling(x, scale=scale, layout=layout,
y = relay.nn.upsampling(x, scale_h=scale_h, scale_w=scale_w, layout=layout,
method=method, align_corners=align_corners)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
if method == "nearest_neighbor":
ref = topi.testing.upsampling_python(data, (scale, scale), layout)
ref = topi.testing.upsampling_python(data, (scale_h, scale_w), layout)
else:
ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout)
ref = topi.testing.bilinear_resize_python(data, (int(round(h*scale_h)),
int(round(w*scale_w))), layout)
for target, ctx in ctx_list():
executor = relay.create_executor("graph", ctx=ctx, target=target)
out = executor.evaluate(func)(data)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale=2)
y = relay.nn.upsampling(y, scale_h=2, scale_w=2)
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(analysis.free_vars(y), y)
return y
Expand All @@ -506,7 +506,7 @@ def expected():
x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c")
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c")
y = relay.nn.upsampling(y, scale_h=2, scale_w=2, layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_concatenate():
def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.analysis.free_vars(out), out)
Expand All @@ -138,7 +138,7 @@ def expected(dshape):

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out)
Expand All @@ -164,7 +164,7 @@ def test_tuple_root():
def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW")
out = relay.Tuple((upsampled, x))
return relay.Function(relay.analysis.free_vars(out), out)

Expand All @@ -174,7 +174,7 @@ def expected(dshape):
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
f1 = relay.Function([p0], upsampled)

x = relay.var("x", shape=dshape)
Expand Down
20 changes: 14 additions & 6 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import topi
import tvm
from ..util import simplify


def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corners=False):
def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
align_corners=False):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
Expand All @@ -31,8 +33,11 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn
[batch, channel, in_height, in_width]
or [batch, in_height, in_width, channel]
scale : int
Scaling factor
scale_h : float
Scaling factor for height
scale_w : float
Scaling factor for width
layout : string, optional
either "NCHW" or "NHWC"
Expand All @@ -43,14 +48,17 @@ def upsampling(data, scale, layout="NCHW", method='nearest_neighbor', align_corn
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
4-D with shape [batch, channel, in_height*scale_h, in_width*scale_w]
or [batch, in_height*scale, in_width*scale, channel]
"""
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(tvm.round(data.shape[3] * scale_w), data.shape[3].dtype)))
elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale_h), data.shape[1].dtype)),
simplify(topi.cast(tvm.round(data.shape[2] * scale_w), data.shape[2].dtype)))

else:
raise ValueError("not support this layout {} yet".format(layout))
return topi.image.resize(data, out_shape, layout=layout,
Expand Down
Loading

0 comments on commit 08776e4

Please sign in to comment.