Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

音声文字起こしをマルチGPUに対応 #113

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion gradio_tabs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from style_bert_vits2.logging import logger
from style_bert_vits2.utils.subprocess import run_script_with_log

import argparse
import transcribe

def do_slice(
model_name: str,
Expand Down Expand Up @@ -39,18 +41,34 @@ def do_slice(

def do_transcribe(
model_name,
whisper_model,
model,
compute_type,
language,
initial_prompt,
device,
device_indexes,
use_hf_whisper,
batch_size,
num_beams,
no_repeat_ngram_size: int = 10,
):
if model_name == "":
return "Error: モデル名を入力してください。"

success, message = transcribe.run(
model_name,
model,
compute_type,
language,
initial_prompt,
device,
device_indexes,
use_hf_whisper,
batch_size,
num_beams,
no_repeat_ngram_size,
)
'''
cmd = [
"transcribe.py",
"--model_name",
Expand All @@ -72,9 +90,19 @@ def do_transcribe(
cmd.append("--use_hf_whisper")
cmd.extend(["--batch_size", str(batch_size)])
success, message = run_script_with_log(cmd)
'''
if not success:
return f"Error: {message}. エラーメッセージが空の場合、何も問題がない可能性があるので、書き起こしファイルをチェックして問題なければ無視してください。"

import torch

# 使用可能なGPUのインデックスリスト取得
def get_gpu_indexes():
gpu_indexes = []
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
gpu_indexes.append(str(i))
return ','.join(gpu_indexes)

how_to_md = """
Style-Bert-VITS2の学習用データセットを作成するためのツールです。以下の2つからなります。
Expand Down Expand Up @@ -192,6 +220,11 @@ def create_dataset_app() -> gr.Blocks:
visible=False,
)
device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda")
device_indexes = gr.Textbox(
label="使用GPUインデックス",
value=get_gpu_indexes(),
info="使用するGPUインデックスをカンマ区切りで指定、例文(0,1,2)",
)
language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語")
initial_prompt = gr.Textbox(
label="初期プロンプト",
Expand Down Expand Up @@ -229,6 +262,7 @@ def create_dataset_app() -> gr.Blocks:
language,
initial_prompt,
device,
device_indexes,
use_hf_whisper,
batch_size,
num_beams,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ tensorboard
torch>=2.1
transformers
umap-learn
portalocker
210 changes: 210 additions & 0 deletions transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,54 @@
from style_bert_vits2.logging import logger
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT

from threading import Thread
import queue
import time
# 使用可能なGPUインデックスを入れるキュー
device_queue = queue.Queue()

def transcribe_thread(
model,
device_index,
wav_file,
output_file,
model_name,
language_id,
initial_prompt,
language,
num_beams,
no_repeat_ngram_size
):

text = transcribe_with_faster_whisper(
model=model,
audio_file=wav_file,
initial_prompt=initial_prompt,
language=language,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
)

with open(output_file, "a", encoding="utf-8") as f:
if lock_file(f):
f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n")
unlock_file(f)

device_queue.put(device_index)

import portalocker

def lock_file(file_obj):
try:
# ファイルに排他ロックを設定する
portalocker.lock(file_obj, portalocker.LOCK_EX)
return True
except portalocker.LockException:
return False

def unlock_file(file_obj):
# ファイルのロックを解除する
portalocker.unlock(file_obj)

# faster-whisperは並列処理しても速度が向上しないので、単一モデルでループ処理する
def transcribe_with_faster_whisper(
Expand Down Expand Up @@ -103,6 +151,150 @@ def transcribe_files_with_hf_whisper(

return results

def run(
model_name:str,
model:str="large-v3",
compute_type:str="bfloat16",
language:str="ja",
initial_prompt:str="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!",
device:str="cuda",
device_indexes:str="0",
use_hf_whisper=True,
batch_size:int=16,
num_beams:int=1,
no_repeat_ngram_size:int=10,
):

with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
path_config: dict[str, str] = yaml.safe_load(f.read())
dataset_root = Path(path_config["dataset_root"])

model_name = str(model_name)

input_dir = dataset_root / model_name / "raw"
output_file = dataset_root / model_name / "esd.list"
initial_prompt: str = initial_prompt
initial_prompt = initial_prompt.strip('"')
language: str = language
device: str = device
# GPUインデックスリスト
device_indexes = [int(x) for x in device_indexes.split(',')]
compute_type: str = compute_type
batch_size: int = batch_size
num_beams: int = num_beams
no_repeat_ngram_size: int = no_repeat_ngram_size

output_file.parent.mkdir(parents=True, exist_ok=True)

wav_files = [f for f in input_dir.rglob("*.wav") if f.is_file()]
wav_files = sorted(wav_files, key=lambda x: x.name)

if output_file.exists():
logger.warning(f"{output_file} exists, backing up to {output_file}.bak")
backup_path = output_file.with_name(output_file.name + ".bak")
if backup_path.exists():
logger.warning(f"{output_file}.bak exists, deleting...")
backup_path.unlink()
output_file.rename(backup_path)

if language == "ja":
language_id = Languages.JP.value
elif language == "en":
language_id = Languages.EN.value
elif language == "zh":
language_id = Languages.ZH.value
else:
raise ValueError(f"{language} is not supported.")

if not use_hf_whisper:
from faster_whisper import WhisperModel

logger.info(
f"Loading faster-whisper model ({model}) with compute_type={compute_type}"
)

models = {}

# 使用するGPUの数だけモデルを作成する。
for device_index in device_indexes:
try:
model_object = WhisperModel(model, device=device, device_index=device_index, compute_type=compute_type)
except ValueError as e:
logger.warning(f"Failed to load model, so use `auto` compute_type: {e}")
model_object = WhisperModel(model, device=device, device_index=device_index)
models[device_index]=model_object
# 使用可能なモデルのキューを入れる
device_queue.put(device_index)

# マルチスレッド開始
threads = []
for wav_file in tqdm(wav_files):
while True:
# 使用可能なモデルが無ければループする。
if not device_queue.empty():
device_index = device_queue.get()
thread = Thread(target=transcribe_thread, args=(
models[device_index],
device_index,
wav_file,
output_file,
model_name,
language_id,
initial_prompt,
language,
num_beams,
no_repeat_ngram_size
))
thread.start()
threads.append(thread)
break
time.sleep(0.01)

for thread in threads:
thread.join()

# モデルの解放
for device_index in device_indexes:
models[device_index]=None

'''
try:
model = WhisperModel(args.model, device=device, compute_type=compute_type)
except ValueError as e:
logger.warning(f"Failed to load model, so use `auto` compute_type: {e}")
model = WhisperModel(args.model, device=device)
for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
text = transcribe_with_faster_whisper(
model=model,
audio_file=wav_file,
initial_prompt=initial_prompt,
language=language,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
)
with open(output_file, "a", encoding="utf-8") as f:
f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n")
'''
else:
model_id = f"openai/whisper-{model}"
logger.info(f"Loading HF Whisper model ({model_id})")
pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT)
results = transcribe_files_with_hf_whisper(
audio_files=wav_files,
model_id=model_id,
initial_prompt=initial_prompt,
language=language,
batch_size=batch_size,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
device=device,
pbar=pbar,
)
with open(output_file, "w", encoding="utf-8") as f:
for wav_file, text in zip(wav_files, results):
f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n")

return True, ""

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -124,6 +316,23 @@ def transcribe_files_with_hf_whisper(
parser.add_argument("--no_repeat_ngram_size", type=int, default=10)
args = parser.parse_args()

run(
args.model_name,
args.model,
args.compute_type,
args.language,
args.initial_prompt,
args.device,
args.device_indexes,
args.use_hf_whisper,
args.batch_size,
args.num_beams,
args.no_repeat_ngram_size,
)

sys.exit(0)

'''
with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
path_config: dict[str, str] = yaml.safe_load(f.read())
dataset_root = Path(path_config["dataset_root"])
Expand Down Expand Up @@ -205,3 +414,4 @@ def transcribe_files_with_hf_whisper(
f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n")

sys.exit(0)
'''