Skip to content

Commit

Permalink
feat(gui): add batch inference, enhance gui, add custom theme (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j authored May 6, 2023
1 parent 0e0943c commit 3ce110b
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 47 deletions.
Binary file modified docs/_static/gui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions src/so_vits_svc_fork/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def train(
default=None,
help="path to cluster model",
)
@click.option(
"-re",

This comment has been minimized.

Copy link
@Orangey64

Orangey64 Jun 12, 2023

Тоді

"--recursive",
type=bool,
default=False,
help="Search recursively",
is_flag=True,
)
@click.option("-t", "--transpose", type=int, default=0, help="transpose")
@click.option(
"-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)"
Expand Down Expand Up @@ -215,6 +223,7 @@ def infer(
output_path: Path,
model_path: Path,
config_path: Path,
recursive: bool,
# svc config
speaker: str,
cluster_model_path: Path | None = None,
Expand Down Expand Up @@ -244,6 +253,10 @@ def infer(
if output_path is None:
output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
output_path = Path(output_path)
if input_path.is_dir() and not recursive:
raise ValueError(
"input_path is a directory. Use 0re or --recursive to infer recursively."

This comment has been minimized.

Copy link
@Lordmau5

Lordmau5 May 6, 2023

Collaborator

Spotted a typo - 0re instead of -re 👀

)
model_path = Path(model_path)
if model_path.is_dir():
model_path = list(
Expand All @@ -259,6 +272,7 @@ def infer(
output_path=output_path,
model_path=model_path,
config_path=config_path,
recursive=recursive,
# svc config
speaker=speaker,
cluster_model_path=cluster_model_path,
Expand Down
88 changes: 74 additions & 14 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,28 @@ def get_output_path(input_path: Path) -> Path:
return output_path


def get_supported_file_types() -> tuple[tuple[str], ...]:
return tuple(
def get_supported_file_types() -> tuple[tuple[str, str], ...]:
res = tuple(
[
((extension, f".{extension.lower()}"))
(extension, f".{extension.lower()}")
for extension in sf.available_formats().keys()
]
)

# Sort by popularity
common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"]
res = sorted(
res,
key=lambda x: common_file_types.index(x[0])
if x[0] in common_file_types
else len(common_file_types),
)
return res


def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]:
return (("Audio", " ".join(sf.available_formats().keys())),)


def validate_output_file_type(output_path: Path) -> bool:
supported_file_types = sorted(
Expand Down Expand Up @@ -145,7 +159,24 @@ def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path:
def main():
LOG.info(f"version: {__version__}")

sg.theme("Dark")
# sg.theme("Dark")
sg.theme_add_new(
"Very Dark",
{
"BACKGROUND": "#111111",
"TEXT": "#FFFFFF",
"INPUT": "#444444",
"TEXT_INPUT": "#FFFFFF",
"SCROLL": "#333333",
"BUTTON": ("white", "#112233"),
"PROGRESS": ("#111111", "#333333"),
"BORDER": 2,
"SLIDER_DEPTH": 2,
"PROGRESS_DEPTH": 2,
},
)
sg.theme("Very Dark")

model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))

frame_contents = {
Expand All @@ -165,7 +196,10 @@ def main():
if Path("./logs/44k/").exists()
else Path(".").absolute().as_posix(),
key="model_path_browse",
file_types=(("PyTorch", "*.pth"),),
file_types=(
("PyTorch", "G_*.pth G_*.pt"),
("Pytorch", "*.pth *.pt"),
),
),
],
[
Expand Down Expand Up @@ -201,7 +235,7 @@ def main():
if Path("./logs/44k/").exists()
else ".",
key="cluster_model_path_browse",
file_types=(("PyTorch", "*.pt"),),
file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")),
),
],
],
Expand Down Expand Up @@ -312,7 +346,17 @@ def main():
sg.Text("Input audio path"),
sg.Push(),
sg.InputText(key="input_path", enable_events=True),
sg.FileBrowse(initial_folder=".", key="input_path_browse"),
sg.FileBrowse(
initial_folder=".",
key="input_path_browse",
file_types=get_supported_file_types_concat(),
),
sg.FolderBrowse(
button_text="Browse(Folder)",
initial_folder=".",
key="input_path_folder_browse",
target="input_path",
),
sg.Button("Play", key="play_input"),
],
[
Expand Down Expand Up @@ -438,15 +482,15 @@ def main():
sg.Combo(
key="presets",
values=list(load_presets().keys()),
size=(20, 1),
size=(40, 1),
enable_events=True,
),
sg.Button("Delete preset", key="delete_preset"),
],
[
sg.Text("Preset name"),
sg.Stretch(),
sg.InputText(key="preset_name", size=(20, 1)),
sg.InputText(key="preset_name", size=(26, 1)),
sg.Button("Add current settings as a preset", key="add_preset"),
],
],
Expand Down Expand Up @@ -498,8 +542,15 @@ def main():
layout = [[column1, column2]]
# layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]]
window = sg.Window(
f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True
) # , use_custom_titlebar=True)
f"{__name__.split('.')[0].replace('_', '-')} v{__version__}",
layout,
grab_anywhere=True,
finalize=True,
# Below disables taskbar, which may be not useful for some users
# use_custom_titlebar=True, no_titlebar=False
# Keep on top
# keep_on_top=True
)
# for n in ["input_device", "output_device"]:
# window[n].Widget.configure(justify="right")
event, values = window.read(timeout=0.01)
Expand Down Expand Up @@ -620,11 +671,19 @@ def apply_preset(name: str) -> None:
# Set a sensible default output path
window.Element("output_path").Update(str(get_output_path(input_path)))
elif event == "infer":
if not input_path.exists() or not input_path.is_file():
LOG.warning(f"Input path {input_path} does not exist.")
if "Default VC" in values["presets"]:
window["presets"].update(
set_to_index=list(load_presets().keys()).index("Default File")
)
apply_preset("Default File")
if values["input_path"] == "":
LOG.warning("Input path is empty.")
continue
if not validate_output_file_type(output_path):
if not input_path.exists():
LOG.warning(f"Input path {input_path} does not exist.")
continue
# if not validate_output_file_type(output_path):
# continue

try:
from so_vits_svc_fork.inference.main import infer
Expand All @@ -639,6 +698,7 @@ def apply_preset(name: str) -> None:
output_path=output_path,
input_path=input_path,
config_path=Path(values["config_path"]),
recursive=True,
# svc config
speaker=values["speaker"],
cluster_model_path=Path(values["cluster_model_path"])
Expand Down
110 changes: 77 additions & 33 deletions src/so_vits_svc_fork/inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from logging import getLogger
from pathlib import Path
from typing import Literal
from typing import Literal, Sequence

import librosa
import numpy as np
import soundfile
import torch
from cm_time import timer
from tqdm import tqdm

from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
from so_vits_svc_fork.utils import get_optimal_device
Expand All @@ -19,10 +20,11 @@
def infer(
*,
# paths
input_path: Path | str,
output_path: Path | str,
input_path: Path | str | Sequence[Path | str],
output_path: Path | str | Sequence[Path | str],
model_path: Path | str,
config_path: Path | str,
recursive: bool = False,
# svc config
speaker: int | str,
cluster_model_path: Path | str | None = None,
Expand All @@ -39,10 +41,36 @@ def infer(
max_chunk_seconds: float = 40,
device: str | torch.device = get_optimal_device(),
):
if isinstance(input_path, (str, Path)):
input_path = [input_path]
if isinstance(output_path, (str, Path)):
output_path = [output_path]
if len(input_path) != len(output_path):
raise ValueError(
f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
)

model_path = Path(model_path)
output_path = Path(output_path)
input_path = Path(input_path)
config_path = Path(config_path)
output_path = [Path(p) for p in output_path]
input_path = [Path(p) for p in input_path]
output_paths = []
input_paths = []

for input_path, output_path in zip(input_path, output_path):
if input_path.is_dir():
if not recursive:
raise ValueError(
f"input_path is a directory, but recursive is False: {input_path}"
)
input_paths.extend(list(input_path.rglob("*.*")))
output_paths.extend(
[output_path / p.relative_to(input_path) for p in input_paths]
)
continue
input_paths.append(input_path)
output_paths.append(output_path)

cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
Expand All @@ -53,23 +81,35 @@ def infer(
device=device,
)

audio, _ = librosa.load(input_path, sr=svc_model.target_sample)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)

soundfile.write(output_path, audio, svc_model.target_sample)
try:
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
for input_path, output_path in pbar:
pbar.set_description(f"{input_path}")
try:
audio, _ = librosa.load(input_path, sr=svc_model.target_sample)
except Exception as e:
LOG.error(f"Failed to load {input_path}")
LOG.exception(e)
continue
output_path.parent.mkdir(parents=True, exist_ok=True)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)
soundfile.write(output_path, audio, svc_model.target_sample)
finally:
del svc_model
torch.cuda.empty_cache()


def realtime(
Expand Down Expand Up @@ -215,14 +255,18 @@ def callback(
if rtf > 1:
LOG.warning("RTF is too high, consider increasing block_seconds")

with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
try:
with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
finally:
# del model, svc_model
torch.cuda.empty_cache()

3 comments on commit 3ce110b

@Orangey64
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ggg

@Orangey64
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cf

@Orangey64
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bv

Please sign in to comment.