-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add diphone factored hybrid model output block (#56)
* Add diphone factored hybrid model backend * add from_flags function * improve naming, separate joint forwarding via dedicated method * activation before dropout * make property naming more consistent * remove superfluous property * make forward an alias of forward_factored * cleanup code a little, improve annotations thanks Nick * document kwargs * specify better name for test * directly specify device for range * fix classmethod decorator * add conversion function from dense label info * consistently name left context logits * rename variables for clarity * assert output shape for ONNX parser * remove flatten (not needed) * order properties and docstring * black
- Loading branch information
1 parent
ef22941
commit 31c284d
Showing
4 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"] | ||
|
||
from .diphone import * | ||
from .util import BoundaryClassV1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
__all__ = [ | ||
"FactoredDiphoneBlockV1Config", | ||
"FactoredDiphoneBlockV1", | ||
] | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Tuple, Union | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
import torch.nn.functional as F | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
from .util import BoundaryClassV1, get_center_dim, get_mlp | ||
|
||
|
||
@dataclass | ||
class FactoredDiphoneBlockV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
num_contexts: the number of raw phonemes/acoustic contexts | ||
num_hmm_states_per_phone: the number of HMM states per phoneme | ||
boundary_class: the phoneme state augmentation to apply | ||
activation: activation function to use in the context mixing MLP. | ||
context_mix_mlp_dim: inner dimension of the context mixing MLP layers | ||
context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be | ||
left_context_embedding_dim: embedding dimension of the left context | ||
values. Good choice is in the order of num_contexts. | ||
dropout: dropout probabilty | ||
num_inputs: input dimension of the output block, must match w/ output dimension | ||
of main encoder (e.g. Conformer) | ||
""" | ||
|
||
num_contexts: int | ||
num_hmm_states_per_phone: int | ||
boundary_class: Union[int, BoundaryClassV1] | ||
|
||
activation: Callable[[], nn.Module] | ||
context_mix_mlp_dim: int | ||
context_mix_mlp_num_layers: int | ||
left_context_embedding_dim: int | ||
|
||
dropout: float | ||
num_inputs: int | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
assert self.num_contexts > 0 | ||
assert self.num_hmm_states_per_phone > 0 | ||
|
||
assert self.context_mix_mlp_dim > 0 | ||
assert self.context_mix_mlp_num_layers > 0 | ||
assert self.left_context_embedding_dim > 0 | ||
|
||
assert self.num_inputs > 0 | ||
assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability" | ||
|
||
|
||
class FactoredDiphoneBlockV1(nn.Module): | ||
""" | ||
Diphone FH model output block. | ||
Consumes the output h(x) of a main encoder model and computes factored or joint | ||
output logits/probabilities for p(c|l,h(x)) and p(l|h(x)). | ||
""" | ||
|
||
def __init__(self, cfg: FactoredDiphoneBlockV1Config): | ||
super().__init__() | ||
|
||
self.boundary_class = cfg.boundary_class | ||
self.num_contexts = cfg.num_contexts | ||
self.num_hmm_states_per_phone = cfg.num_hmm_states_per_phone | ||
|
||
self.num_center = get_center_dim(self.num_contexts, self.num_hmm_states_per_phone, self.boundary_class) | ||
self.num_diphone = self.num_center * self.num_contexts | ||
|
||
self.left_context_encoder = get_mlp( | ||
num_input=cfg.num_inputs, | ||
num_output=cfg.num_contexts, | ||
hidden_dim=cfg.context_mix_mlp_dim, | ||
num_layers=cfg.context_mix_mlp_num_layers, | ||
dropout=cfg.dropout, | ||
activation=cfg.activation, | ||
) | ||
self.left_context_embedding = nn.Embedding(cfg.num_contexts, cfg.left_context_embedding_dim) | ||
self.center_encoder = get_mlp( | ||
num_input=cfg.num_inputs + cfg.left_context_embedding_dim, | ||
num_output=self.num_center, | ||
hidden_dim=cfg.context_mix_mlp_dim, | ||
num_layers=cfg.context_mix_mlp_num_layers, | ||
dropout=cfg.dropout, | ||
activation=cfg.activation, | ||
) | ||
|
||
def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: | ||
return self.forward_factored(*args, **kwargs) | ||
|
||
def forward_factored( | ||
self, | ||
features: Tensor, # B, T, F | ||
contexts_left: Tensor, # B, T | ||
**kwargs, # kwargs because the train_step function passes contexts_center for right context training | ||
) -> Tuple[Tensor, Tensor, Tensor]: | ||
""" | ||
:param features: Main encoder output. shape B, T, F. F=num_inputs | ||
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T. | ||
:return: tuple of logits for p(c|l,x), p(l|x) and the embedded left context values. | ||
""" | ||
|
||
logits_left = self.left_context_encoder(features) # B, T, C | ||
# in training we forward exactly one context per T, so: B, T, E | ||
contexts_embedded_left = self.left_context_embedding(contexts_left) | ||
features_center = torch.cat((features, contexts_embedded_left), -1) # B, T, F+E | ||
logits_center = self.center_encoder(features_center) # B, T, C | ||
|
||
return logits_center, logits_left, contexts_embedded_left | ||
|
||
def forward_joint(self, features: Tensor) -> Tensor: | ||
""" | ||
:param features: Main encoder output. shape B, T, F. F=num_inputs | ||
:return: log probabilities for p(c,l|x). | ||
""" | ||
|
||
logits_left = self.left_context_encoder(features) # B, T, C | ||
|
||
# here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x) | ||
contexts_left = torch.arange(self.num_contexts, device=features.device) # C | ||
contexts_embedded_left = self.left_context_embedding(contexts_left) # C, E | ||
|
||
features_expanded = features.expand((self.num_contexts, -1, -1, -1)) # C, B, T, F | ||
contexts_embedded_left_ = contexts_embedded_left.reshape((self.num_contexts, 1, 1, -1)).expand( | ||
(-1, features.shape[0], features.shape[1], -1) | ||
) # C, B, T, E | ||
features_center = torch.cat((features_expanded, contexts_embedded_left_), dim=-1) # C, B, T, F+E | ||
logits_center = self.center_encoder(features_center) # C, B, T, F' | ||
log_probs_center = F.log_softmax(logits_center, -1) | ||
log_probs_center = log_probs_center.permute((1, 2, 3, 0)) # B, T, F', C | ||
log_probs_left = F.log_softmax(logits_left, -1) | ||
log_probs_left = log_probs_left.unsqueeze(-2) # B, T, 1, C | ||
|
||
joint_log_probs = log_probs_center + log_probs_left # B, T, F', C | ||
joint_log_probs = joint_log_probs.reshape( | ||
(features.shape[0], features.shape[1], self.num_diphone) | ||
) # B, T, F'*C | ||
|
||
return joint_log_probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from enum import Enum | ||
from typing import Callable, Union | ||
|
||
from torch import nn | ||
|
||
|
||
class BoundaryClassV1(Enum): | ||
"""Phoneme state class augmentation selector""" | ||
|
||
none = 1 | ||
word_end = 2 | ||
boundary = 4 | ||
|
||
def factor(self): | ||
return self.value | ||
|
||
@classmethod | ||
def from_flags(cls, use_word_end_classes: bool, use_boundary_classes: bool) -> "BoundaryClassV1": | ||
assert not (use_word_end_classes and use_boundary_classes), "cannot use both classes" | ||
|
||
if use_boundary_classes: | ||
return cls.boundary | ||
elif use_word_end_classes: | ||
return cls.word_end | ||
else: | ||
return cls.none | ||
|
||
@classmethod | ||
def from_dense_label_info(cls, li: "i6_core.mm.context_label.DenseLabelInfo") -> "BoundaryClassV1": | ||
return cls.from_flags( | ||
use_word_end_classes=li.use_word_end_classes, | ||
use_boundary_classes=li.use_boundary_classes, | ||
) | ||
|
||
|
||
def get_center_dim( | ||
n_contexts: int, | ||
num_hmm_states_per_phone: int, | ||
ph_class: Union[int, BoundaryClassV1], | ||
) -> int: | ||
""" | ||
:return: number of center phonemes given the augmentation values | ||
""" | ||
|
||
factor = ph_class.factor() if isinstance(ph_class, BoundaryClassV1) else ph_class | ||
return n_contexts * num_hmm_states_per_phone * factor | ||
|
||
|
||
def get_mlp( | ||
num_input: int, | ||
num_output: int, | ||
hidden_dim: int, | ||
dropout: float, | ||
activation: Callable[[], nn.Module], | ||
num_layers, | ||
) -> nn.Module: | ||
""" | ||
:return: a context-mixing MLP according to the specifications | ||
""" | ||
|
||
assert num_input > 0 | ||
assert num_output > 0 | ||
assert num_layers > 0 | ||
assert hidden_dim > 0 | ||
assert 0.0 <= dropout <= 1.0 | ||
|
||
return nn.Sequential( | ||
*[ | ||
layer | ||
for in_dim in [ | ||
num_input, | ||
*[hidden_dim for _ in range(num_layers - 1)], | ||
] | ||
for layer in [ | ||
nn.Linear(in_dim, hidden_dim), | ||
activation(), | ||
nn.Dropout(dropout), | ||
] | ||
], | ||
nn.Linear(hidden_dim, num_output), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from itertools import product | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config | ||
from i6_models.parts.factored_hybrid.util import get_center_dim | ||
|
||
|
||
def test_dim_calcs(): | ||
n_ctx = 42 | ||
|
||
assert get_center_dim(n_ctx, 1, BoundaryClassV1.none) == 42 | ||
assert get_center_dim(n_ctx, 1, BoundaryClassV1.word_end) == 84 | ||
assert get_center_dim(n_ctx, 3, BoundaryClassV1.word_end) == 252 | ||
assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504 | ||
|
||
|
||
def test_output_shape_and_norm(): | ||
n_ctx = 42 | ||
n_in = 32 | ||
|
||
for we_class, states_per_ph in product( | ||
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary], | ||
[1, 3], | ||
): | ||
block = FactoredDiphoneBlockV1( | ||
FactoredDiphoneBlockV1Config( | ||
activation=nn.ReLU, | ||
context_mix_mlp_dim=64, | ||
context_mix_mlp_num_layers=2, | ||
dropout=0.1, | ||
left_context_embedding_dim=32, | ||
num_contexts=n_ctx, | ||
num_hmm_states_per_phone=states_per_ph, | ||
num_inputs=n_in, | ||
boundary_class=we_class, | ||
) | ||
) | ||
|
||
for b, t in product([10, 50, 100], [10, 50, 100]): | ||
contexts_forward = torch.randint(0, n_ctx, (b, t)) | ||
encoder_output = torch.rand((b, t, n_in)) | ||
output_center, output_left, _ = block(features=encoder_output, contexts_left=contexts_forward) | ||
assert output_left.shape == (b, t, n_ctx) | ||
cdim = get_center_dim(n_ctx, states_per_ph, we_class) | ||
assert output_center.shape == (b, t, cdim) | ||
|
||
encoder_output = torch.rand((b, t, n_in)) | ||
output = block.forward_joint(features=encoder_output) | ||
cdim = get_center_dim(n_ctx, states_per_ph, we_class) | ||
assert output.shape == (b, t, cdim * n_ctx) | ||
output_p = torch.exp(output) | ||
ones_hopefully = torch.sum(output_p, dim=-1) | ||
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 | ||
assert all(close_to_one) |