Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gui): add batch inference, enhance gui, add custom theme #582

Merged
merged 7 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/_static/gui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions src/so_vits_svc_fork/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def train(
default=None,
help="path to cluster model",
)
@click.option(
"-re",
"--recursive",
type=bool,
default=False,
help="Search recursively",
is_flag=True,
)
@click.option("-t", "--transpose", type=int, default=0, help="transpose")
@click.option(
"-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)"
Expand Down Expand Up @@ -215,6 +223,7 @@ def infer(
output_path: Path,
model_path: Path,
config_path: Path,
recursive: bool,
# svc config
speaker: str,
cluster_model_path: Path | None = None,
Expand Down Expand Up @@ -244,6 +253,10 @@ def infer(
if output_path is None:
output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
output_path = Path(output_path)
if input_path.is_dir() and not recursive:
raise ValueError(
"input_path is a directory. Use 0re or --recursive to infer recursively."
)
model_path = Path(model_path)
if model_path.is_dir():
model_path = list(
Expand All @@ -259,6 +272,7 @@ def infer(
output_path=output_path,
model_path=model_path,
config_path=config_path,
recursive=recursive,
# svc config
speaker=speaker,
cluster_model_path=cluster_model_path,
Expand Down
88 changes: 74 additions & 14 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,28 @@ def get_output_path(input_path: Path) -> Path:
return output_path


def get_supported_file_types() -> tuple[tuple[str], ...]:
return tuple(
def get_supported_file_types() -> tuple[tuple[str, str], ...]:
res = tuple(
[
((extension, f".{extension.lower()}"))
(extension, f".{extension.lower()}")
for extension in sf.available_formats().keys()
]
)

# Sort by popularity
common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"]
res = sorted(
res,
key=lambda x: common_file_types.index(x[0])
if x[0] in common_file_types
else len(common_file_types),
)
return res


def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]:
return (("Audio", " ".join(sf.available_formats().keys())),)


def validate_output_file_type(output_path: Path) -> bool:
supported_file_types = sorted(
Expand Down Expand Up @@ -145,7 +159,24 @@ def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path:
def main():
LOG.info(f"version: {__version__}")

sg.theme("Dark")
# sg.theme("Dark")
sg.theme_add_new(
"Very Dark",
{
"BACKGROUND": "#111111",
"TEXT": "#FFFFFF",
"INPUT": "#444444",
"TEXT_INPUT": "#FFFFFF",
"SCROLL": "#333333",
"BUTTON": ("white", "#112233"),
"PROGRESS": ("#111111", "#333333"),
"BORDER": 2,
"SLIDER_DEPTH": 2,
"PROGRESS_DEPTH": 2,
},
)
sg.theme("Very Dark")

model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))

frame_contents = {
Expand All @@ -165,7 +196,10 @@ def main():
if Path("./logs/44k/").exists()
else Path(".").absolute().as_posix(),
key="model_path_browse",
file_types=(("PyTorch", "*.pth"),),
file_types=(
("PyTorch", "G_*.pth G_*.pt"),
("Pytorch", "*.pth *.pt"),
),
),
],
[
Expand Down Expand Up @@ -201,7 +235,7 @@ def main():
if Path("./logs/44k/").exists()
else ".",
key="cluster_model_path_browse",
file_types=(("PyTorch", "*.pt"),),
file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")),
),
],
],
Expand Down Expand Up @@ -312,7 +346,17 @@ def main():
sg.Text("Input audio path"),
sg.Push(),
sg.InputText(key="input_path", enable_events=True),
sg.FileBrowse(initial_folder=".", key="input_path_browse"),
sg.FileBrowse(
initial_folder=".",
key="input_path_browse",
file_types=get_supported_file_types_concat(),
),
sg.FolderBrowse(
button_text="Browse(Folder)",
initial_folder=".",
key="input_path_folder_browse",
target="input_path",
),
sg.Button("Play", key="play_input"),
],
[
Expand Down Expand Up @@ -438,15 +482,15 @@ def main():
sg.Combo(
key="presets",
values=list(load_presets().keys()),
size=(20, 1),
size=(40, 1),
enable_events=True,
),
sg.Button("Delete preset", key="delete_preset"),
],
[
sg.Text("Preset name"),
sg.Stretch(),
sg.InputText(key="preset_name", size=(20, 1)),
sg.InputText(key="preset_name", size=(26, 1)),
sg.Button("Add current settings as a preset", key="add_preset"),
],
],
Expand Down Expand Up @@ -498,8 +542,15 @@ def main():
layout = [[column1, column2]]
# layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]]
window = sg.Window(
f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True
) # , use_custom_titlebar=True)
f"{__name__.split('.')[0].replace('_', '-')} v{__version__}",
layout,
grab_anywhere=True,
finalize=True,
# Below disables taskbar, which may be not useful for some users
# use_custom_titlebar=True, no_titlebar=False
# Keep on top
# keep_on_top=True
)
# for n in ["input_device", "output_device"]:
# window[n].Widget.configure(justify="right")
event, values = window.read(timeout=0.01)
Expand Down Expand Up @@ -620,11 +671,19 @@ def apply_preset(name: str) -> None:
# Set a sensible default output path
window.Element("output_path").Update(str(get_output_path(input_path)))
elif event == "infer":
if not input_path.exists() or not input_path.is_file():
LOG.warning(f"Input path {input_path} does not exist.")
if "Default VC" in values["presets"]:
window["presets"].update(
set_to_index=list(load_presets().keys()).index("Default File")
)
apply_preset("Default File")
if values["input_path"] == "":
LOG.warning("Input path is empty.")
continue
if not validate_output_file_type(output_path):
if not input_path.exists():
LOG.warning(f"Input path {input_path} does not exist.")
continue
# if not validate_output_file_type(output_path):
# continue

try:
from so_vits_svc_fork.inference.main import infer
Expand All @@ -639,6 +698,7 @@ def apply_preset(name: str) -> None:
output_path=output_path,
input_path=input_path,
config_path=Path(values["config_path"]),
recursive=True,
# svc config
speaker=values["speaker"],
cluster_model_path=Path(values["cluster_model_path"])
Expand Down
110 changes: 77 additions & 33 deletions src/so_vits_svc_fork/inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from logging import getLogger
from pathlib import Path
from typing import Literal
from typing import Literal, Sequence

import librosa
import numpy as np
import soundfile
import torch
from cm_time import timer
from tqdm import tqdm

from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
from so_vits_svc_fork.utils import get_optimal_device
Expand All @@ -19,10 +20,11 @@
def infer(
*,
# paths
input_path: Path | str,
output_path: Path | str,
input_path: Path | str | Sequence[Path | str],
output_path: Path | str | Sequence[Path | str],
model_path: Path | str,
config_path: Path | str,
recursive: bool = False,
# svc config
speaker: int | str,
cluster_model_path: Path | str | None = None,
Expand All @@ -39,10 +41,36 @@ def infer(
max_chunk_seconds: float = 40,
device: str | torch.device = get_optimal_device(),
):
if isinstance(input_path, (str, Path)):
input_path = [input_path]
if isinstance(output_path, (str, Path)):
output_path = [output_path]
if len(input_path) != len(output_path):
raise ValueError(
f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
)

model_path = Path(model_path)
output_path = Path(output_path)
input_path = Path(input_path)
config_path = Path(config_path)
output_path = [Path(p) for p in output_path]
input_path = [Path(p) for p in input_path]
output_paths = []
input_paths = []

for input_path, output_path in zip(input_path, output_path):
if input_path.is_dir():
if not recursive:
raise ValueError(
f"input_path is a directory, but recursive is False: {input_path}"
)
input_paths.extend(list(input_path.rglob("*.*")))
output_paths.extend(
[output_path / p.relative_to(input_path) for p in input_paths]
)
continue
input_paths.append(input_path)
output_paths.append(output_path)

cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
Expand All @@ -53,23 +81,35 @@ def infer(
device=device,
)

audio, _ = librosa.load(input_path, sr=svc_model.target_sample)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)

soundfile.write(output_path, audio, svc_model.target_sample)
try:
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
for input_path, output_path in pbar:
pbar.set_description(f"{input_path}")
try:
audio, _ = librosa.load(input_path, sr=svc_model.target_sample)
except Exception as e:
LOG.error(f"Failed to load {input_path}")
LOG.exception(e)
continue
output_path.parent.mkdir(parents=True, exist_ok=True)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)
soundfile.write(output_path, audio, svc_model.target_sample)
finally:
del svc_model
torch.cuda.empty_cache()


def realtime(
Expand Down Expand Up @@ -215,14 +255,18 @@ def callback(
if rtf > 1:
LOG.warning("RTF is too high, consider increasing block_seconds")

with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
try:
with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
finally:
# del model, svc_model
torch.cuda.empty_cache()