From 31c284d6c9b400f4d1aaf20b492449ec7c3fb99b Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Sun, 28 Jul 2024 17:34:23 +0200 Subject: [PATCH] Add diphone factored hybrid model output block (#56) * Add diphone factored hybrid model backend * add from_flags function * improve naming, separate joint forwarding via dedicated method * activation before dropout * make property naming more consistent * remove superfluous property * make forward an alias of forward_factored * cleanup code a little, improve annotations thanks Nick * document kwargs * specify better name for test * directly specify device for range * fix classmethod decorator * add conversion function from dense label info * consistently name left context logits * rename variables for clarity * assert output shape for ONNX parser * remove flatten (not needed) * order properties and docstring * black --- i6_models/parts/factored_hybrid/__init__.py | 4 + i6_models/parts/factored_hybrid/diphone.py | 150 ++++++++++++++++++++ i6_models/parts/factored_hybrid/util.py | 81 +++++++++++ tests/test_fh.py | 56 ++++++++ 4 files changed, 291 insertions(+) create mode 100644 i6_models/parts/factored_hybrid/__init__.py create mode 100644 i6_models/parts/factored_hybrid/diphone.py create mode 100644 i6_models/parts/factored_hybrid/util.py create mode 100644 tests/test_fh.py diff --git a/i6_models/parts/factored_hybrid/__init__.py b/i6_models/parts/factored_hybrid/__init__.py new file mode 100644 index 00000000..3d9da599 --- /dev/null +++ b/i6_models/parts/factored_hybrid/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"] + +from .diphone import * +from .util import BoundaryClassV1 diff --git a/i6_models/parts/factored_hybrid/diphone.py b/i6_models/parts/factored_hybrid/diphone.py new file mode 100644 index 00000000..e21ae423 --- /dev/null +++ b/i6_models/parts/factored_hybrid/diphone.py @@ -0,0 +1,150 @@ +__all__ = [ + "FactoredDiphoneBlockV1Config", + "FactoredDiphoneBlockV1", +] + +from dataclasses import dataclass +from typing import Callable, Tuple, Union + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from i6_models.config import ModelConfiguration + +from .util import BoundaryClassV1, get_center_dim, get_mlp + + +@dataclass +class FactoredDiphoneBlockV1Config(ModelConfiguration): + """ + Attributes: + num_contexts: the number of raw phonemes/acoustic contexts + num_hmm_states_per_phone: the number of HMM states per phoneme + boundary_class: the phoneme state augmentation to apply + + activation: activation function to use in the context mixing MLP. + context_mix_mlp_dim: inner dimension of the context mixing MLP layers + context_mix_mlp_num_layers: how many hidden layers on the MLPs there should be + left_context_embedding_dim: embedding dimension of the left context + values. Good choice is in the order of num_contexts. + + dropout: dropout probabilty + num_inputs: input dimension of the output block, must match w/ output dimension + of main encoder (e.g. Conformer) + """ + + num_contexts: int + num_hmm_states_per_phone: int + boundary_class: Union[int, BoundaryClassV1] + + activation: Callable[[], nn.Module] + context_mix_mlp_dim: int + context_mix_mlp_num_layers: int + left_context_embedding_dim: int + + dropout: float + num_inputs: int + + def __post_init__(self): + super().__post_init__() + + assert self.num_contexts > 0 + assert self.num_hmm_states_per_phone > 0 + + assert self.context_mix_mlp_dim > 0 + assert self.context_mix_mlp_num_layers > 0 + assert self.left_context_embedding_dim > 0 + + assert self.num_inputs > 0 + assert 0.0 <= self.dropout <= 1.0, "dropout must be a probability" + + +class FactoredDiphoneBlockV1(nn.Module): + """ + Diphone FH model output block. + + Consumes the output h(x) of a main encoder model and computes factored or joint + output logits/probabilities for p(c|l,h(x)) and p(l|h(x)). + """ + + def __init__(self, cfg: FactoredDiphoneBlockV1Config): + super().__init__() + + self.boundary_class = cfg.boundary_class + self.num_contexts = cfg.num_contexts + self.num_hmm_states_per_phone = cfg.num_hmm_states_per_phone + + self.num_center = get_center_dim(self.num_contexts, self.num_hmm_states_per_phone, self.boundary_class) + self.num_diphone = self.num_center * self.num_contexts + + self.left_context_encoder = get_mlp( + num_input=cfg.num_inputs, + num_output=cfg.num_contexts, + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + self.left_context_embedding = nn.Embedding(cfg.num_contexts, cfg.left_context_embedding_dim) + self.center_encoder = get_mlp( + num_input=cfg.num_inputs + cfg.left_context_embedding_dim, + num_output=self.num_center, + hidden_dim=cfg.context_mix_mlp_dim, + num_layers=cfg.context_mix_mlp_num_layers, + dropout=cfg.dropout, + activation=cfg.activation, + ) + + def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: + return self.forward_factored(*args, **kwargs) + + def forward_factored( + self, + features: Tensor, # B, T, F + contexts_left: Tensor, # B, T + **kwargs, # kwargs because the train_step function passes contexts_center for right context training + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + :param features: Main encoder output. shape B, T, F. F=num_inputs + :param contexts_left: The left contexts used to compute p(c|l,x), shape B, T. + :return: tuple of logits for p(c|l,x), p(l|x) and the embedded left context values. + """ + + logits_left = self.left_context_encoder(features) # B, T, C + # in training we forward exactly one context per T, so: B, T, E + contexts_embedded_left = self.left_context_embedding(contexts_left) + features_center = torch.cat((features, contexts_embedded_left), -1) # B, T, F+E + logits_center = self.center_encoder(features_center) # B, T, C + + return logits_center, logits_left, contexts_embedded_left + + def forward_joint(self, features: Tensor) -> Tensor: + """ + :param features: Main encoder output. shape B, T, F. F=num_inputs + :return: log probabilities for p(c,l|x). + """ + + logits_left = self.left_context_encoder(features) # B, T, C + + # here we forward every context to compute p(c, l|x) = p(c|l, x) * p(l|x) + contexts_left = torch.arange(self.num_contexts, device=features.device) # C + contexts_embedded_left = self.left_context_embedding(contexts_left) # C, E + + features_expanded = features.expand((self.num_contexts, -1, -1, -1)) # C, B, T, F + contexts_embedded_left_ = contexts_embedded_left.reshape((self.num_contexts, 1, 1, -1)).expand( + (-1, features.shape[0], features.shape[1], -1) + ) # C, B, T, E + features_center = torch.cat((features_expanded, contexts_embedded_left_), dim=-1) # C, B, T, F+E + logits_center = self.center_encoder(features_center) # C, B, T, F' + log_probs_center = F.log_softmax(logits_center, -1) + log_probs_center = log_probs_center.permute((1, 2, 3, 0)) # B, T, F', C + log_probs_left = F.log_softmax(logits_left, -1) + log_probs_left = log_probs_left.unsqueeze(-2) # B, T, 1, C + + joint_log_probs = log_probs_center + log_probs_left # B, T, F', C + joint_log_probs = joint_log_probs.reshape( + (features.shape[0], features.shape[1], self.num_diphone) + ) # B, T, F'*C + + return joint_log_probs diff --git a/i6_models/parts/factored_hybrid/util.py b/i6_models/parts/factored_hybrid/util.py new file mode 100644 index 00000000..0d8f1502 --- /dev/null +++ b/i6_models/parts/factored_hybrid/util.py @@ -0,0 +1,81 @@ +from enum import Enum +from typing import Callable, Union + +from torch import nn + + +class BoundaryClassV1(Enum): + """Phoneme state class augmentation selector""" + + none = 1 + word_end = 2 + boundary = 4 + + def factor(self): + return self.value + + @classmethod + def from_flags(cls, use_word_end_classes: bool, use_boundary_classes: bool) -> "BoundaryClassV1": + assert not (use_word_end_classes and use_boundary_classes), "cannot use both classes" + + if use_boundary_classes: + return cls.boundary + elif use_word_end_classes: + return cls.word_end + else: + return cls.none + + @classmethod + def from_dense_label_info(cls, li: "i6_core.mm.context_label.DenseLabelInfo") -> "BoundaryClassV1": + return cls.from_flags( + use_word_end_classes=li.use_word_end_classes, + use_boundary_classes=li.use_boundary_classes, + ) + + +def get_center_dim( + n_contexts: int, + num_hmm_states_per_phone: int, + ph_class: Union[int, BoundaryClassV1], +) -> int: + """ + :return: number of center phonemes given the augmentation values + """ + + factor = ph_class.factor() if isinstance(ph_class, BoundaryClassV1) else ph_class + return n_contexts * num_hmm_states_per_phone * factor + + +def get_mlp( + num_input: int, + num_output: int, + hidden_dim: int, + dropout: float, + activation: Callable[[], nn.Module], + num_layers, +) -> nn.Module: + """ + :return: a context-mixing MLP according to the specifications + """ + + assert num_input > 0 + assert num_output > 0 + assert num_layers > 0 + assert hidden_dim > 0 + assert 0.0 <= dropout <= 1.0 + + return nn.Sequential( + *[ + layer + for in_dim in [ + num_input, + *[hidden_dim for _ in range(num_layers - 1)], + ] + for layer in [ + nn.Linear(in_dim, hidden_dim), + activation(), + nn.Dropout(dropout), + ] + ], + nn.Linear(hidden_dim, num_output), + ) diff --git a/tests/test_fh.py b/tests/test_fh.py new file mode 100644 index 00000000..224362de --- /dev/null +++ b/tests/test_fh.py @@ -0,0 +1,56 @@ +from itertools import product + +import torch +import torch.nn as nn + +from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config +from i6_models.parts.factored_hybrid.util import get_center_dim + + +def test_dim_calcs(): + n_ctx = 42 + + assert get_center_dim(n_ctx, 1, BoundaryClassV1.none) == 42 + assert get_center_dim(n_ctx, 1, BoundaryClassV1.word_end) == 84 + assert get_center_dim(n_ctx, 3, BoundaryClassV1.word_end) == 252 + assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504 + + +def test_output_shape_and_norm(): + n_ctx = 42 + n_in = 32 + + for we_class, states_per_ph in product( + [BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary], + [1, 3], + ): + block = FactoredDiphoneBlockV1( + FactoredDiphoneBlockV1Config( + activation=nn.ReLU, + context_mix_mlp_dim=64, + context_mix_mlp_num_layers=2, + dropout=0.1, + left_context_embedding_dim=32, + num_contexts=n_ctx, + num_hmm_states_per_phone=states_per_ph, + num_inputs=n_in, + boundary_class=we_class, + ) + ) + + for b, t in product([10, 50, 100], [10, 50, 100]): + contexts_forward = torch.randint(0, n_ctx, (b, t)) + encoder_output = torch.rand((b, t, n_in)) + output_center, output_left, _ = block(features=encoder_output, contexts_left=contexts_forward) + assert output_left.shape == (b, t, n_ctx) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output_center.shape == (b, t, cdim) + + encoder_output = torch.rand((b, t, n_in)) + output = block.forward_joint(features=encoder_output) + cdim = get_center_dim(n_ctx, states_per_ph, we_class) + assert output.shape == (b, t, cdim * n_ctx) + output_p = torch.exp(output) + ones_hopefully = torch.sum(output_p, dim=-1) + close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3 + assert all(close_to_one)