From f090da38bc98bcb21ad07bfaa7b08b838c8f6725 Mon Sep 17 00:00:00 2001 From: Your Name <> Date: Fri, 1 Nov 2024 00:45:48 -0700 Subject: [PATCH] refactor ksponspeech recipe --- egs/ksponspeech/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_musan.py | 159 +----------------- egs/ksponspeech/ASR/local/filter_cuts.py | 158 +---------------- egs/ksponspeech/ASR/local/train_bpe_model.py | 116 +------------ .../ASR/local/validate_manifest.py | 102 +---------- egs/ksponspeech/ASR/zipformer/README.md | 1 - 6 files changed, 4 insertions(+), 532 deletions(-) delete mode 100644 egs/ksponspeech/ASR/local/__init__.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/compute_fbank_musan.py mode change 100644 => 120000 egs/ksponspeech/ASR/local/filter_cuts.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/train_bpe_model.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/validate_manifest.py delete mode 100644 egs/ksponspeech/ASR/zipformer/README.md diff --git a/egs/ksponspeech/ASR/local/__init__.py b/egs/ksponspeech/ASR/local/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/egs/ksponspeech/ASR/local/compute_fbank_musan.py b/egs/ksponspeech/ASR/local/compute_fbank_musan.py deleted file mode 100755 index c0bdacfe51..0000000000 --- a/egs/ksponspeech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory `src_dir` (default is data/manifests). - -The generated fbank features are saved in data/fbank. -""" -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - MonoCut, - WhisperFbank, - WhisperFbankConfig, - combine, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def is_cut_long(c: MonoCut) -> bool: - return c.duration > 5 - - -def compute_fbank_musan( - src_dir: str = "data/manifests", - num_mel_bins: int = 80, - whisper_fbank: bool = False, - output_dir: str = "data/fbank", -): - src_dir = Path(src_dir) - output_dir = Path(output_dir) - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(is_cut_long) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--src-dir", - type=str, - default="data/manifests", - help="Source manifests directory.", - ) - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--output-dir", - type=str, - default="data/fbank", - help="Output directory. Default: data/fbank.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - compute_fbank_musan( - src_dir=args.src_dir, - num_mel_bins=args.num_mel_bins, - whisper_fbank=args.whisper_fbank, - output_dir=args.output_dir, - ) diff --git a/egs/ksponspeech/ASR/local/compute_fbank_musan.py b/egs/ksponspeech/ASR/local/compute_fbank_musan.py new file mode 120000 index 0000000000..5833f2484e --- /dev/null +++ b/egs/ksponspeech/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/filter_cuts.py b/egs/ksponspeech/ASR/local/filter_cuts.py deleted file mode 100644 index f081da5dfe..0000000000 --- a/egs/ksponspeech/ASR/local/filter_cuts.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script removes short and long utterances from a cutset. - -Caution: - You may need to tune the thresholds for your own dataset. - -Usage example: - - python3 ./local/filter_cuts.py \ - --bpe-model data/lang_bpe_5000/bpe.model \ - --in-cuts data/fbank/speechtools_cuts_test.jsonl.gz \ - --out-cuts data/fbank-filtered/speechtools_cuts_test.jsonl.gz -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=Path, - help="Path to the bpe.model", - ) - - parser.add_argument( - "--in-cuts", - type=Path, - help="Path to the input cutset", - ) - - parser.add_argument( - "--out-cuts", - type=Path, - help="Path to the output cutset", - ) - - return parser.parse_args() - - -def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): - total = 0 # number of total utterances before removal - removed = 0 # number of removed utterances - - def remove_short_and_long_utterances(c: Cut): - """Return False to exclude the input cut""" - nonlocal removed, total - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ./display_manifest_statistics.py - # - # You should use ./display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - total += 1 - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - removed += 1 - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./pruned_transducer_stateless2/conformer.py, the - # conv module uses the following expression - # for subsampling - if c.num_frames is None: - num_frames = c.duration * 100 # approximate - else: - num_frames = c.num_frames - - T = ((num_frames - 1) // 2 - 1) // 2 - # Note: for ./lstm_transducer_stateless/lstm.py, the formula is - # T = ((num_frames - 3) // 2 - 1) // 2 - - # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is - # T = ((num_frames - 7) // 2 + 1) // 2 - - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - removed += 1 - return False - - return True - - # We use to_eager() here so that we can print out the value of total - # and removed below. - ans = cut_set.filter(remove_short_and_long_utterances).to_eager() - ratio = removed / total * 100 - logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." - ) - return ans - - -def main(): - args = get_args() - logging.info(vars(args)) - - if args.out_cuts.is_file(): - logging.info(f"{args.out_cuts} already exists - skipping") - return - - assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" - assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - cut_set = load_manifest_lazy(args.in_cuts) - assert isinstance(cut_set, CutSet) - - cut_set = filter_cuts(cut_set, sp) - logging.info(f"Saving to {args.out_cuts}") - args.out_cuts.parent.mkdir(parents=True, exist_ok=True) - cut_set.to_file(args.out_cuts) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ksponspeech/ASR/local/filter_cuts.py b/egs/ksponspeech/ASR/local/filter_cuts.py new file mode 120000 index 0000000000..27aca17293 --- /dev/null +++ b/egs/ksponspeech/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/train_bpe_model.py b/egs/ksponspeech/ASR/local/train_bpe_model.py deleted file mode 100755 index 5979d5b986..0000000000 --- a/egs/ksponspeech/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path -from typing import Dict - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def generate_tokens(lang_dir: Path): - """ - Generate the tokens.txt from a bpe model. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(lang_dir / "bpe.model")) - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f: - for sym, i in token2id.items(): - f.write(f"{sym} {i}\n") - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - generate_tokens(lang_dir) - - -if __name__ == "__main__": - main() diff --git a/egs/ksponspeech/ASR/local/train_bpe_model.py b/egs/ksponspeech/ASR/local/train_bpe_model.py new file mode 120000 index 0000000000..6fad36421e --- /dev/null +++ b/egs/ksponspeech/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/validate_manifest.py b/egs/ksponspeech/ASR/local/validate_manifest.py deleted file mode 100755 index 98f2734196..0000000000 --- a/egs/ksponspeech/ASR/local/validate_manifest.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/speechtools_cuts_train.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset.speech_recognition import validate_for_asr - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - tol = 2e-3 # same tolerance as in 'validate_for_asr()' - s = c.supervisions[0] - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - if s.start < -tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} must not be negative." - ) - if s.start > tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." - ) - if c.start + s.end > c.end + tol: - raise ValueError( - f"{c.id}: Supervision end time {c.start+s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - # Validation from K2 training - # - checks supervision start is 0 - # - checks supervision.duration is not longer than cut.duration - # - there is tolerance 2ms - validate_for_asr(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ksponspeech/ASR/local/validate_manifest.py b/egs/ksponspeech/ASR/local/validate_manifest.py new file mode 120000 index 0000000000..0a9725e876 --- /dev/null +++ b/egs/ksponspeech/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/README.md b/egs/ksponspeech/ASR/zipformer/README.md deleted file mode 100644 index c8c2104cdc..0000000000 --- a/egs/ksponspeech/ASR/zipformer/README.md +++ /dev/null @@ -1 +0,0 @@ -This recipe implements Zipformer model.