-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 #4484
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,10 +248,17 @@ bool Conv2DTransposeRel(const Array<Type>& types, | |
} | ||
// dilation | ||
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); | ||
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - | ||
2 * param->padding[0] + param->output_padding[0])); | ||
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - | ||
2 * param->padding[1] + param->output_padding[1])); | ||
if ( param->padding.size() == 2 ) { | ||
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - | ||
2 * param->padding[0] + param->output_padding[0])); | ||
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - | ||
2 * param->padding[1] + param->output_padding[1])); | ||
} else if (param->padding.size() == 4) { | ||
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - | ||
param->padding[0] - param->padding[2] + param->output_padding[0])); | ||
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - | ||
param->padding[1] - param->padding[3] + param->output_padding[1])); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Formulas are very similar except of using
Can we calculate correct padding and then use the same formulas with calculated padding for both cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The additional condition is to handle the head padding and tail padding are diffenct, for even kernel in this case. I didn't quite understand your point. Please clearify. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about the following?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is better. Thanks! |
||
} | ||
|
||
DataType out_dtype = param->out_dtype; | ||
if (out_dtype.bits() == 0) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -362,10 +362,16 @@ def test_forward_convolution(): | |
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') | ||
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NCHW', [4, 176, 8, 8]) | ||
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NCHW', [4, 176, 8, 8]) | ||
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NCHW', [4, 176, 8, 8]) | ||
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', | ||
'NCHW', [4, 19, 17, 17]) | ||
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', | ||
'NCHW', [4, 124, 17, 17]) | ||
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also add test case for
E.g. if input is 5x5 then valid outputs are 9x9 or 10x10 (you can use one or another in the output_shape tensor) regardless of the kernel size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The strides in transpose convolution is used to dilate the input, see this. If the input is dilated, there would be no way to get 'SAME' size of output by only padding. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. I can get your point. 'SAME' means the size of the enlarged input size(by dilation). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TF code for kernel 2x2, strides 2x2 and padding SAME is
|
||
'NCHW', [4, 124, 17, 17]) | ||
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', | ||
'NCHW', [4, 12, 17, 17]) | ||
# kernel 2x2, strides (2,2) | ||
|
@@ -388,10 +394,16 @@ def test_forward_convolution(): | |
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') | ||
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NHWC', [4, 8, 8, 176]) | ||
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NHWC', [4, 8, 8, 176]) | ||
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', | ||
'NHWC', [4, 8, 8, 176]) | ||
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', | ||
'NHWC', [4, 17, 17, 19]) | ||
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', | ||
'NHWC', [4, 17, 17, 124]) | ||
_test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', | ||
'NHWC', [4, 17, 17, 124]) | ||
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', | ||
'NHWC', [4, 17, 17, 12]) | ||
# kernel 2x2, strides (2,2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor formatting comment. I think we do not need spaces after
(
and before)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will fix it.