Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][DYN] Dynamic UpSampling3D Op #6353

Merged
merged 11 commits into from
Sep 4, 2020
56 changes: 49 additions & 7 deletions python/tvm/relay/op/dyn/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,21 @@ def compute_upsampling(attrs, inputs, out_dtype):
return [topi.nn.upsampling(data, scale_h, scale_w, layout,
method, align_corners, out_dtype.shape)]

# upsampling3d
@register_compute("dyn.nn.upsampling3d")
def compute_upsampling3d(attrs, inputs, out_dtype):
data = inputs[0]
scale_d = inputs[1]
scale_h = inputs[2]
scale_w = inputs[3]
layout = attrs.layout
method = attrs.method
coordinate_transformation_mode = attrs.coordinate_transformation_mode
return [topi.nn.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,\
coordinate_transformation_mode, out_dtype.shape)]

register_injective_schedule("dyn.nn.upsampling")
register_injective_schedule("dyn.nn.upsampling3d")
register_broadcast_schedule("dyn.nn.pad")

#####################
Expand All @@ -47,12 +61,12 @@ def compute_upsampling(attrs, inputs, out_dtype):

# upsampling
@script
def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis):
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis):
out = output_tensor((4,), "int64")
out[0] = int64(dshape[0])
for i in const_range(4):
out[i] = int64(dshape[i])
out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
out[channel_axis] = int64(dshape[channel_axis])
return out

@register_shape_func("dyn.nn.upsampling", True)
Expand All @@ -65,11 +79,39 @@ def upsampling_shape_func(attrs, inputs, _):
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2],
convert(height_axis), convert(width_axis),
convert(channel_axis))]
convert(height_axis), convert(width_axis))]

# upsampling3d
@script
def _upsampling3d_shape_func(dshape, scale_d, scale_h, scale_w,
depth_axis, height_axis, width_axis):
out = output_tensor((5,), "int64")
for i in const_range(5):
out[i] = int64(dshape[i])
out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[0]))
out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
return out


@register_shape_func("dyn.nn.upsampling3d", True)
def upsampling3d_shape_func(attrs, inputs, _):
"""Shape function for upsampling. Supports NCHW and NHWC layouts."""
layout = attrs.layout
depth_axis = height_axis = width_axis = 1
for i, letter in enumerate(layout):
if letter == "D":
depth_axis = i
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
return [_upsampling3d_shape_func(inputs[0].shape, inputs[1], inputs[2],
inputs[3], convert(depth_axis),
convert(height_axis),
convert(width_axis))]

# pad
@script
def _dyn_pad_shape_func(data, pad_width):
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,15 @@ def upsampling3d(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_d, Expr):
scale_d = const(scale_d, "float64")
if not isinstance(scale_h, Expr):
scale_h = const(scale_h, "float64")
if not isinstance(scale_w, Expr):
scale_w = const(scale_w, "float64")
return _dyn_make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
coordinate_transformation_mode)
return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
coordinate_transformation_mode)

Expand Down
8 changes: 7 additions & 1 deletion python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,9 @@ def resize(data, size, layout="NCHW", method="bilinear",
out_dtype: string, optional
Type to return. If left None will be same as input type.

output_shape: optional
output_shape: tvm.tir.container.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as output_shape)

Returns
-------
Expand Down Expand Up @@ -680,17 +681,22 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor",
inputs is a 5-D tensor with shape
[batch, channel, in_depth, in_height, in_width]
or [batch, in_depth, in_height, in_width, channel]

size: Tuple
Output resolution scale to

layout: string, optional
"NCDHW", "NDHWC", or "NCDHWc".

coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.

Available options are "half_pixel", "align_corners" and "asymmetric".
method: {"trilinear", "nearest_neighbor"}
Method to be used for resizing.

out_dtype: string, optional
Type to return. If left None will be same as input type.

Expand Down
41 changes: 32 additions & 9 deletions python/tvm/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
method : {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for upsampling.

output_shape: tvm.tir.container.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as output_shape)

Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -79,7 +83,7 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',


def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
coordinate_transformation_mode="half_pixel"):
coordinate_transformation_mode="half_pixel", output_shape=None):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.

Expand Down Expand Up @@ -111,6 +115,10 @@ def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='neares
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".

output_shape: tvm.tir.container.Array, optional
Shape to return. If left None will be inferred
(If shape is determined dynamically, pass out_dtype.shape as output_shape)

Returns
-------
output : tvm.te.Tensor
Expand All @@ -119,15 +127,30 @@ def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='neares
"""
base_layout = layout[0:5]
if base_layout == "NCDHW":
out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_d), data.shape[2].dtype)),
simplify(topi.cast(te.round(data.shape[3] * scale_h), data.shape[3].dtype)),
simplify(topi.cast(te.round(data.shape[4] * scale_w), data.shape[4].dtype)))
if not output_shape: # static case
scaled_d = data.shape[2] * scale_d
scaled_h = data.shape[3] * scale_h
scaled_w = data.shape[4] * scale_w
resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[2].dtype)),
simplify(topi.cast(te.round(scaled_h), data.shape[3].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[4].dtype)))
else: # dynamic case -- don't need to scale; already done in shape func
resize_shape = (simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)),
simplify(topi.cast(te.round(output_shape[4]), data.shape[4].dtype)))
elif layout == "NDHWC":
out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_d), data.shape[1].dtype)),
simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype)))

if not output_shape: # static case
scaled_d = data.shape[1] * scale_d
scaled_h = data.shape[2] * scale_h
scaled_w = data.shape[3] * scale_w
resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[1].dtype)),
simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
else: # dynamic case
resize_shape = (simplify(topi.cast(te.round(output_shape[1]), data.shape[1].dtype)),
simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)),
simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)))
else:
raise ValueError("not support this layout {} yet".format(layout))
return topi.image.resize3d(data, out_shape, layout=layout, method=method,
return topi.image.resize3d(data, resize_shape, layout=layout, method=method,
coordinate_transformation_mode=coordinate_transformation_mode)
86 changes: 82 additions & 4 deletions src/relay/op/dyn/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ RELAY_REGISTER_OP("dyn.nn.upsampling")
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC

- **scale_h**: scale_h is an integer of the amount to scale height by
- **scale_h**: scale_h is a double of the amount to scale height by

- **scale_w**: scale_w is an integer of the amount to scale width by
- **scale_w**: scale_w is a double of the amount to scale width by

- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, in_height*scale, in_width*scale)
(batch_size, channels, in_height*scale_h, in_width*scale_w)

for layout NHWC
(batch_size, in_height*scale, in_width*scale, channels)
(batch_size, in_height*scale_h, in_width*scale_w, channels)

)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSamplingAttrs>()
Expand All @@ -118,6 +118,84 @@ RELAY_REGISTER_OP("dyn.nn.upsampling")
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// UpSampling3D
bool UpSampling3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [data_type, scale_d_type, scale_h_type, scale_w_type, ret_type]
CHECK_EQ(types.size(), 5);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCDHW("NCDHW");

const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);

auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(layout_converter.defined())
<< "UpSampling3D only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;

auto ncdhw_oshape = layout_converter.ForwardShape(data->shape);

ncdhw_oshape.Set(2, Any());
ncdhw_oshape.Set(3, Any());
ncdhw_oshape.Set(4, Any());

auto oshape = layout_converter.BackwardShape(ncdhw_oshape);

reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}

Expr MakeUpSampling3D(Expr data, Expr scale_d, Expr scale_h, Expr scale_w, String layout,
String method, String coordinate_transformation_mode) {
auto attrs = make_object<UpSampling3DAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->coordinate_transformation_mode = coordinate_transformation_mode;

static const Op& op = Op::Get("dyn.nn.upsampling3d");
return Call(op, {data, scale_d, scale_h, scale_w}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D);

RELAY_REGISTER_OP("dyn.nn.upsampling3d")
.describe(R"code(Perform upsampling on input array with nearest neighbour or
bilinear interpolation.

- **data**: data is 5D array of shape
(batch_size, channels, in_depth, in_height, in_width) for NCDHW
(batch_size, in_depth, in_height, in_width, channels) for NDHWC

- **scale_d**: scale_d is a double of the amount to scale depth by

- **scale_h**: scale_h is a double of the amount to scale height by

- **scale_w**: scale_w is a double of the amount to scale width by

- **out**: Output is 5D array of shape
for layout NCDHW
(batch_size, channels, in_depth*scale_d, in_height*scale_h, in_width*scale_w)

for layout NDHWC
(batch_size, in_depth*scale_d, in_height*scale_h, in_width*scale_w, channels)

)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSampling3DAttrs>()
.set_num_inputs(4)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("scale_d", "double", "The scale for the depth.")
.add_argument("scale_h", "double", "The scale for the height.")
.add_argument("scale_w", "double", "The scale for the width.")
.set_support_level(2)
.add_type_rel("DynamicUpSampling3D", UpSampling3DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataT
Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
bool align_corners);

Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, String layout,
String method, String coordinate_transformation_mode);

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);

Expand Down
19 changes: 19 additions & 0 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
{Op::Get("dyn.nn.upsampling3d"),
[](const CallNode* call_node) {
const ConstantNode* scale_d = call_node->args[1].as<ConstantNode>();
const ConstantNode* scale_h = call_node->args[2].as<ConstantNode>();
const ConstantNode* scale_w = call_node->args[3].as<ConstantNode>();
if (scale_d && scale_h && scale_w) {
CHECK_EQ(scale_d->data->ndim, 0);
CHECK_EQ(scale_h->data->ndim, 0);
CHECK_EQ(scale_w->data->ndim, 0);
const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>();
CHECK(param);

return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data),
ToScalar(scale_h->data), ToScalar(scale_w->data),
param->layout, param->method,
param->coordinate_transformation_mode);
}
return Expr(nullptr);
}},
{Op::Get("dyn.nn.pad"),
[](const CallNode* call_node) {
const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
Expand Down
Loading