Skip to content

Commit

Permalink
Merge pull request #3731 from IvyZX:conv-trans
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611686211
  • Loading branch information
Flax Authors committed Mar 1, 2024
2 parents d1f219f + fdbc640 commit 1abfa87
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
27 changes: 18 additions & 9 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,15 +738,16 @@ class ConvTranspose(Module):
kernel size can be passed as an integer, which will be interpreted as a
tuple of the single integer. For all other cases, it must be a sequence of
integers.
strides: a sequence of `n` integers, representing the inter-window strides.
strides: an integer or a sequence of `n` integers, representing the
inter-window strides.
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpreted as applying the same
padding in all dims and assign a single int in a sequence causes the same
padding to be used on both sides.
kernel_dilation: ``None``, or a sequence of ``n`` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel_dilation: ``None``, or an integer or a sequence of ``n`` integers,
giving the dilation factor to apply in each spatial dimension of the convolution
kernel. Convolution with kernel dilation is also known as 'atrous
convolution'.
use_bias: whether to add a bias to the output (default: True).
Expand Down Expand Up @@ -804,6 +805,17 @@ def __call__(self, inputs: Array) -> Array:
else:
kernel_size = tuple(self.kernel_size)

def maybe_broadcast(
x: Optional[Union[int, Sequence[int]]],
) -> Tuple[int, ...]:
if x is None:
# backward compatibility with using None as sentinel for
# broadcast 1
x = 1
if isinstance(x, int):
return (x,) * len(kernel_size)
return tuple(x)

# Combine all input batch dimensions into a single leading batch axis.
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
if num_batch_dimensions != 1:
Expand All @@ -814,11 +826,8 @@ def __call__(self, inputs: Array) -> Array:
]
inputs = jnp.reshape(inputs, flat_input_shape)

strides: Tuple[int, ...]
if self.strides is None:
strides = (1,) * (inputs.ndim - 2)
else:
strides = tuple(self.strides)
strides = maybe_broadcast(self.strides)
kernel_dilation = maybe_broadcast(self.kernel_dilation)

in_features = jnp.shape(inputs)[-1]
if self.transpose_kernel:
Expand Down Expand Up @@ -857,7 +866,7 @@ def __call__(self, inputs: Array) -> Array:
kernel,
strides,
padding_lax,
rhs_dilation=self.kernel_dilation,
rhs_dilation=kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def test_transpose_kernel_conv_transpose(self, use_bias):
conv_module = nn.ConvTranspose(
features=4,
use_bias=use_bias,
strides=(2, 2),
strides=2,
kernel_size=(6, 6),
padding='CIRCULAR',
transpose_kernel=True,
Expand Down

0 comments on commit 1abfa87

Please sign in to comment.