Skip to content

Commit

Permalink
[ONNX][Relay] Support "tf_crop_and_resize" in relay Resize op. (#9475)
Browse files Browse the repository at this point in the history
* add fallback to opset 11

* Support tf_crop_and_resize in resize op

* change api use in the rest of the codebase

really fix the tests

* respond to review comments, improve doc strings

* fix docstring indentation

* remove N anc C from resize roi
  • Loading branch information
Matthew Brookhart authored Nov 16, 2021
1 parent 22ba652 commit eda12cb
Show file tree
Hide file tree
Showing 22 changed files with 419 additions and 104 deletions.
24 changes: 24 additions & 0 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,21 @@ namespace relay {
/*! \brief Attributes used in image resize1d operator */
struct Resize1DAttrs : public tvm::AttrsNode<Resize1DAttrs> {
Array<IndexExpr> size;
Array<FloatImm> roi;
std::string layout;
std::string method;
std::string coordinate_transformation_mode;
std::string rounding_method;
double cubic_alpha;
int cubic_exclude;
double extrapolation_value;
DataType out_dtype;

TVM_DECLARE_ATTRS(Resize1DAttrs, "relay.attrs.Resize1DAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
TVM_ATTR_FIELD(roi)
.set_default(NullValue<Array<FloatImm> >())
.describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel and width"
Expand Down Expand Up @@ -73,23 +78,31 @@ struct Resize1DAttrs : public tvm::AttrsNode<Resize1DAttrs> {
TVM_ATTR_FIELD(cubic_exclude)
.set_default(0)
.describe("Flag to exclude exterior of the image during cubic interpolation");
TVM_ATTR_FIELD(extrapolation_value)
.set_default(0.0)
.describe("Value to return when roi is outside of the image");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};

/*! \brief Attributes used in image resize2d operator */
struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
Array<IndexExpr> size;
Array<FloatImm> roi;
std::string layout;
std::string method;
std::string coordinate_transformation_mode;
std::string rounding_method;
double cubic_alpha;
int cubic_exclude;
double extrapolation_value;
DataType out_dtype;

TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
TVM_ATTR_FIELD(roi)
.set_default(NullValue<Array<FloatImm> >())
.describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand Down Expand Up @@ -118,23 +131,31 @@ struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
TVM_ATTR_FIELD(cubic_exclude)
.set_default(0)
.describe("Flag to exclude exterior of the image during bicubic interpolation");
TVM_ATTR_FIELD(extrapolation_value)
.set_default(0.0)
.describe("Value to return when roi is outside of the image");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};

/*! \brief Attributes used in image resize3d operator */
struct Resize3DAttrs : public tvm::AttrsNode<Resize3DAttrs> {
Array<IndexExpr> size;
Array<FloatImm> roi;
std::string layout;
std::string method;
std::string coordinate_transformation_mode;
std::string rounding_method;
double cubic_alpha;
int cubic_exclude;
double extrapolation_value;
DataType out_dtype;

TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
TVM_ATTR_FIELD(roi)
.set_default(NullValue<Array<FloatImm> >())
.describe("Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
"Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
Expand Down Expand Up @@ -163,6 +184,9 @@ struct Resize3DAttrs : public tvm::AttrsNode<Resize3DAttrs> {
TVM_ATTR_FIELD(cubic_exclude)
.set_default(0)
.describe("Flag to exclude exterior of the image during tricubic interpolation");
TVM_ATTR_FIELD(extrapolation_value)
.set_default(0.0)
.describe("Value to return when roi is outside of the image");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};
Expand Down
66 changes: 56 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,13 +2610,13 @@ def _impl_v10(cls, inputs, attr, params):
out = None
if ndims == 3:
out_size = fold_constant(_op.strided_slice(size, [2], [3]))
out = _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric")
out = _op.image.resize1d(inputs[0], out_size, None, "NCW", method, "asymmetric")
elif ndims == 4:
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
out = _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric")
out = _op.image.resize2d(inputs[0], out_size, None, "NCHW", method, "asymmetric")
elif ndims == 5:
out_size = fold_constant(_op.strided_slice(size, [2], [5]))
out = _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric")
out = _op.image.resize3d(inputs[0], out_size, None, "NCDHW", method, "asymmetric")
else:
raise NotImplementedError("Resize only supports 3, 4, or 5 dims")
return out
Expand All @@ -2639,6 +2639,12 @@ def _impl_v11(cls, inputs, attr, params):
def _impl_v13(cls, inputs, attr, params):
scale = inputs[2]
size = inputs[3]

# Some versions of onnx exporters produce an opset 13 model with the opset 11
# resize op, handle that edge case
if scale is not None and size is not None:
return cls._impl_v11(inputs, attr, params)

if size is not None:
assert scale is None, "One of scale or size should be passed, not both."
else:
Expand All @@ -2657,6 +2663,9 @@ def v11_13_common(cls, inputs, size, attr, params):
they handle the passing of scale and size. This utility
provides the implementation for both
"""
roi = inputs[1]
if roi is not None and infer_shape(roi)[0] == 0:
roi = None
ndims = len(infer_shape(inputs[0]))
mode = attr.get("mode").decode("ascii")
if mode == "nearest":
Expand All @@ -2674,23 +2683,60 @@ def v11_13_common(cls, inputs, size, attr, params):
nearest_mode = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii")
alpha = attr.get("cubic_coeff_a", -0.75)
exclude = attr.get("exclude_outside", 0)
extrapolation_value = attr.get("extrapolation_value", 0.0)

if roi is not None:
roi = fold_constant(
_op.concatenate(
[
_op.strided_slice(roi, [2], [ndims]),
_op.strided_slice(roi, [ndims + 2], [2 * ndims]),
],
axis=0,
)
)

out_size = fold_constant(_op.strided_slice(size, [2], [ndims]))

out_size = fold_constant(_op.strided_slice(size, [2], [4]))
out = None
if ndims == 3:
out_size = fold_constant(_op.strided_slice(size, [2], [3]))
out = _op.image.resize1d(
inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude
inputs[0],
out_size,
roi,
"NCW",
method,
coord_trans,
nearest_mode,
alpha,
exclude,
extrapolation_value,
)
elif ndims == 4:
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
out = _op.image.resize2d(
inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude
inputs[0],
out_size,
roi,
"NCHW",
method,
coord_trans,
nearest_mode,
alpha,
exclude,
extrapolation_value,
)
elif ndims == 5:
out_size = fold_constant(_op.strided_slice(size, [2], [5]))
out = _op.image.resize3d(
inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude
inputs[0],
out_size,
roi,
"NCDHW",
method,
coord_trans,
nearest_mode,
alpha,
exclude,
extrapolation_value,
)
else:
raise NotImplementedError("Resize only supports 3, 4, or 5 dims")
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,7 @@ def upsample(inputs, input_types):

def func(x):
return _op.image.resize2d(
x, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75
x, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
)

if self.is_quantized_tensor(data):
Expand Down Expand Up @@ -1854,7 +1854,7 @@ def upsample3d(inputs, input_types):
else:
coord_trans = "half_pixel"

return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans)
return _op.image.resize3d(data, out_size, None, "NCDHW", method, coord_trans)

return upsample3d

Expand Down Expand Up @@ -2186,7 +2186,9 @@ def interpolate(self, inputs, input_types):
else:
coord_trans = "half_pixel"

return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75)
return _op.image.resize2d(
data, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
)

def numel(self, inputs, input_types):
return _op.ndarray_size(inputs[0])
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,9 @@ def _impl(inputs, attr, params, mod):

# Ignore the new attributes from TF2.0, for now.
return AttrCvt(
op_name="resize2d", ignores=["Tdim", "half_pixel_centers"], extras={"method": method}
op_name="resize2d",
ignores=["Tdim", "half_pixel_centers"],
extras={"method": method, "roi": None},
)(inputs, attr)

return _impl
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def _convert_resize(self, method, op):
if bilinear_method and input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.image.resize2d(
in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans
in_expr, target_size, None, "NHWC", method, coordinate_transformation_mode=coord_trans
)
if bilinear_method and output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/dyn/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,29 @@
# resize
@reg.register_compute("dyn.image.resize2d")
def compute_resize2d(attrs, inputs, out_type):
"""
Compute function calls into topi
"""
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
cubic_alpha = attrs.cubic_alpha
cubic_exclude = attrs.cubic_exclude
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [
tvm.topi.image.resize2d(
inputs[0],
inputs[2],
inputs[1],
layout,
method,
coord_trans,
rounding_method,
cubic_alpha,
cubic_exclude,
extrapolation_value,
out_dtype,
out_type.shape,
)
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,27 @@
def compute_resize1d(attrs, inputs, out_type):
"""compute definition for resize1d op"""
size = attrs.size
roi = attrs.roi
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
cubic_alpha = attrs.cubic_alpha
cubic_exclude = attrs.cubic_exclude
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [
topi.image.resize1d(
inputs[0],
roi,
size,
layout,
method,
coord_trans,
rounding_method,
cubic_alpha,
cubic_exclude,
extrapolation_value,
out_dtype,
)
]
Expand Down Expand Up @@ -128,23 +132,27 @@ def resize1d_shape_func(attrs, inputs, _):
def compute_resize2d(attrs, inputs, out_type):
"""compute definition for resize2d op"""
size = attrs.size
roi = attrs.roi
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
cubic_alpha = attrs.cubic_alpha
cubic_exclude = attrs.cubic_exclude
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [
topi.image.resize2d(
inputs[0],
roi,
size,
layout,
method,
coord_trans,
rounding_method,
cubic_alpha,
cubic_exclude,
extrapolation_value,
out_dtype,
)
]
Expand Down Expand Up @@ -225,23 +233,27 @@ def resize2d_shape_func(attrs, inputs, _):
def compute_resize3d(attrs, inputs, out_type):
"""compute definition for resize3d op"""
size = attrs.size
roi = attrs.roi
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
cubic_alpha = attrs.cubic_alpha
cubic_exclude = attrs.cubic_exclude
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [
topi.image.resize3d(
inputs[0],
roi,
size,
layout,
method,
coord_trans,
rounding_method,
cubic_alpha,
cubic_exclude,
extrapolation_value,
out_dtype,
)
]
Expand Down
Loading

0 comments on commit eda12cb

Please sign in to comment.