Skip to content

Commit

Permalink
Merge pull request #407 from jhj0517/feature/direct-hf-model
Browse files Browse the repository at this point in the history
Enable direct use of model from huggingface
  • Loading branch information
jhj0517 authored Nov 23, 2024
2 parents 8b06244 + 5a11504 commit 12bd736
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_pipeline_inputs(self):

with gr.Row():
dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],
label=_("Model"))
label=_("Model"), allow_custom_value=True)
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
else whisper_params["lang"], label=_("Language"))
Expand Down
18 changes: 16 additions & 2 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import huggingface_hub
import numpy as np
import torch
from typing import BinaryIO, Union, Tuple, List
Expand Down Expand Up @@ -118,15 +119,28 @@ def update_model(self,
Parameters
----------
model_size: str
Size of whisper model
Size of whisper model. If you enter the huggingface repo id, it will try to download the model
automatically from huggingface.
compute_type: str
Compute type for transcription.
see more info : https://opennmt.net/CTranslate2/quantization.html
progress: gr.Progress
Indicator to show progress directly in gradio.
"""
progress(0, desc="Initializing Model..")
self.current_model_size = self.model_paths[model_size]

model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size
if model_size not in self.model_paths and model_size_dirname not in self.model_paths:
print(f"Model is not detected. Trying to download \"{model_size}\" from huggingface to "
f"\"{os.path.join(self.model_dir, model_size_dirname)} ...")
huggingface_hub.snapshot_download(
model_size,
local_dir=os.path.join(self.model_dir, model_size_dirname),
)
self.model_paths = self.get_model_paths()
gr.Info(f"Model is downloaded with the name \"{model_size_dirname}\"")

self.current_model_size = self.model_paths[model_size_dirname]

local_files_only = False
hf_prefix = "models--Systran--faster-whisper-"
Expand Down

0 comments on commit 12bd736

Please sign in to comment.