Skip to content

Commit

Permalink
[RELAY][DYN] Dynamic upsampling relay op (apache#6273)
Browse files Browse the repository at this point in the history
* implementing upsampling op

* fix lint

* fix lint again

* add doc to upsampling shape func

* fix set attrs build problem

* fixing imports

* reverting data layout transform changes

* moved layout template to header file

* changing python module from nn.dyn to dyn.nn

* adding support for more layouts to upsampling

* fix lint

* fix upsampling doc

* change _nn.py doc

* failed flakey test

* fix build after merge
  • Loading branch information
electriclilies authored and trevor-m committed Sep 3, 2020
1 parent bf6a8e3 commit a9078e8
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 41 deletions.
49 changes: 45 additions & 4 deletions python/tvm/relay/op/dyn/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,62 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
"""Backend compiler related feature registration"""
"""Backend compiler related feature registration for dynamic relay ops in nn namespace"""

from __future__ import absolute_import

from tvm import topi

from tvm.runtime import convert
from tvm.te.hybrid import script
from ...op import register_shape_func
from ...op import register_broadcast_schedule
from ...op import register_shape_func, register_compute
from ...op import register_injective_schedule, register_broadcast_schedule

# pad
# upsampling
@register_compute("dyn.nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype):
data = inputs[0]
scale_h = inputs[1]
scale_w = inputs[2]
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
return [topi.nn.upsampling(data, scale_h, scale_w, layout,
method, align_corners, out_dtype.shape)]

register_injective_schedule("dyn.nn.upsampling")
register_broadcast_schedule("dyn.nn.pad")

#####################
# Shape functions #
#####################

# upsampling
@script
def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
out[0] = int64(dshape[0])
out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
out[channel_axis] = int64(dshape[channel_axis])
return out

@register_shape_func("dyn.nn.upsampling", True)
def upsampling_shape_func(attrs, inputs, _):
"""Shape function for upsampling. Supports NCHW and NHWC layouts."""
layout = attrs.layout
height_axis = width_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2],
convert(height_axis), convert(width_axis),
convert(channel_axis))]
# pad
@script
def _dyn_pad_shape_func(data, pad_width):
ndim = len(data.shape)
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,10 +1152,10 @@ def upsampling(data,
data : tvm.relay.Expr
The input data to the operator.
scale_h : tvm.relay.Expr
scale_h : tvm.relay.Expr or int or float
The scale factor for height upsampling.
scale_w : tvm.relay.Expr
scale_w : tvm.relay.Expr or int or float
The scale factor for width upsampling.
layout : str, optional
Expand All @@ -1172,6 +1172,12 @@ def upsampling(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_h, Expr):
scale_h = const(scale_h, "float64")
if not isinstance(scale_w, Expr):
scale_w = const(scale_w, "float64")
return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)


Expand Down
28 changes: 21 additions & 7 deletions python/tvm/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
align_corners=False):
align_corners=False, output_shape=None):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
Expand Down Expand Up @@ -52,16 +52,30 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
"""
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype)))
if not output_shape: #static case
scaled_h = data.shape[2] * scale_h
scaled_w = data.shape[3] * scale_w
reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
else: #dynamic case -- we don't need to scale; already done in shape func
reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)))
elif layout == "NHWC":
out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
if not output_shape: #static case
scaled_h = data.shape[1] * scale_h
scaled_w = data.shape[2] * scale_w
reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)))
else: #dynamic case
reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)))

else:
raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
return topi.image.resize(data, out_shape, layout=layout,
method=method, coordinate_transformation_mode=coord_trans)
return topi.image.resize(data, reshape_size, layout=layout,
method=method, coordinate_transformation_mode=coord_trans,
output_shape=output_shape)


def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
Expand Down
123 changes: 123 additions & 0 deletions src/relay/op/dyn/nn/upsampling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file upsampling.cc
* \brief upsampling operator
*/

#include "../../nn/upsampling.h"

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>

#include <vector>

#include "../../op_common.h"

namespace tvm {
namespace relay {
namespace dyn {

bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [data_type, scale_h_type, scale_w_type, ret_type]
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* scale_h = types[1].as<TensorTypeNode>();
const auto* scale_w = types[2].as<TensorTypeNode>();
if (data == nullptr) return false;
if (scale_h == nullptr) return false;
if (scale_w == nullptr) return false;

CHECK_EQ(data->shape.size(), 4);
CHECK_EQ(scale_h->shape.size(), 0);
CHECK_EQ(scale_w->shape.size(), 0);
static const Layout kNCHW("NCHW");

const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param);
const Layout in_layout(param->layout);

auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only supports input layouts that are convertible from NCHW."
<< " But got " << in_layout;

auto nchw_oshape = layout_converter.ForwardShape(data->shape);

nchw_oshape.Set(2, Any());
nchw_oshape.Set(3, Any());
auto oshape = layout_converter.BackwardShape(nchw_oshape);

reporter->Assign(types[3], TensorType(oshape, data->dtype));
return true;
}

// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method,
bool align_corners) {
auto attrs = make_object<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->align_corners = align_corners;

static const Op& op = Op::Get("dyn.nn.upsampling");
return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling);

RELAY_REGISTER_OP("dyn.nn.upsampling")
.describe(
R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **scale_h**: scale_h is an integer of the amount to scale height by
- **scale_w**: scale_w is an integer of the amount to scale width by
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, in_height*scale, in_width*scale)
for layout NHWC
(batch_size, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSamplingAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("scale_h", "double", "The scale for the height.")
.add_argument("scale_w", "double", "The scale for the width.")
.set_support_level(2)
.add_type_rel("DynamicUpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ Expr MakeTile(Expr data, Array<Integer> reps);

Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);

Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
bool align_corners);

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);

Expand Down
31 changes: 4 additions & 27 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
* \file upsampling.cc
* \brief upsampling operator
*/

#include "upsampling.h"

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>

#include <utility>
#include <vector>

#include "../op_common.h"
Expand All @@ -36,33 +40,6 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
Expand Down
67 changes: 67 additions & 0 deletions src/relay/op/nn/upsampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
*
* \file src/relay/op/nn/upsampling.h
* \brief implementation of the InferCorrectLayout pass for upsampling
*/

#ifndef TVM_RELAY_OP_NN_UPSAMPLING_H_
#define TVM_RELAY_OP_NN_UPSAMPLING_H_

#include <tvm/relay/attrs/nn.h>
#include <tvm/tir/data_layout.h>

#include "../op_common.h"

namespace tvm {
namespace relay {

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_NN_UPSAMPLING_H_
Loading

0 comments on commit a9078e8

Please sign in to comment.