Skip to content

Commit

Permalink
Add diphone factored hybrid model output block (#56)
Browse files Browse the repository at this point in the history
* Add diphone factored hybrid model backend

* add from_flags function

* improve naming, separate joint forwarding via dedicated method

* activation before dropout

* make property naming more consistent

* remove superfluous property

* make forward an alias of forward_factored

* cleanup code a little, improve annotations

thanks Nick

* document kwargs

* specify better name for test

* directly specify device for range

* fix classmethod decorator

* add conversion function from dense label info

* consistently name left context logits

* rename variables for clarity

* assert output shape for ONNX parser

* remove flatten (not needed)

* order properties and docstring

* black
  • Loading branch information
NeoLegends authored Jul 28, 2024
1 parent ef22941 commit 31c284d
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 0 deletions.
4 changes: 4 additions & 0 deletions i6_models/parts/factored_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"]

from .diphone import *
from .util import BoundaryClassV1
150 changes: 150 additions & 0 deletions i6_models/parts/factored_hybrid/diphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
__all__ = [
"FactoredDiphoneBlockV1Config",
"FactoredDiphoneBlockV1",
]

from dataclasses import dataclass
from typing import Callable, Tuple, Union

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from i6_models.config import ModelConfiguration

from .util import BoundaryClassV1, get_center_dim, get_mlp


@dataclass
class FactoredDiphoneBlockV1Config(ModelConfiguration):
"""
Attributes:
num_contexts: the number of raw phonemes/acoustic contexts
num_hmm_states_per_phone: the number of HMM states per phoneme
boundary_class: the phoneme state augmentation to apply
activation: activation function to use in the context mixing MLP.
context_mix_mlp_dim: inner dimension of the context mixing MLP layers
context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be
left_context_embedding_dim: embedding dimension of the left context
values. Good choice is in the order of num_contexts.
dropout: dropout probabilty
num_inputs: input dimension of the output block, must match w/ output dimension
of main encoder (e.g. Conformer)
"""

num_contexts: int
num_hmm_states_per_phone: int
boundary_class: Union[int, BoundaryClassV1]

activation: Callable[[], nn.Module]
context_mix_mlp_dim: int
context_mix_mlp_num_layers: int
left_context_embedding_dim: int

dropout: float
num_inputs: int

def __post_init__(self):
super().__post_init__()

assert self.num_contexts > 0
assert self.num_hmm_states_per_phone > 0

assert self.context_mix_mlp_dim > 0
assert self.context_mix_mlp_num_layers > 0
assert self.left_context_embedding_dim > 0

assert self.num_inputs > 0
assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability"


class FactoredDiphoneBlockV1(nn.Module):
"""
Diphone FH model output block.
Consumes the output h(x) of a main encoder model and computes factored or joint
output logits/probabilities for p(c|l,h(x)) and p(l|h(x)).
"""

def __init__(self, cfg: FactoredDiphoneBlockV1Config):
super().__init__()

self.boundary_class = cfg.boundary_class
self.num_contexts = cfg.num_contexts
self.num_hmm_states_per_phone = cfg.num_hmm_states_per_phone

self.num_center = get_center_dim(self.num_contexts, self.num_hmm_states_per_phone, self.boundary_class)
self.num_diphone = self.num_center * self.num_contexts

self.left_context_encoder = get_mlp(
num_input=cfg.num_inputs,
num_output=cfg.num_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)
self.left_context_embedding = nn.Embedding(cfg.num_contexts, cfg.left_context_embedding_dim)
self.center_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.left_context_embedding_dim,
num_output=self.num_center,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
return self.forward_factored(*args, **kwargs)

def forward_factored(
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
**kwargs, # kwargs because the train_step function passes contexts_center for right context training
) -> Tuple[Tensor, Tensor, Tensor]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
:return: tuple of logits for p(c|l,x), p(l|x) and the embedded left context values.
"""

logits_left = self.left_context_encoder(features) # B, T, C
# in training we forward exactly one context per T, so: B, T, E
contexts_embedded_left = self.left_context_embedding(contexts_left)
features_center = torch.cat((features, contexts_embedded_left), -1) # B, T, F+E
logits_center = self.center_encoder(features_center) # B, T, C

return logits_center, logits_left, contexts_embedded_left

def forward_joint(self, features: Tensor) -> Tensor:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:return: log probabilities for p(c,l|x).
"""

logits_left = self.left_context_encoder(features) # B, T, C

# here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x)
contexts_left = torch.arange(self.num_contexts, device=features.device) # C
contexts_embedded_left = self.left_context_embedding(contexts_left) # C, E

features_expanded = features.expand((self.num_contexts, -1, -1, -1)) # C, B, T, F
contexts_embedded_left_ = contexts_embedded_left.reshape((self.num_contexts, 1, 1, -1)).expand(
(-1, features.shape[0], features.shape[1], -1)
) # C, B, T, E
features_center = torch.cat((features_expanded, contexts_embedded_left_), dim=-1) # C, B, T, F+E
logits_center = self.center_encoder(features_center) # C, B, T, F'
log_probs_center = F.log_softmax(logits_center, -1)
log_probs_center = log_probs_center.permute((1, 2, 3, 0)) # B, T, F', C
log_probs_left = F.log_softmax(logits_left, -1)
log_probs_left = log_probs_left.unsqueeze(-2) # B, T, 1, C

joint_log_probs = log_probs_center + log_probs_left # B, T, F', C
joint_log_probs = joint_log_probs.reshape(
(features.shape[0], features.shape[1], self.num_diphone)
) # B, T, F'*C

return joint_log_probs
81 changes: 81 additions & 0 deletions i6_models/parts/factored_hybrid/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from enum import Enum
from typing import Callable, Union

from torch import nn


class BoundaryClassV1(Enum):
"""Phoneme state class augmentation selector"""

none = 1
word_end = 2
boundary = 4

def factor(self):
return self.value

@classmethod
def from_flags(cls, use_word_end_classes: bool, use_boundary_classes: bool) -> "BoundaryClassV1":
assert not (use_word_end_classes and use_boundary_classes), "cannot use both classes"

if use_boundary_classes:
return cls.boundary
elif use_word_end_classes:
return cls.word_end
else:
return cls.none

@classmethod
def from_dense_label_info(cls, li: "i6_core.mm.context_label.DenseLabelInfo") -> "BoundaryClassV1":
return cls.from_flags(
use_word_end_classes=li.use_word_end_classes,
use_boundary_classes=li.use_boundary_classes,
)


def get_center_dim(
n_contexts: int,
num_hmm_states_per_phone: int,
ph_class: Union[int, BoundaryClassV1],
) -> int:
"""
:return: number of center phonemes given the augmentation values
"""

factor = ph_class.factor() if isinstance(ph_class, BoundaryClassV1) else ph_class
return n_contexts * num_hmm_states_per_phone * factor


def get_mlp(
num_input: int,
num_output: int,
hidden_dim: int,
dropout: float,
activation: Callable[[], nn.Module],
num_layers,
) -> nn.Module:
"""
:return: a context-mixing MLP according to the specifications
"""

assert num_input > 0
assert num_output > 0
assert num_layers > 0
assert hidden_dim > 0
assert 0.0 <= dropout <= 1.0

return nn.Sequential(
*[
layer
for in_dim in [
num_input,
*[hidden_dim for _ in range(num_layers - 1)],
]
for layer in [
nn.Linear(in_dim, hidden_dim),
activation(),
nn.Dropout(dropout),
]
],
nn.Linear(hidden_dim, num_output),
)
56 changes: 56 additions & 0 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from itertools import product

import torch
import torch.nn as nn

from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config
from i6_models.parts.factored_hybrid.util import get_center_dim


def test_dim_calcs():
n_ctx = 42

assert get_center_dim(n_ctx, 1, BoundaryClassV1.none) == 42
assert get_center_dim(n_ctx, 1, BoundaryClassV1.word_end) == 84
assert get_center_dim(n_ctx, 3, BoundaryClassV1.word_end) == 252
assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504


def test_output_shape_and_norm():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary],
[1, 3],
):
block = FactoredDiphoneBlockV1(
FactoredDiphoneBlockV1Config(
activation=nn.ReLU,
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
num_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
boundary_class=we_class,
)
)

for b, t in product([10, 50, 100], [10, 50, 100]):
contexts_forward = torch.randint(0, n_ctx, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, _ = block(features=encoder_output, contexts_left=contexts_forward)
assert output_left.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output_center.shape == (b, t, cdim)

encoder_output = torch.rand((b, t, n_in))
output = block.forward_joint(features=encoder_output)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output.shape == (b, t, cdim * n_ctx)
output_p = torch.exp(output)
ones_hopefully = torch.sum(output_p, dim=-1)
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3
assert all(close_to_one)

0 comments on commit 31c284d

Please sign in to comment.