Skip to content

Commit

Permalink
[Relay][Dynamic] Add Dynamic Resize Op (#6198)
Browse files Browse the repository at this point in the history
* WIP

* optionally remove output shape inference from topi

* fix resize

* add resize to dynamic_to_static pass

add resize to dynamic_to_static pass

* fix clang-format

* fix bad rebase

* add argument to dynamic resize doc string

* fix i386 test

* fix lint
  • Loading branch information
Matthew Brookhart authored Aug 8, 2020
1 parent bfd46ab commit 9ad33fe
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 27 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
from . import _algorithm
from . import _transform
from . import _tensor

from .import image
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic image ops."""

from . import _image
76 changes: 76 additions & 0 deletions python/tvm/relay/op/dyn/image/_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import

import tvm.topi
from tvm.runtime import convert
from tvm.te.hybrid import script
from tvm.topi.util import nchw_pack_layout, nchw_xc_layout
from ... import op as reg


# resize
@reg.register_compute("dyn.image.resize")
def compute_resize(attrs, inputs, out_type):
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
out_dtype = attrs.out_dtype
return [
tvm.topi.image.resize(inputs[0], inputs[1], layout, method, coord_trans, out_dtype,
out_type.shape)
]


reg.register_injective_schedule("dyn.image.resize")


@script
def _NCHW_resize_shape_func(dshape, size, ndim):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[2] = int64(size[0])
out[3] = int64(size[1])
return out


@script
def _NHWC_resize_shape_func(dshape, size, ndim):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[1] = int64(size[0])
out[2] = int64(size[1])
return out


@reg.register_shape_func("dyn.image.resize", True)
def resize_shape_func(attrs, inputs, _):
"""
Shape function for dyn.image.resize op.
"""
layout = attrs.layout
if layout == 'NHWC':
out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout):
out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
else:
raise ValueError("Resize Unsupported Layout", layout)
return out
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/image/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relay.op.dyn.image._make", __name__)
17 changes: 12 additions & 5 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
"""Image operations."""
from . import _make
from ..dyn.image import _make as _dyn_make
from ...expr import Expr


def resize(data,
size,
Expand All @@ -38,7 +41,7 @@ def resize(data,
data : relay.Expr
The input data to the operator.
size: Tuple of Expr
size: Tuple of Int or Expr
The out size to which the image will be resized.
layout : str, optional
Expand All @@ -61,6 +64,9 @@ def resize(data,
result: relay.Expr
The resized result.
"""
if isinstance(size, Expr):
return _dyn_make.resize(data, size, layout, method, coordinate_transformation_mode,
out_dtype)
return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)


Expand Down Expand Up @@ -156,8 +162,8 @@ def crop_and_resize(data,
result: relay.Expr
The computed result.
"""
return _make.crop_and_resize(data, boxes, box_indices, crop_size,
layout, method, extrapolation_value, out_dtype)
return _make.crop_and_resize(data, boxes, box_indices, crop_size, layout, method,
extrapolation_value, out_dtype)


def dilation2d(data,
Expand Down Expand Up @@ -213,8 +219,8 @@ def dilation2d(data,
The computed result.
"""

return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
kernel_layout, out_dtype)
return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout,
out_dtype)


def affine_grid(data, target_shape=None):
Expand All @@ -239,6 +245,7 @@ def affine_grid(data, target_shape=None):
"""
return _make.affine_grid(data, target_shape)


def grid_sample(data, grid, method='bilinear', layout='NCHW'):
"""Applies bilinear sampling to input feature map.
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):


def resize(data, size, layout="NCHW", method="bilinear",
coordinate_transformation_mode="half_pixel", out_dtype=None):
coordinate_transformation_mode="half_pixel", out_dtype=None, output_shape=None):
"""Perform resize operation on the data.
Parameters
Expand Down Expand Up @@ -519,6 +519,9 @@ def resize(data, size, layout="NCHW", method="bilinear",
out_dtype: string, optional
Type to return. If left None will be same as input type.
output_shape: optional
Shape to return. If left None will be inferred
Returns
-------
output : tvm.te.Tensor
Expand All @@ -528,19 +531,22 @@ def resize(data, size, layout="NCHW", method="bilinear",
"""

method = method.lower()

if layout == 'NHWC':
in_n, in_h, in_w, in_c = data.shape
output_shape = [in_n, size[0], size[1], in_c]
if output_shape is None:
output_shape = [in_n, size[0], size[1], in_c]
elif layout == 'NCHW':
in_n, in_c, in_h, in_w = data.shape
output_shape = [in_n, in_c, size[0], size[1]]
if output_shape is None:
output_shape = [in_n, in_c, size[0], size[1]]
elif nchw_pack_layout(layout):# for NCHWinic
in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
if output_shape is None:
output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
elif nchw_xc_layout(layout):# for NCHWxc
in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc]
if output_shape is None:
output_shape = [in_n, in_c, size[0], size[1], in_cc]
else:
raise ValueError('%s layout is not supported.' % layout)

Expand Down
109 changes: 109 additions & 0 deletions src/relay/op/dyn/image/resize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file resize.cc
* \brief Image resize operators
*/
#include <tvm/relay/attrs/image.h>
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>

#include "../../op_common.h"

namespace tvm {
namespace relay {
namespace dyn {

TVM_REGISTER_NODE_TYPE(ResizeAttrs);

bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// {data, size, out}
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCHW("NCHW");

const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;

auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, Any());
oshape.Set(3, Any());

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}

// assign output type
reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
return true;
}

// Positional relay function to create image operator
// used by frontend FFI.
Expr MakeResize(Expr data, Expr size, String layout, String method,
String coordinate_transformation_mode, DataType out_dtype) {
auto attrs = make_object<ResizeAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("dyn.image.resize");
return Call(op, {data, size}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize);

RELAY_REGISTER_OP("dyn.image.resize")
.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **size**: data is 2D array of shape (2,) with values
(new_height, new_width)
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, size[0], size[1])
for layout NHWC
(batch_size, size[0], size[1], channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type<ResizeAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("size", "Tensor", "The output size tensor.")
.set_support_level(5)
.add_type_rel("DynResize", ResizeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace dyn
} // namespace relay
} // namespace tvm
1 change: 1 addition & 0 deletions src/relay/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>

#include "../make_op.h"
#include "../op_common.h"

namespace tvm {
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ Expr MakeZeros(Array<Integer> shape, DataType dtype);

Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype);

Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
String coordinate_transformation_mode, DataType out_dtype);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
16 changes: 16 additions & 0 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \brief Rewrite Dynamic Operations to Static operations where possible
*/
#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/attrs/image.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

Expand Down Expand Up @@ -98,6 +99,21 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
{Op::Get("dyn.image.resize"),
[](const CallNode* call_node) {
if (const ConstantNode* size = call_node->args[1].as<ConstantNode>()) {
const ResizeAttrs* param = call_node->attrs.as<ResizeAttrs>();
CHECK(param);
auto size_int = ToVector(size->data);
Array<PrimExpr> size_prim;
for (size_t i = 0; i < size_int.size(); ++i) {
size_prim.push_back(size_int[i]);
}
return MakeResize(call_node->args[0], size_prim, param->layout, param->method,
param->coordinate_transformation_mode, param->out_dtype);
}
return Expr(nullptr);
}},
};
}

Expand Down
Loading

0 comments on commit 9ad33fe

Please sign in to comment.