Skip to content

Commit

Permalink
fine tuned training & inference
Browse files Browse the repository at this point in the history
  • Loading branch information
roatienza committed May 30, 2023
1 parent 4055043 commit 9c67a72
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
dynamic_axes={
"inputs": {1: "phoneme"},
# ideally, this works but repeat_interleave is fixed
"outputs": {1: "wav"}
"outputs": {0: "wav", 1: "lengths", 2: "duration"}
})
elif args.jit is not None:
with torch.no_grad():
Expand Down
8 changes: 7 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,14 @@ def tts(lexicon, g2p, preprocess_config, model, is_onnx, args, verbose=False):
model = model.to(args.infer_device)
model.eval()

# default number of threads is 128 on AMD
# this is too high and causes the model to run slower
# set it to a lower number eg --threads 24
# https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
if args.threads is not None:
torch.set_num_threads(args.threads)
if args.compile:
model = torch.compile(model)
model = torch.compile(model, mode="reduce-overhead", backend="inductor")

if args.text is not None:
rtf = []
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ g2p-en
validators
onnx
onnxruntime
protobuf==3.20
protobuf==3.20.2
numpy==1.24.3
# needed for data preparation
librosa
unidecode
Expand Down
3 changes: 2 additions & 1 deletion utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def get_args():

parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--iter", type=int, default=1)
parser.add_argument("--threads", type=int, default=24)

#choices = ["bf16-mixed", "16-mixed", 16, 32, 64]
parser.add_argument("--precision", default=16)
Expand Down Expand Up @@ -436,7 +437,7 @@ def get_args():
help='Convert to onnx model')
parser.add_argument('--onnx-insize',
type=int,
default=128,
default=None,
help='Max input size for the onnx model')
parser.add_argument('--onnx-opset',
type=int,
Expand Down

0 comments on commit 9c67a72

Please sign in to comment.