diff --git a/docs/_static/gui.png b/docs/_static/gui.png index 4417f1b8..543e1986 100644 Binary files a/docs/_static/gui.png and b/docs/_static/gui.png differ diff --git a/src/so_vits_svc_fork/__main__.py b/src/so_vits_svc_fork/__main__.py index 5b701b24..87bc47d1 100644 --- a/src/so_vits_svc_fork/__main__.py +++ b/src/so_vits_svc_fork/__main__.py @@ -164,6 +164,14 @@ def train( default=None, help="path to cluster model", ) +@click.option( + "-re", + "--recursive", + type=bool, + default=False, + help="Search recursively", + is_flag=True, +) @click.option("-t", "--transpose", type=int, default=0, help="transpose") @click.option( "-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)" @@ -215,6 +223,7 @@ def infer( output_path: Path, model_path: Path, config_path: Path, + recursive: bool, # svc config speaker: str, cluster_model_path: Path | None = None, @@ -244,6 +253,10 @@ def infer( if output_path is None: output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}" output_path = Path(output_path) + if input_path.is_dir() and not recursive: + raise ValueError( + "input_path is a directory. Use 0re or --recursive to infer recursively." + ) model_path = Path(model_path) if model_path.is_dir(): model_path = list( @@ -259,6 +272,7 @@ def infer( output_path=output_path, model_path=model_path, config_path=config_path, + recursive=recursive, # svc config speaker=speaker, cluster_model_path=cluster_model_path, diff --git a/src/so_vits_svc_fork/gui.py b/src/so_vits_svc_fork/gui.py index 602a6397..1795c996 100644 --- a/src/so_vits_svc_fork/gui.py +++ b/src/so_vits_svc_fork/gui.py @@ -73,14 +73,28 @@ def get_output_path(input_path: Path) -> Path: return output_path -def get_supported_file_types() -> tuple[tuple[str], ...]: - return tuple( +def get_supported_file_types() -> tuple[tuple[str, str], ...]: + res = tuple( [ - ((extension, f".{extension.lower()}")) + (extension, f".{extension.lower()}") for extension in sf.available_formats().keys() ] ) + # Sort by popularity + common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"] + res = sorted( + res, + key=lambda x: common_file_types.index(x[0]) + if x[0] in common_file_types + else len(common_file_types), + ) + return res + + +def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]: + return (("Audio", " ".join(sf.available_formats().keys())),) + def validate_output_file_type(output_path: Path) -> bool: supported_file_types = sorted( @@ -145,7 +159,24 @@ def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path: def main(): LOG.info(f"version: {__version__}") - sg.theme("Dark") + # sg.theme("Dark") + sg.theme_add_new( + "Very Dark", + { + "BACKGROUND": "#111111", + "TEXT": "#FFFFFF", + "INPUT": "#444444", + "TEXT_INPUT": "#FFFFFF", + "SCROLL": "#333333", + "BUTTON": ("white", "#112233"), + "PROGRESS": ("#111111", "#333333"), + "BORDER": 2, + "SLIDER_DEPTH": 2, + "PROGRESS_DEPTH": 2, + }, + ) + sg.theme("Very Dark") + model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth"))) frame_contents = { @@ -165,7 +196,10 @@ def main(): if Path("./logs/44k/").exists() else Path(".").absolute().as_posix(), key="model_path_browse", - file_types=(("PyTorch", "*.pth"),), + file_types=( + ("PyTorch", "G_*.pth G_*.pt"), + ("Pytorch", "*.pth *.pt"), + ), ), ], [ @@ -201,7 +235,7 @@ def main(): if Path("./logs/44k/").exists() else ".", key="cluster_model_path_browse", - file_types=(("PyTorch", "*.pt"),), + file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")), ), ], ], @@ -312,7 +346,17 @@ def main(): sg.Text("Input audio path"), sg.Push(), sg.InputText(key="input_path", enable_events=True), - sg.FileBrowse(initial_folder=".", key="input_path_browse"), + sg.FileBrowse( + initial_folder=".", + key="input_path_browse", + file_types=get_supported_file_types_concat(), + ), + sg.FolderBrowse( + button_text="Browse(Folder)", + initial_folder=".", + key="input_path_folder_browse", + target="input_path", + ), sg.Button("Play", key="play_input"), ], [ @@ -438,7 +482,7 @@ def main(): sg.Combo( key="presets", values=list(load_presets().keys()), - size=(20, 1), + size=(40, 1), enable_events=True, ), sg.Button("Delete preset", key="delete_preset"), @@ -446,7 +490,7 @@ def main(): [ sg.Text("Preset name"), sg.Stretch(), - sg.InputText(key="preset_name", size=(20, 1)), + sg.InputText(key="preset_name", size=(26, 1)), sg.Button("Add current settings as a preset", key="add_preset"), ], ], @@ -498,8 +542,15 @@ def main(): layout = [[column1, column2]] # layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]] window = sg.Window( - f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True - ) # , use_custom_titlebar=True) + f"{__name__.split('.')[0].replace('_', '-')} v{__version__}", + layout, + grab_anywhere=True, + finalize=True, + # Below disables taskbar, which may be not useful for some users + # use_custom_titlebar=True, no_titlebar=False + # Keep on top + # keep_on_top=True + ) # for n in ["input_device", "output_device"]: # window[n].Widget.configure(justify="right") event, values = window.read(timeout=0.01) @@ -620,11 +671,19 @@ def apply_preset(name: str) -> None: # Set a sensible default output path window.Element("output_path").Update(str(get_output_path(input_path))) elif event == "infer": - if not input_path.exists() or not input_path.is_file(): - LOG.warning(f"Input path {input_path} does not exist.") + if "Default VC" in values["presets"]: + window["presets"].update( + set_to_index=list(load_presets().keys()).index("Default File") + ) + apply_preset("Default File") + if values["input_path"] == "": + LOG.warning("Input path is empty.") continue - if not validate_output_file_type(output_path): + if not input_path.exists(): + LOG.warning(f"Input path {input_path} does not exist.") continue + # if not validate_output_file_type(output_path): + # continue try: from so_vits_svc_fork.inference.main import infer @@ -639,6 +698,7 @@ def apply_preset(name: str) -> None: output_path=output_path, input_path=input_path, config_path=Path(values["config_path"]), + recursive=True, # svc config speaker=values["speaker"], cluster_model_path=Path(values["cluster_model_path"]) diff --git a/src/so_vits_svc_fork/inference/main.py b/src/so_vits_svc_fork/inference/main.py index b58ff0bb..9c75e197 100644 --- a/src/so_vits_svc_fork/inference/main.py +++ b/src/so_vits_svc_fork/inference/main.py @@ -2,13 +2,14 @@ from logging import getLogger from pathlib import Path -from typing import Literal +from typing import Literal, Sequence import librosa import numpy as np import soundfile import torch from cm_time import timer +from tqdm import tqdm from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc from so_vits_svc_fork.utils import get_optimal_device @@ -19,10 +20,11 @@ def infer( *, # paths - input_path: Path | str, - output_path: Path | str, + input_path: Path | str | Sequence[Path | str], + output_path: Path | str | Sequence[Path | str], model_path: Path | str, config_path: Path | str, + recursive: bool = False, # svc config speaker: int | str, cluster_model_path: Path | str | None = None, @@ -39,10 +41,36 @@ def infer( max_chunk_seconds: float = 40, device: str | torch.device = get_optimal_device(), ): + if isinstance(input_path, (str, Path)): + input_path = [input_path] + if isinstance(output_path, (str, Path)): + output_path = [output_path] + if len(input_path) != len(output_path): + raise ValueError( + f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}" + ) + model_path = Path(model_path) - output_path = Path(output_path) - input_path = Path(input_path) config_path = Path(config_path) + output_path = [Path(p) for p in output_path] + input_path = [Path(p) for p in input_path] + output_paths = [] + input_paths = [] + + for input_path, output_path in zip(input_path, output_path): + if input_path.is_dir(): + if not recursive: + raise ValueError( + f"input_path is a directory, but recursive is False: {input_path}" + ) + input_paths.extend(list(input_path.rglob("*.*"))) + output_paths.extend( + [output_path / p.relative_to(input_path) for p in input_paths] + ) + continue + input_paths.append(input_path) + output_paths.append(output_path) + cluster_model_path = Path(cluster_model_path) if cluster_model_path else None svc_model = Svc( net_g_path=model_path.as_posix(), @@ -53,23 +81,35 @@ def infer( device=device, ) - audio, _ = librosa.load(input_path, sr=svc_model.target_sample) - audio = svc_model.infer_silence( - audio.astype(np.float32), - speaker=speaker, - transpose=transpose, - auto_predict_f0=auto_predict_f0, - cluster_infer_ratio=cluster_infer_ratio, - noise_scale=noise_scale, - f0_method=f0_method, - db_thresh=db_thresh, - pad_seconds=pad_seconds, - chunk_seconds=chunk_seconds, - absolute_thresh=absolute_thresh, - max_chunk_seconds=max_chunk_seconds, - ) - - soundfile.write(output_path, audio, svc_model.target_sample) + try: + pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1) + for input_path, output_path in pbar: + pbar.set_description(f"{input_path}") + try: + audio, _ = librosa.load(input_path, sr=svc_model.target_sample) + except Exception as e: + LOG.error(f"Failed to load {input_path}") + LOG.exception(e) + continue + output_path.parent.mkdir(parents=True, exist_ok=True) + audio = svc_model.infer_silence( + audio.astype(np.float32), + speaker=speaker, + transpose=transpose, + auto_predict_f0=auto_predict_f0, + cluster_infer_ratio=cluster_infer_ratio, + noise_scale=noise_scale, + f0_method=f0_method, + db_thresh=db_thresh, + pad_seconds=pad_seconds, + chunk_seconds=chunk_seconds, + absolute_thresh=absolute_thresh, + max_chunk_seconds=max_chunk_seconds, + ) + soundfile.write(output_path, audio, svc_model.target_sample) + finally: + del svc_model + torch.cuda.empty_cache() def realtime( @@ -215,14 +255,18 @@ def callback( if rtf > 1: LOG.warning("RTF is too high, consider increasing block_seconds") - with sd.Stream( - device=(input_device, output_device), - channels=1, - callback=callback, - samplerate=svc_model.target_sample, - blocksize=int(block_seconds * svc_model.target_sample), - latency="low", - ) as stream: - LOG.info(f"Latency: {stream.latency}") - while True: - sd.sleep(1000) + try: + with sd.Stream( + device=(input_device, output_device), + channels=1, + callback=callback, + samplerate=svc_model.target_sample, + blocksize=int(block_seconds * svc_model.target_sample), + latency="low", + ) as stream: + LOG.info(f"Latency: {stream.latency}") + while True: + sd.sleep(1000) + finally: + # del model, svc_model + torch.cuda.empty_cache()