Skip to content

Commit

Permalink
Add random combine from k2-fsa#229.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 23, 2022
1 parent 1614a68 commit 51cc648
Showing 1 changed file with 296 additions and 3 deletions.
299 changes: 296 additions & 3 deletions egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from encoder_interface import EncoderInterface
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
) -> None:
super(Conformer, self).__init__()

Expand All @@ -86,7 +87,11 @@ def __init__(
layer_dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
)

def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
Expand Down Expand Up @@ -276,13 +281,30 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""

def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: List[int],
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers

assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])

num_channels = encoder_layer.norm_final.weight.numel()
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)

def forward(
self,
src: Tensor,
Expand All @@ -309,6 +331,8 @@ def forward(
"""
output = src

outputs = []

for i, mod in enumerate(self.layers):
output = mod(
output,
Expand All @@ -317,6 +341,10 @@ def forward(
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
if i in self.aux_layers:
outputs.append(output)

output = self.combiner(outputs)

return output

Expand Down Expand Up @@ -1025,6 +1053,269 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialized
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""

def __init__(
self,
num_inputs: int,
num_channels: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1

self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)

self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev

self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
self._reset_parameters()

def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)

def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
return inputs[-1]

# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels

mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])

ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)

# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)

weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)

if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans

def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)

def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight

# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)

indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans

def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)


def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}"
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
num_channels=num_channels,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)

x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]

y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.


def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)

feature_dim = 50
c = Conformer(
num_features=feature_dim, output_dim=256, d_model=128, nhead=4
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)


if __name__ == "__main__":
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
Expand All @@ -1036,3 +1327,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)

_test_random_combine_main()

0 comments on commit 51cc648

Please sign in to comment.