Skip to content

Commit

Permalink
[torch_xla2] Add transposed conv ops (#8426)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg authored Dec 21, 2024
1 parent 38d0868 commit 5711d1d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
"nn.functional.adaptive_max_pool2d",
"nn.functional.adaptive_max_pool3d",
"nn.functional.alpha_dropout",
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
"nn.functional.ctc_loss",
"nn.functional.dropout2d",
"nn.functional.dropout3d",
Expand Down
59 changes: 43 additions & 16 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,9 +1028,6 @@ def _aten_convolution(
output_padding,
groups,
):
if transposed:
raise NotImplementedError("Transposed convolution is not implemented.")

num_shape_dim = weight.ndim - 1
batch_dims = input.shape[:-num_shape_dim]

Expand All @@ -1040,14 +1037,24 @@ def make_padding(padding, num_spatial_dims):
# Expand single padding to pairs expected by jax
if len(padding) == 1 and len(padding) < num_spatial_dims:
padding *= num_spatial_dims
return ((p, p) for p in padding)
if transposed:
# See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
pad_out = []
for i in range(num_spatial_dims):
front = dilation[i] * (weight.shape[i+2] - 1) - padding[i]
back = front + output_padding[i]
pad_out.append((front, back))
return pad_out
else:
return ((p, p) for p in padding)

def create_default_conv_dimension_numbers(num_spatial_dims):
# Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211
# (batch dimension, feature dimension, spatial dimensions...)
lhs_spec = [0, 1]
# (out feature dimension, in feature dimension, spatial dimensions...)
rhs_spec = [0, 1]
# swapped for transposed convolution
rhs_spec = [1, 0] if transposed else [0, 1]
# (batch dimension, feature dimension, spatial dimensions...)
out_spec = [0, 1]
for i in range(0, num_spatial_dims):
Expand All @@ -1058,17 +1065,37 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
*map(tuple, (lhs_spec, rhs_spec, out_spec))
)

res = jax.lax.conv_general_dilated(
input,
weight,
stride,
make_padding(padding, len(stride)),
lhs_dilation=(1,) * len(stride),
rhs_dilation=dilation,
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
feature_group_count=groups,
batch_group_count=1,
)
if transposed:
rhs = jnp.flip(weight, range(2, 1+num_shape_dim))
if groups != 1:
# reshape filters for tranposed depthwise convolution
assert rhs.shape[0] % groups == 0
rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups]
rhs_shape.extend(rhs.shape[2:])
rhs = jnp.reshape(rhs, rhs_shape)
res = jax.lax.conv_general_dilated(
input,
rhs,
(1,) * len(stride),
make_padding(padding, len(stride)),
lhs_dilation=stride,
rhs_dilation=dilation,
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
feature_group_count=groups,
batch_group_count=1,
)
else:
res = jax.lax.conv_general_dilated(
input,
weight,
stride,
make_padding(padding, len(stride)),
lhs_dilation=(1,) * len(stride),
rhs_dilation=dilation,
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
feature_group_count=groups,
batch_group_count=1,
)

if bias is not None:
# TODO(qihqi): bias always on channel?
Expand Down

0 comments on commit 5711d1d

Please sign in to comment.