Skip to content

Commit 69b2a0a

Browse files
authored
Fix model downloading in bento (#3803)
Summary: Pull Request resolved: #3803 The model checkpoint path can not be created for Squim models. Use the latest download_asset method to fix it. Reviewed By: moto-meta Differential Revision: D59061348
1 parent 7f6209b commit 69b2a0a

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

src/torchaudio/pipelines/_squim_pipeline.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22

3-
from torchaudio._internal import load_state_dict_from_url
3+
import torch
4+
import torchaudio
45

56
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
67

@@ -42,26 +43,16 @@ class SquimObjectiveBundle:
4243
_path: str
4344
_sample_rate: float
4445

45-
def _get_state_dict(self, dl_kwargs):
46-
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
47-
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
48-
state_dict = load_state_dict_from_url(url, **dl_kwargs)
49-
return state_dict
50-
51-
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
46+
def get_model(self) -> SquimObjective:
5247
"""Construct the SquimObjective model, and load the pretrained weight.
5348
54-
The weight file is downloaded from the internet and cached with
55-
:func:`torch.hub.load_state_dict_from_url`
56-
57-
Args:
58-
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
59-
6049
Returns:
6150
Variation of :py:class:`~torchaudio.models.SquimObjective`.
6251
"""
6352
model = squim_objective_base()
64-
model.load_state_dict(self._get_state_dict(dl_kwargs))
53+
path = torchaudio.utils.download_asset(f"models/{self._path}")
54+
state_dict = torch.load(path, weights_only=True)
55+
model.load_state_dict(state_dict)
6556
model.eval()
6657
return model
6758

@@ -128,26 +119,15 @@ class SquimSubjectiveBundle:
128119
_path: str
129120
_sample_rate: float
130121

131-
def _get_state_dict(self, dl_kwargs):
132-
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
133-
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
134-
state_dict = load_state_dict_from_url(url, **dl_kwargs)
135-
return state_dict
136-
137-
def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
122+
def get_model(self) -> SquimSubjective:
138123
"""Construct the SquimSubjective model, and load the pretrained weight.
139-
140-
The weight file is downloaded from the internet and cached with
141-
:func:`torch.hub.load_state_dict_from_url`
142-
143-
Args:
144-
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
145-
146124
Returns:
147125
Variation of :py:class:`~torchaudio.models.SquimObjective`.
148126
"""
149127
model = squim_subjective_base()
150-
model.load_state_dict(self._get_state_dict(dl_kwargs))
128+
path = torchaudio.utils.download_asset(f"models/{self._path}")
129+
state_dict = torch.load(path, weights_only=True)
130+
model.load_state_dict(state_dict)
151131
model.eval()
152132
return model
153133

0 commit comments

Comments
 (0)