Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast of strides and kernel_dilation to nn.ConvTranspose #3731

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,15 +738,16 @@ class ConvTranspose(Module):
kernel size can be passed as an integer, which will be interpreted as a
tuple of the single integer. For all other cases, it must be a sequence of
integers.
strides: a sequence of `n` integers, representing the inter-window strides.
strides: an integer or a sequence of `n` integers, representing the
inter-window strides.
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpreted as applying the same
padding in all dims and assign a single int in a sequence causes the same
padding to be used on both sides.
kernel_dilation: ``None``, or a sequence of ``n`` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel_dilation: ``None``, or an integer or a sequence of ``n`` integers,
giving the dilation factor to apply in each spatial dimension of the convolution
kernel. Convolution with kernel dilation is also known as 'atrous
convolution'.
use_bias: whether to add a bias to the output (default: True).
Expand Down Expand Up @@ -804,6 +805,17 @@ def __call__(self, inputs: Array) -> Array:
else:
kernel_size = tuple(self.kernel_size)

def maybe_broadcast(
x: Optional[Union[int, Sequence[int]]],
) -> Tuple[int, ...]:
if x is None:
# backward compatibility with using None as sentinel for
# broadcast 1
x = 1
if isinstance(x, int):
return (x,) * len(kernel_size)
return tuple(x)

# Combine all input batch dimensions into a single leading batch axis.
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
if num_batch_dimensions != 1:
Expand All @@ -814,11 +826,8 @@ def __call__(self, inputs: Array) -> Array:
]
inputs = jnp.reshape(inputs, flat_input_shape)

strides: Tuple[int, ...]
if self.strides is None:
strides = (1,) * (inputs.ndim - 2)
else:
strides = tuple(self.strides)
strides = maybe_broadcast(self.strides)
kernel_dilation = maybe_broadcast(self.kernel_dilation)

in_features = jnp.shape(inputs)[-1]
if self.transpose_kernel:
Expand Down Expand Up @@ -857,7 +866,7 @@ def __call__(self, inputs: Array) -> Array:
kernel,
strides,
padding_lax,
rhs_dilation=self.kernel_dilation,
rhs_dilation=kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def test_transpose_kernel_conv_transpose(self, use_bias):
conv_module = nn.ConvTranspose(
features=4,
use_bias=use_bias,
strides=(2, 2),
strides=2,
kernel_size=(6, 6),
padding='CIRCULAR',
transpose_kernel=True,
Expand Down
Loading