diff --git a/wenet/cli/hub.py b/wenet/cli/hub.py index 12176c0b4..43169f6a5 100644 --- a/wenet/cli/hub.py +++ b/wenet/cli/hub.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import requests import sys import tarfile from pathlib import Path @@ -72,11 +73,11 @@ class Hub(object): # TODO(Mddct): make assets class to support other language Assets = { # wenetspeech - "chinese": - "https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/chs.tar.gz", + "chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz", # gigaspeech - "english": - "https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/en.tar.gz" + "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", + # paraformer + "paraformer": "paraformer.tar.gz" } def __init__(self) -> None: @@ -89,14 +90,14 @@ def get_model_by_lang(lang: str) -> str: sys.exit(1) # NOTE(Mddct): model_dir structure - # Path.Home()/.went + # Path.Home()/.wenet # - chs # - units.txt # - final.zip # - en # - units.txt # - final.zip - model_url = Hub.Assets[lang] + model = Hub.Assets[lang] model_dir = os.path.join(Path.home(), ".wenet", lang) if not os.path.exists(model_dir): os.makedirs(model_dir) @@ -104,5 +105,12 @@ def get_model_by_lang(lang: str) -> str: if set(["final.zip", "units.txt"]).issubset(set(os.listdir(model_dir))): return model_dir + # If not exist, download + response = requests.get( + "https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa + ) + model_info = next(data for data in response.json()["Data"] + if data["Key"] == model) + model_url = model_info['Url'] download(model_url, model_dir, only_child=True) return model_dir diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index fab4e0090..3fd5c65e2 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -4,12 +4,12 @@ import torchaudio import torchaudio.compliance.kaldi as kaldi +from wenet.cli.hub import Hub from wenet.paraformer.search import paraformer_greedy_search from wenet.utils.file_utils import read_symbol_table class Paraformer: - def __init__(self, model_dir: str) -> None: model_path = os.path.join(model_dir, 'final.zip') @@ -60,7 +60,7 @@ def align(self, audio_file: str, label: str) -> dict: raise NotImplementedError("Align is currently not supported") -def load_model(language: str = None, model_dir: str = None) -> Paraformer: +def load_model(model_dir: str = None) -> Paraformer: if model_dir is None: - model_dir = Hub.get_model_by_lang(language) + model_dir = Hub.get_model_by_lang('paraformer') return Paraformer(model_dir) diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index e12b8b436..53e101c41 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -53,7 +53,7 @@ def main(): args = get_args() if args.paraformer: - model = load_paraformer(args.language, args.model_dir) + model = load_paraformer(args.model_dir) else: model = load_model(args.language, args.model_dir) if args.align: