diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index d79b35e533a..fb8c0375de0 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -31,9 +31,6 @@ "nn.functional.adaptive_max_pool2d", "nn.functional.adaptive_max_pool3d", "nn.functional.alpha_dropout", - "nn.functional.conv_transpose1d", - "nn.functional.conv_transpose2d", - "nn.functional.conv_transpose3d", "nn.functional.ctc_loss", "nn.functional.dropout2d", "nn.functional.dropout3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 1f9e9f4a045..cf5928eada1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1028,9 +1028,6 @@ def _aten_convolution( output_padding, groups, ): - if transposed: - raise NotImplementedError("Transposed convolution is not implemented.") - num_shape_dim = weight.ndim - 1 batch_dims = input.shape[:-num_shape_dim] @@ -1040,14 +1037,24 @@ def make_padding(padding, num_spatial_dims): # Expand single padding to pairs expected by jax if len(padding) == 1 and len(padding) < num_spatial_dims: padding *= num_spatial_dims - return ((p, p) for p in padding) + if transposed: + # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + pad_out = [] + for i in range(num_spatial_dims): + front = dilation[i] * (weight.shape[i+2] - 1) - padding[i] + back = front + output_padding[i] + pad_out.append((front, back)) + return pad_out + else: + return ((p, p) for p in padding) def create_default_conv_dimension_numbers(num_spatial_dims): # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 # (batch dimension, feature dimension, spatial dimensions...) lhs_spec = [0, 1] # (out feature dimension, in feature dimension, spatial dimensions...) - rhs_spec = [0, 1] + # swapped for transposed convolution + rhs_spec = [1, 0] if transposed else [0, 1] # (batch dimension, feature dimension, spatial dimensions...) out_spec = [0, 1] for i in range(0, num_spatial_dims): @@ -1058,17 +1065,37 @@ def create_default_conv_dimension_numbers(num_spatial_dims): *map(tuple, (lhs_spec, rhs_spec, out_spec)) ) - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding, len(stride)), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) + if transposed: + rhs = jnp.flip(weight, range(2, 1+num_shape_dim)) + if groups != 1: + # reshape filters for tranposed depthwise convolution + assert rhs.shape[0] % groups == 0 + rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups] + rhs_shape.extend(rhs.shape[2:]) + rhs = jnp.reshape(rhs, rhs_shape) + res = jax.lax.conv_general_dilated( + input, + rhs, + (1,) * len(stride), + make_padding(padding, len(stride)), + lhs_dilation=stride, + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) + else: + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding, len(stride)), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) if bias is not None: # TODO(qihqi): bias always on channel?