Skip to content

Commit

Permalink
Compute framewise alignment information of the LibriSpeech dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 23, 2021
1 parent 4580ff1 commit 27a6d5e
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 25 deletions.
130 changes: 105 additions & 25 deletions egs/librispeech/ASR/conformer_ctc/ali.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
from pathlib import Path
from typing import List, Tuple

import k2
import torch
Expand All @@ -32,6 +33,7 @@
AttributeDict,
encode_supervisions,
get_alignments,
save_alignments,
setup_logger,
)

Expand All @@ -56,23 +58,40 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)

parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="The lang dir",
)

parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)

parser.add_argument(
"--ali-dir",
type=str,
default="data/ali",
help="The experiment dir",
)
return parser


def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
Expand All @@ -86,9 +105,31 @@ def compute_alignments(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
token_table: k2.SymbolTable,
):
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing alignments.
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- Framewise alignments (token IDs) after subsampling
"""
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
num_cuts = 0

device = graph_compiler.device
ans = []
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]

Expand All @@ -97,11 +138,23 @@ def compute_alignments(
feature = feature.to(device)

supervisions = batch["supervisions"]

cut_ids = []
for cut in supervisions["cut"]:
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)

nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
# we need also to sort cut_ids as encode_supervisions()
# reorders "texts".
# In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]

token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
Expand All @@ -113,22 +166,30 @@ def compute_alignments(
)

lattice = k2.intersect_dense(
decoding_graph, dense_fsa_vec, params.output_beam
decoding_graph,
dense_fsa_vec,
params.output_beam,
)

best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
lattice=lattice,
use_double_scores=params.use_double_scores,
)

ali_ids = get_alignments(best_path)
ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids]
assert len(ali_ids) == len(cut_ids)
ans += list(zip(cut_ids, ali_ids))

num_cuts += len(ali_ids)

if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"

frame_shift = 0.01 # 10ms, i.e., 0.01 seconds
for i, ali in enumerate(ali_tokens[0]):
print(i * params.subsampling_factor * frame_shift, ali)
import sys
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)

sys.exit(0)
return ans


@torch.no_grad()
Expand All @@ -138,6 +199,7 @@ def main():
args = parser.parse_args()

assert args.return_cuts is True
assert args.concatenate_cuts is False

params = get_params()
params.update(vars(args))
Expand Down Expand Up @@ -169,9 +231,7 @@ def main():
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)

Expand All @@ -190,20 +250,40 @@ def main():
model.eval()

librispeech = LibriSpeechAsrDataModule(args)

train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
test_dl = librispeech.test_dataloaders() # a list

ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)

enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}

compute_alignments(
model=model,
dl=enabled_datasets["test_clean"],
params=params,
graph_compiler=graph_compiler,
token_table=lexicon.token_table,
)
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
alignments = compute_alignments(
model=model,
dl=dl,
params=params,
graph_compiler=graph_compiler,
)
num_utt = len(alignments)
alignments = dict(alignments)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
)


torch.set_num_threads(1)
Expand Down
45 changes: 45 additions & 0 deletions icefall/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,51 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
return labels.tolist()


def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)


def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments


def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]]
) -> None:
Expand Down

0 comments on commit 27a6d5e

Please sign in to comment.