Skip to content

Commit

Permalink
Add non-pretrained option
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Oct 3, 2023
1 parent 350b986 commit 0056278
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Predict subjective speech score with only 2 lines of code, with various MOS prediction systems.

```python
predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True)
predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "utmos22_strong", trust_repo=True)
score = predictor(wave, sr)
# tensor([3.7730]), good quality speech!
```
Expand All @@ -20,7 +20,7 @@ import torch
import librosa

wave, sr = librosa.load("<your_audio>.wav", sr=None, mono=True)
predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True)
predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "utmos22_strong", trust_repo=True)
score = predictor(torch.from_numpy(wave).unsqueeze(0), sr)
# tensor([3.7730])
```
Expand All @@ -32,7 +32,7 @@ SpeechMOS use `torch.hub` built-in model loader, so no needs of library import
First, instantiate a MOS predictor with model specifier string:
```python
import torch
predictor = torch.hub.load("tarepan/SpeechMOS:v1.0.0", "<model_specifier>", trust_repo=True)
predictor = torch.hub.load("tarepan/SpeechMOS:v1.1.0", "<model_specifier>", trust_repo=True)
```

Then, pass tensor of speeches :: `(Batch, Time)`:
Expand Down
12 changes: 9 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
# Weight transfer code is in my fork (`/demo/utmos_strong_alt`).


def utmos22_strong(progress: bool = True) -> UTMOS22Strong:
def utmos22_strong(progress: bool = True, pretrained: bool = True) -> UTMOS22Strong:
"""
`UTMOS strong learner` speech naturalness MOS predictor.
Args:
progress - Whether to show model checkpoint load progress
"""

state_dict = torch.hub.load_state_dict_from_url(url=URLS["utmos22_strong"], map_location="cpu", progress=progress)
# Init
model = UTMOS22Strong()
model.load_state_dict(state_dict)

# Pretrained weights
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(url=URLS["utmos22_strong"], map_location="cpu", progress=progress)
model.load_state_dict(state_dict)

# Mode
model.eval()

return model
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "speechmos"
version = "1.0.0"
version = "1.1.0"
description = "Easy-to-Use Speech MOS predictors 🎧"
authors = ["tarepan"]
readme = "README.md"
Expand Down

0 comments on commit 0056278

Please sign in to comment.