Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TF] Add conv3d #4604

Merged
merged 2 commits into from
Jan 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
Expand Down
133 changes: 130 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,18 @@ def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
if len(kernel) == 3:
return prefix + '3d' + surfix
raise tvm.error.OpAttributeInvalid(
'Only 2D kernels are supported for operator {}'.format(prefix + '2d'))
'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d'))
return _impl

def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
if len(attrs['kernel_shape']) in (2, 3):
return True
return False
return _dim_check, "Only 2d kernel supported."
return _dim_check, "Only 2d or 3d kernel supported."

def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
Expand Down Expand Up @@ -425,6 +427,130 @@ def _impl(inputs, attr, params):
return out
return _impl

def _conv3d(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]

# NCDHW Layout require weights transpose
if attr['data_format'] == 'NCDHW':
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3))
weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))

attr['data_format'] = "NCDHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)]
flip_layout = True

if attr['data_format'] == 'NDHWC':
kernel_d, kernel_h, kernel_w, _, _ = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[4]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[3]

if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW':
_, _, kernel_d, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]

if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else:
msg = 'Value {} in attribute "data_format" of operator Conv is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))

# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
attr['padding'] = [0, 0, 0]
elif attr['padding'] == 'SAME':
stride_d, stride_h, stride_w = attr['strides']
kernel_d, kernel_h, kernel_w = attr['kernel_shape']

pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]

if attr['data_format'] == 'NDHWC':
in_d = pdata_shape[1]
in_h = pdata_shape[2]
in_w = pdata_shape[3]
else:
in_d = pdata_shape[2]
in_h = pdata_shape[3]
in_w = pdata_shape[4]

dilation_d = attr['dilations'][0]
dilation_h = attr['dilations'][1]
dilation_w = attr['dilations'][2]
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]]

else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'

use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCDHW" else 3

# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'],
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)

if use_bias:
out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)

if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))

return out
return _impl

def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
Expand Down Expand Up @@ -1442,6 +1568,7 @@ def _impl(inputs, attr, params):
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'Conv3D' : _conv3d('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(),
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _get_out_depth():
assert len(weight_shape) == 5
C, M, _, _, VC = weight_shape
return C * VC * M

if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
Expand Down Expand Up @@ -330,7 +331,7 @@ def compute_conv3d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCDHW"]
assert layout in ["NCDHW", "NDHWC"]
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
Expand All @@ -353,6 +354,8 @@ def schedule_conv3d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCDHW":
return topi.generic.schedule_conv3d_ncdhw(outs)
elif groups == 1 and layout == "NDHWC":
return topi.generic.schedule_conv3d_ndhwc(outs)

raise ValueError("No compatible schedule")

Expand Down
19 changes: 10 additions & 9 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);

template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
Expand Down Expand Up @@ -105,7 +105,7 @@ with the layer input to produce a tensor of outputs.
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
Expand Down Expand Up @@ -163,7 +163,8 @@ with the layer input to produce a tensor of outputs.
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>);
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);


// relay.nn.conv2d_transpose
Expand Down Expand Up @@ -337,7 +338,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
ConvInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);


Expand Down Expand Up @@ -635,7 +636,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.set_support_level(10)
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
ConvInferCorrectLayout<Conv2DWinogradAttrs>);

// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
Expand Down Expand Up @@ -744,7 +745,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
Expand Down Expand Up @@ -854,7 +855,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);

// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expand Down Expand Up @@ -903,7 +904,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);


// Positional relay function to create depthwise conv2d NCHWc operator
Expand Down Expand Up @@ -953,7 +954,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);


bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
13 changes: 9 additions & 4 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <string>
#include <utility>

#include "../op_common.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -187,7 +189,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
param->kernel_size[1], param->kernel_size[2]}};
}

/*wshape = trans_kernel_layout.BackwardShape(wshape); */
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
Expand All @@ -196,6 +198,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (weight != nullptr) {
weight_dtype = weight->dtype;
}

// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
Expand Down Expand Up @@ -225,22 +228,24 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});

IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z,
oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}

if (!dshape_ncdhw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y,
oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}

if (!dshape_ncdhw[4].as<ir::Any>()) {
oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x,
oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
param->strides[2]) + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
Expand Down
39 changes: 39 additions & 0 deletions src/relay/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,45 @@ inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
}
}

/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 2) {
*pad_h = padding[0] * 2;
*pad_w = padding[1] * 2;
} else if (padding.size() == 4) {
*pad_h = padding[0] + padding[2];
*pad_w = padding[1] + padding[3];
} else {
CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got "
<< padding.size();
}
}

/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
IndexExpr* pad_h, IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_d = padding[0] * 2;
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 3) {
*pad_d = padding[0] * 2;
*pad_h = padding[1] * 2;
*pad_w = padding[2] * 2;
} else if (padding.size() == 6) {
*pad_d = padding[0] + padding[3];
*pad_h = padding[1] + padding[4];
*pad_w = padding[2] + padding[5];
} else {
CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got "
<< padding.size();
}
}

} // namespace relay
} // namespace tvm

Expand Down
Loading