Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][DYN] Dynamic upsampling relay op #6273

Merged
merged 18 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic ops."""

from . import _nn
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/nn/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relay.op.dyn.nn._make", __name__)
85 changes: 85 additions & 0 deletions python/tvm/relay/op/dyn/nn/_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
"""Backend compiler related feature registration"""
electriclilies marked this conversation as resolved.
Show resolved Hide resolved

from __future__ import absolute_import

from tvm import topi

from tvm.runtime import convert
from tvm.te.hybrid import script
from ...op import register_shape_func, register_compute
from ...op import register_injective_schedule

# upsampling
@register_compute("dyn.nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype):
data = inputs[0]
scale_h = inputs[1]
scale_w = inputs[2]
layout = attrs.layout
method = attrs.method
align_corners = attrs.align_corners
return [topi.nn.upsampling(data, scale_h, scale_w, layout,
method, align_corners, out_dtype.shape)]

register_injective_schedule("dyn.nn.upsampling")

#####################
# Shape functions #
#####################

# upsampling
@script
def _upsampling_nhwc_shape_func(dshape, scale_h, scale_w, ndim):
out = output_tensor((ndim,), "int64")
batch_size = dshape[0]
in_height = dshape[1]
in_width = dshape[2]
channels = dshape[3]
out[0] = int64(batch_size)
out[1] = int64(round(in_height * scale_h[0]))
out[2] = int64(round(in_width * scale_w[0]))
out[3] = int64(channels)
return out

@script
def _upsampling_nchw_shape_func(dshape, scale_h, scale_w, ndim):
out = output_tensor((ndim,), "int64")
batch_size = dshape[0]
channels = dshape[1]
in_height = dshape[2]
in_width = dshape[3]
out[0] = int64(batch_size)
out[1] = int64(channels)
out[2] = int64(round(in_height * scale_h[0]))
out[3] = int64(round(in_width * scale_w[0]))
return out

@register_shape_func("dyn.nn.upsampling", True)
def upsampling_shape_func(attrs, inputs, _):
"""Shape function for upsampling. Supports NCHW and NHWC layouts."""
if attrs.layout == "NHWC":
shape_func = _upsampling_nhwc_shape_func(inputs[0].shape, inputs[1], inputs[2],
convert(len(inputs[0].shape)))
elif attrs.layout == "NCHW":
shape_func = _upsampling_nchw_shape_func(inputs[0].shape, inputs[1], inputs[2],
convert(len(inputs[0].shape)))
else:
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
assert false, "Layout passed to the upsampling shape func must be NCHW or NHWC"
return [shape_func]
14 changes: 11 additions & 3 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from tvm.relay import expr

from . import _make
from ..dyn.nn import _make as _dyn_make
from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
from ...expr import const, Expr


def conv1d(data,
Expand Down Expand Up @@ -1147,13 +1149,13 @@ def upsampling(data,

Parameters
----------
data : tvm.relay.Expr
data : tvm.relay.Expr or tuple<anytype> or list<anytype>
The input data to the operator.

scale_h : tvm.relay.Expr
scale_h : tvm.relay.Expr or int or float
The scale factor for height upsampling.

scale_w : tvm.relay.Expr
scale_w : tvm.relay.Expr or int or float
The scale factor for width upsampling.

layout : str, optional
Expand All @@ -1170,6 +1172,12 @@ def upsampling(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_h, Expr):
scale_h = const(scale_h, "float64")
if not isinstance(scale_w, Expr):
scale_w = const(scale_w, "float64")
return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _math_intrin(func_id, args):
from tvm.tir import op
return getattr(op, func_id)(*args)

sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name


def _min_max(func_id, args):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/te/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def max_num_threads(allow_none=True):
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'round' : round,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
Expand Down
27 changes: 20 additions & 7 deletions python/tvm/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
align_corners=False):
align_corners=False, output_shape=None):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.

Expand Down Expand Up @@ -52,17 +52,30 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
"""
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype)))
if not output_shape: #static case
scaled_h = data.shape[2] * scale_h
scaled_w = data.shape[3] * scale_w
reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
else: #dynamic case -- we don't need to scale; already done in shape func
reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)))
elif layout == "NHWC":
out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
if not output_shape: #static case
scaled_h = data.shape[1] * scale_h
scaled_w = data.shape[2] * scale_w
reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)))
else: #dynamic case
reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)))

else:
raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
return topi.image.resize(data, out_shape, layout=layout,
method=method, coordinate_transformation_mode=coord_trans)
return topi.image.resize(data, reshape_size, layout=layout,
method=method, coordinate_transformation_mode=coord_trans,
output_shape=output_shape)


def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
Expand Down
123 changes: 123 additions & 0 deletions src/relay/op/dyn/nn/upsampling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file upsampling.cc
* \brief upsampling operator
*/

#include "../../nn/upsampling.h"

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>

#include <vector>

#include "../../op_common.h"

namespace tvm {
namespace relay {
namespace dyn {

bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [data_type, scale_h_type, scale_w_type, ret_type]
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* scale_h = types[1].as<TensorTypeNode>();
const auto* scale_w = types[2].as<TensorTypeNode>();
if (data == nullptr) return false;
if (scale_h == nullptr) return false;
if (scale_w == nullptr) return false;

CHECK_EQ(data->shape.size(), 4);
CHECK_EQ(scale_h->shape.size(), 0);
CHECK_EQ(scale_w->shape.size(), 0);
static const Layout kNCHW("NCHW");

const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param);
const Layout in_layout(param->layout);

auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only supports input layouts that are convertible from NCHW."
<< " But got " << in_layout;

auto nchw_oshape = layout_converter.ForwardShape(data->shape);

nchw_oshape.Set(2, Any());
nchw_oshape.Set(3, Any());
auto oshape = layout_converter.BackwardShape(nchw_oshape);

reporter->Assign(types[3], TensorType(oshape, data->dtype));
return true;
}

// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method,
bool align_corners) {
auto attrs = make_object<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->align_corners = align_corners;

static const Op& op = Op::Get("dyn.nn.upsampling");
return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling);

RELAY_REGISTER_OP("dyn.nn.upsampling")
.describe(
R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.

- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC

- **scale_h**: scale_h is an integer of the amount to scale height by

- **scale_w**: scale_w is an integer of the amount to scale width by

- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, in_height*scale, in_width*scale)

for layout NHWC
(batch_size, in_height*scale, in_width*scale, channels)

)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSamplingAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("scale_h", "double", "The scale for the height.")
.add_argument("scale_w", "double", "The scale for the width.")
.set_support_level(2)
.add_type_rel("DynamicUpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ Expr MakeTile(Expr data, Array<Integer> reps);

Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);

Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
bool align_corners);

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);

Expand Down
Loading