Skip to content

Commit

Permalink
Minor fixes to the onnx inference script for ljspeech matcha-tts. (#1838
Browse files Browse the repository at this point in the history
)
  • Loading branch information
csukuangfj authored Dec 19, 2024
1 parent 92ed170 commit ad966fb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
20 changes: 13 additions & 7 deletions .github/scripts/ljspeech/TTS/run-matcha.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/infer.py \
--num-buckets 2 \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
Expand Down Expand Up @@ -97,19 +98,23 @@ function export_onnx() {
python3 ./matcha/export_onnx_hifigan.py
else
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx
fi

ls -lh *.onnx

python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-6.onnx \
--vocoder ./hifigan_v1.onnx \
--tokens ./data/tokens.txt \
--input-text "how are you doing?" \
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
for v in v1 v2 v3; do
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-6.onnx \
--vocoder ./hifigan_$v.onnx \
--tokens ./data/tokens.txt \
--input-text "how are you doing?" \
--output-wav /icefall/generated-matcha-tts-steps-6-$v.wav
done

ls -lh /icefall/*.wav
soxi /icefall/generated-matcha-tts-steps-6-v1.wav
soxi /icefall/generated-matcha-tts-steps-6-*.wav
}

prepare_data
Expand All @@ -118,3 +123,4 @@ infer
export_onnx

rm -rfv generator_v* matcha/exp
git checkout .
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/matcha/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main():
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "temperature", "length_scale"],
input_names=["x", "x_length", "noise_scale", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
Expand Down
19 changes: 14 additions & 5 deletions egs/ljspeech/TTS/matcha/onnx_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __call__(self, x: torch.tensor):
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
# audio: (batch_size, num_samples)

return torch.from_numpy(audio)

Expand All @@ -97,19 +98,24 @@ class OnnxModel:
def __init__(
self,
filename: str,
tokens: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 2

self.session_opts = session_opts
self.tokenizer = Tokenizer("./data/tokens.txt")
self.tokenizer = Tokenizer(tokens)
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])

for i in self.model.get_inputs():
print(i)

Expand Down Expand Up @@ -138,6 +144,7 @@ def __call__(self, x: torch.tensor):
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
# mel: (batch_size, feat_dim, num_frames)

return torch.from_numpy(mel)

Expand All @@ -147,7 +154,7 @@ def main():
params = get_parser().parse_args()
logging.info(vars(params))

model = OnnxModel(params.acoustic_model)
model = OnnxModel(params.acoustic_model, params.tokens)
vocoder = OnnxHifiGANModel(params.vocoder)
text = params.input_text
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
Expand All @@ -164,15 +171,17 @@ def main():
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()

sample_rate = model.sample_rate

t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf_am = t * 22050 / audio.shape[-1]
rtf_vocoder = t2 * 22050 / audio.shape[-1]
rtf_am = t * sample_rate / audio.shape[-1]
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
print("RTF for acoustic model ", rtf_am)
print("RTF for vocoder", rtf_vocoder)

# skip denoiser
sf.write(params.output_wav, audio, 22050, "PCM_16")
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
logging.info(f"Saved to {params.output_wav}")


Expand Down

0 comments on commit ad966fb

Please sign in to comment.