Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv1D supports paddings. #847

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,17 +1216,7 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti


class Conv2DWith1DPadding(Conv2D):
"""The 2-D convolution with 1-D padding on the time axis.

Kernel weights have the HWIO layout and in the shape of (window[0], window[1], input_dim,
output_dim). Both inputs and outputs will be in the NHWC layout.

For audio inputs/outputs, we assume dims correspond to [batch_size, time, frequency, input_dim].
This layer also returns paddings along the time axis. If specifying `cfg.padding` as a tuple of
(leading, trailing) paddings, leading padding frames are treated as valid (i.e. not masked by
the output paddings) while trailing padding frames are invalid (i.e. masked by the output
paddings).
"""
"""The 2-D convolution with 1-D padding on the time axis."""

@config_class
class Config(Conv2D.Config):
Expand Down Expand Up @@ -1499,25 +1489,81 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:

def forward(self, x: Tensor) -> Tensor:
cfg = self.config
dilation = (cfg.rhs_dilation,) if cfg.rhs_dilation else None
dilation = cfg.rhs_dilation or 1
conv_padding = conv_explicit_padding(
window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=dilation
window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=(dilation,)
)
transpose_dilation = cfg.lhs_dilation or 1
output = jax.lax.conv_general_dilated(
lhs=x,
rhs=self.parameters["weight"],
window_strides=(cfg.strides,),
dimension_numbers=("NWC", "WIO", "NWC"),
padding=conv_padding,
feature_group_count=cfg.num_input_dim_groups,
lhs_dilation=[cfg.lhs_dilation] if cfg.lhs_dilation is not None else None,
rhs_dilation=[cfg.rhs_dilation] if cfg.rhs_dilation is not None else None,
lhs_dilation=(transpose_dilation,),
rhs_dilation=(dilation,),
)
if cfg.bias:
output += self.parameters["bias"]
return output


class Conv1DWithPadding(Conv1D):
"""The 1-D convolution with 1-D padding on the time axis."""

@config_class
class Config(Conv1D.Config):
"""Configures Conv1DWithPadding."""

# An optional integer in the range of [left_time_padding, window - right_time_padding)
# that specifies the anchor position within the convolution window that is used to
# determine output paddings. Specifically, the output token is valid iff the input token
# at the anchor position of the corresponding window is valid.
# If None, defaults to left time padding. See Conv2DWith1DPadding more details.
anchor: Optional[int] = None

# We add a kwargs "paddings" to the forward method.
# pylint: disable-next=arguments-differ
def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]:
"""Computes convolution outputs and paddings.

Args:
x: A Tensor of shape [batch_size, seq_len, frequency, input_dim].
paddings: 0/1 Tensor of shape [batch_size, seq_len].

Returns:
output: A Tensor of shape [batch_size, seq_len, frequency, output_dim].
paddings: 0/1 Tensor of shape [batch_size, seq_len].
"""
cfg = self.config
chex.assert_rank(x, paddings.ndim + 1)
# Apply padding to the input.
x = x * (1 - paddings[..., None])

# Apply Conv1D.
output = super().forward(x)

# TODO(dhwang2): Implement Conv1DTranspose separately for lhs_dilation. It's problematic
# for lhs_dilation (Conv Transpose) and rhs_dilation (Dilated Convolution) to be part of
# the same class. Not only are they never used together, but their combined usage would
# result in undefined behavior. Additionally, the logic for handling explicit padding and
# paddings is fundamentally different between them, so supporting both in a single class
# makes the code error-prone.
# Compute paddings conv output.
output_paddings = compute_conv_paddings(
paddings,
window=cfg.window,
stride=cfg.strides,
conv_padding=cfg.padding,
dilation=cfg.rhs_dilation,
anchor=cfg.anchor,
)
# Apply padding to the outputs.
output = output * (1 - output_paddings[..., None])
return output, output_paddings


class DepthwiseConv1D(BaseConv):
"""The 1-D depth-wise convolution layer.

Expand Down
85 changes: 85 additions & 0 deletions axlearn/common/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import partial
from typing import Optional, Union

import einops
import jax.random
import numpy as np
import tensorflow as tf
Expand All @@ -39,6 +40,7 @@
CategoricalHingeLossMetric,
ClassificationMetric,
Conv1D,
Conv1DWithPadding,
Conv2D,
Conv2DTranspose,
Conv2DWith1DPadding,
Expand Down Expand Up @@ -1265,6 +1267,89 @@ def test_conv2d_with_1d_padding(
jnp.take_along_axis(ref_paddings, permute_idx[:, None], axis=0)[:, :output_len],
)

@parameterized.named_parameters(
("1_S1", 1, 1, "VALID", None),
("2_S1_VALID", 2, 1, "VALID", None),
("2_S2_SAME", 2, 2, "SAME", None),
("2_S_CAUSAL", 2, 1, "CAUSAL", None),
("2_S2_VALID", 2, 2, "VALID", None),
("2_S2_CAUSAL", 2, 2, "CAUSAL", None),
("3_S1_VALID", 3, 1, "VALID", None),
("3_S1_VALID_A0", 3, 1, "VALID", 0),
("3_S1_VALID_A1", 3, 1, "VALID", 1),
("3_S1_VALID_A2", 3, 1, "VALID", 2),
("3_S1_SAME", 3, 1, "SAME", None),
("3_S1_CAUSAL", 3, 1, "CAUSAL", None),
("3_S2_VALID", 3, 2, "VALID", None),
("3_S2_CAUSAL", 3, 2, "CAUSAL", None),
)
def test_conv1d_against_conv2d_with_1d_padding(
self,
window: int,
strides: int,
padding: ConvPaddingType,
anchor: Optional[int],
):
input_dim, output_dim = 4, 6
ref_cfg = Conv2DWith1DPadding.default_config().set(
name="ref",
input_dim=input_dim,
output_dim=output_dim,
window=(window, 1),
strides=(strides, 1),
padding=padding,
anchor=anchor,
)
ref_layer = ref_cfg.instantiate(parent=None)

test_cfg = Conv1DWithPadding.default_config().set(
name="test",
input_dim=input_dim,
output_dim=output_dim,
window=window,
strides=strides,
padding=padding,
anchor=anchor,
)
test_layer = test_cfg.instantiate(parent=None)

# Initialize layer parameters.
prng_key = jax.random.PRNGKey(123)
prng_key, init_key = jax.random.split(prng_key)
state = ref_layer.initialize_parameters_recursively(init_key)
test_state = dict(
bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o")
)

# Generate a batch of 10 input sequences.
batch_size, max_seq_len = 10, 10

prng_key, input_key = jax.random.split(prng_key)
inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim])
# The 10 sequences have length 1 to 10.
paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1)

(test_outputs, test_paddings), _ = F(
test_layer,
inputs=dict(x=inputs, paddings=paddings),
is_training=True,
state=test_state,
prng_key=prng_key,
)

inputs = einops.rearrange(inputs, "b t i -> b t 1 i")
(ref_outputs, ref_paddings), _ = F(
ref_layer,
inputs=dict(x=inputs, paddings=paddings),
is_training=True,
state=state,
prng_key=prng_key,
)
ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o")

assert_allclose(ref_paddings, test_paddings)
assert_allclose(ref_outputs, test_outputs)

@parameterized.named_parameters(
{
"testcase_name": "2x2",
Expand Down
Loading