From 9161ad080ce791e63286f599332d1d2b33153ab3 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 26 Jun 2024 08:37:23 -0700 Subject: [PATCH] Fix model downloading in bento Summary: The model checkpoint path can not be created for Squim models. Use the latest download_asset method to fix it. Differential Revision: D59061348 --- src/torchaudio/pipelines/_squim_pipeline.py | 40 ++++++--------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/torchaudio/pipelines/_squim_pipeline.py b/src/torchaudio/pipelines/_squim_pipeline.py index f1d11bff538..0c70db4aef7 100644 --- a/src/torchaudio/pipelines/_squim_pipeline.py +++ b/src/torchaudio/pipelines/_squim_pipeline.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from torchaudio._internal import load_state_dict_from_url +import torch +import torchaudio from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective @@ -42,26 +43,16 @@ class SquimObjectiveBundle: _path: str _sample_rate: float - def _get_state_dict(self, dl_kwargs): - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(url, **dl_kwargs) - return state_dict - - def get_model(self, *, dl_kwargs=None) -> SquimObjective: + def get_model(self) -> SquimObjective: """Construct the SquimObjective model, and load the pretrained weight. - The weight file is downloaded from the internet and cached with - :func:`torch.hub.load_state_dict_from_url` - - Args: - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. - Returns: Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_objective_base() - model.load_state_dict(self._get_state_dict(dl_kwargs)) + path = torchaudio.utils.download_asset(f"models/{self._path}") + state_dict = torch.load(path, weights_only=True) + model.load_state_dict(state_dict) model.eval() return model @@ -128,26 +119,15 @@ class SquimSubjectiveBundle: _path: str _sample_rate: float - def _get_state_dict(self, dl_kwargs): - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(url, **dl_kwargs) - return state_dict - - def get_model(self, *, dl_kwargs=None) -> SquimSubjective: + def get_model(self) -> SquimSubjective: """Construct the SquimSubjective model, and load the pretrained weight. - - The weight file is downloaded from the internet and cached with - :func:`torch.hub.load_state_dict_from_url` - - Args: - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. - Returns: Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_subjective_base() - model.load_state_dict(self._get_state_dict(dl_kwargs)) + path = torchaudio.utils.download_asset(f"models/{self._path}") + state_dict = torch.load(path, weights_only=True) + model.load_state_dict(state_dict) model.eval() return model