diff --git a/README.md b/README.md index 77ab498..9ef66f2 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Predict subjective speech score with only 2 lines of code, with various MOS prediction systems. ```python -predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True) +predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "utmos22_strong", trust_repo=True) score = predictor(wave, sr) # tensor([3.7730]), good quality speech! ``` @@ -20,7 +20,7 @@ import torch import librosa wave, sr = librosa.load(".wav", sr=None, mono=True) -predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True) +predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "utmos22_strong", trust_repo=True) score = predictor(torch.from_numpy(wave).unsqueeze(0), sr) # tensor([3.7730]) ``` @@ -32,7 +32,7 @@ SpeechMOS use `torch.hub` built-in model loader, so no needs of library import First, instantiate a MOS predictor with model specifier string: ```python import torch -predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "", trust_repo=True) +predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "", trust_repo=True) ``` Then, pass tensor of speeches :: `(Batch, Time)`: diff --git a/hubconf.py b/hubconf.py index 5442a38..2fa3c1d 100644 --- a/hubconf.py +++ b/hubconf.py @@ -15,7 +15,7 @@ # Weight transfer code is in my fork (`/demo/utmos_strong_alt`). -def utmos22_strong(progress: bool = True) -> UTMOS22Strong: +def utmos22_strong(progress: bool = True, pretrained: bool = True) -> UTMOS22Strong: """ `UTMOS strong learner` speech naturalness MOS predictor. @@ -23,9 +23,15 @@ def utmos22_strong(progress: bool = True) -> UTMOS22Strong: progress - Whether to show model checkpoint load progress """ - state_dict = torch.hub.load_state_dict_from_url(url=URLS["utmos22_strong"], map_location="cpu", progress=progress) + # Init model = UTMOS22Strong() - model.load_state_dict(state_dict) + + # Pretrained weights + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(url=URLS["utmos22_strong"], map_location="cpu", progress=progress) + model.load_state_dict(state_dict) + + # Mode model.eval() return model diff --git a/pyproject.toml b/pyproject.toml index 27a8bec..64041a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "speechmos" -version = "1.0.0" +version = "1.1.0" description = "Easy-to-Use Speech MOS predictors 🎧" authors = ["tarepan"] readme = "README.md"