Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune Whisper model on LibriSpeech #1571

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
59 changes: 53 additions & 6 deletions egs/librispeech/ASR/local/compute_fbank_librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
NumpyHdf5Writer,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand Down Expand Up @@ -61,25 +69,55 @@ def get_args():
help="""Dataset parts to compute fbank. If None, we will use all""",
)

parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Where to store the train/dev/test manifests and fbank features",
)

parser.add_argument(
"--perturb-speed",
type=str2bool,
default=True,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)

parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="If use Whisper configuration for fbank computation",
)

parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
)

parser.add_argument(
"--use-hdf5",
type=str2bool,
default=False,
help="If use hdf5 to store un-compressed features. Otherwise, use Lilcom"
)

return parser.parse_args()


def compute_fbank_librispeech(
bpe_model: Optional[str] = None,
dataset: Optional[str] = None,
output_dir: Optional[str] = None,
perturb_speed: Optional[bool] = True,
whisper_fbank: Optional[bool] = False,
num_mel_bins: Optional[int] = 80,
use_hdf5: Optional[bool] = False,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80

if bpe_model:
logging.info(f"Loading {bpe_model}")
Expand Down Expand Up @@ -116,7 +154,12 @@ def compute_fbank_librispeech(
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -134,7 +177,7 @@ def compute_fbank_librispeech(
if bpe_model:
cut_set = filter_cuts(cut_set, sp)
if perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
Expand All @@ -146,7 +189,7 @@ def compute_fbank_librispeech(
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
storage_type=LilcomChunkyWriter if not use_hdf5 else NumpyHdf5Writer,
)
cut_set.to_file(output_dir / cuts_filename)

Expand All @@ -160,5 +203,9 @@ def compute_fbank_librispeech(
compute_fbank_librispeech(
bpe_model=args.bpe_model,
dataset=args.dataset,
output_dir=args.output_dir,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
num_mel_bins=args.num_mel_bins,
use_hdf5=args.use_hdf5,
)
15 changes: 13 additions & 2 deletions egs/librispeech/ASR/local/compute_fbank_musan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
FbankConfig,
LilcomChunkyWriter,
MonoCut,
NumpyHdf5Writer,
WhisperFbank,
WhisperFbankConfig,
combine,
Expand All @@ -55,7 +56,10 @@ def is_cut_long(c: MonoCut) -> bool:


def compute_fbank_musan(
num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank"
num_mel_bins: int = 80,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
use_hdf5: bool = False,
):
src_dir = Path("data/manifests")
output_dir = Path(output_dir)
Expand Down Expand Up @@ -111,7 +115,7 @@ def compute_fbank_musan(
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
storage_type=LilcomChunkyWriter if not use_hdf5 else NumpyHdf5Writer,
)
)
musan_cuts.to_file(musan_cuts_path)
Expand All @@ -137,6 +141,12 @@ def get_args():
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
parser.add_argument(
"--use-hdf5",
type=str2bool,
default=False,
help="If use hdf5 to store un-compressed features. Otherwise, use Lilcom"
)
return parser.parse_args()


Expand All @@ -149,4 +159,5 @@ def get_args():
num_mel_bins=args.num_mel_bins,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
use_hdf5=args.use_hdf5,
)
31 changes: 31 additions & 0 deletions egs/librispeech/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,34 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
$lang_dir/L_disambig.fst
fi
fi

# NOTE: This stage is optional and should only be done if you want to
# do Whisper related experiments
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare whisper fbank feature"
perturb_speed=0
whisper_mel_bins=80
use_hdf5=False
output_dir=data/fbank_whisper_${whisper_mel_bins}D_test
if [ ! -f $output_dir/.librispeech.whisper.done ]; then
mkdir -p $output_dir
./local/compute_fbank_librispeech.py \
--num-mel-bins ${whisper_mel_bins} \
--perturb-speed ${perturb_speed} \
--whisper-fbank true \
--use-hdf5 ${use_hdf5} \
--output-dir $output_dir
./local/compute_fbank_musan.py \
--num-mel-bins ${whisper_mel_bins} \
--whisper-fbank true \
--use-hdf5 ${use_hdf5} \
--output-dir $output_dir
touch $output_dir/.librispeech.whisper.done
fi
if [ ! -f ${output_dir}/librispeech_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c ${output_dir}/librispeech_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c ${output_dir}/librispeech_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c ${output_dir}/librispeech_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > ${output_dir}/librispeech_cuts_train-all-shuf.jsonl.gz
fi
fi
49 changes: 45 additions & 4 deletions egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
from typing import Any, Dict, Optional

import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
Expand Down Expand Up @@ -215,6 +223,20 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="AudioSamples or PrecomputedFeatures",
)

group.add_argument(
"--use-whisper-fbank",
type=str2bool,
default=False,
help="Use whisper fbank feature as input",
)

group.add_argument(
"--whisper-fbank-n-mels",
type=int,
default=80,
help="Number of mels for whisper fbank, large-v3 uses 128-mel fbank",
)

def train_dataloaders(
self,
cuts_train: CutSet,
Expand Down Expand Up @@ -297,9 +319,15 @@ def train_dataloaders(
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(extractor),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
Expand Down Expand Up @@ -355,9 +383,15 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:

logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=OnTheFlyFeatures(extractor),
return_cuts=self.args.return_cuts,
)
else:
Expand All @@ -383,8 +417,15 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:

def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))

test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
input_strategy=OnTheFlyFeatures(extractor)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
Expand Down
1 change: 1 addition & 0 deletions egs/librispeech/ASR/whisper/asr_datamodule.py
Loading
Loading