|
1 | 1 | from dataclasses import dataclass
|
2 | 2 |
|
3 |
| -from torchaudio._internal import load_state_dict_from_url |
| 3 | +import torch |
| 4 | +import torchaudio |
4 | 5 |
|
5 | 6 | from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
|
6 | 7 |
|
@@ -42,26 +43,16 @@ class SquimObjectiveBundle:
|
42 | 43 | _path: str
|
43 | 44 | _sample_rate: float
|
44 | 45 |
|
45 |
| - def _get_state_dict(self, dl_kwargs): |
46 |
| - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" |
47 |
| - dl_kwargs = {} if dl_kwargs is None else dl_kwargs |
48 |
| - state_dict = load_state_dict_from_url(url, **dl_kwargs) |
49 |
| - return state_dict |
50 |
| - |
51 |
| - def get_model(self, *, dl_kwargs=None) -> SquimObjective: |
| 46 | + def get_model(self) -> SquimObjective: |
52 | 47 | """Construct the SquimObjective model, and load the pretrained weight.
|
53 | 48 |
|
54 |
| - The weight file is downloaded from the internet and cached with |
55 |
| - :func:`torch.hub.load_state_dict_from_url` |
56 |
| -
|
57 |
| - Args: |
58 |
| - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. |
59 |
| -
|
60 | 49 | Returns:
|
61 | 50 | Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
62 | 51 | """
|
63 | 52 | model = squim_objective_base()
|
64 |
| - model.load_state_dict(self._get_state_dict(dl_kwargs)) |
| 53 | + path = torchaudio.utils.download_asset(f"models/{self._path}") |
| 54 | + state_dict = torch.load(path, weights_only=True) |
| 55 | + model.load_state_dict(state_dict) |
65 | 56 | model.eval()
|
66 | 57 | return model
|
67 | 58 |
|
@@ -128,26 +119,15 @@ class SquimSubjectiveBundle:
|
128 | 119 | _path: str
|
129 | 120 | _sample_rate: float
|
130 | 121 |
|
131 |
| - def _get_state_dict(self, dl_kwargs): |
132 |
| - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" |
133 |
| - dl_kwargs = {} if dl_kwargs is None else dl_kwargs |
134 |
| - state_dict = load_state_dict_from_url(url, **dl_kwargs) |
135 |
| - return state_dict |
136 |
| - |
137 |
| - def get_model(self, *, dl_kwargs=None) -> SquimSubjective: |
| 122 | + def get_model(self) -> SquimSubjective: |
138 | 123 | """Construct the SquimSubjective model, and load the pretrained weight.
|
139 |
| -
|
140 |
| - The weight file is downloaded from the internet and cached with |
141 |
| - :func:`torch.hub.load_state_dict_from_url` |
142 |
| -
|
143 |
| - Args: |
144 |
| - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. |
145 |
| -
|
146 | 124 | Returns:
|
147 | 125 | Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
148 | 126 | """
|
149 | 127 | model = squim_subjective_base()
|
150 |
| - model.load_state_dict(self._get_state_dict(dl_kwargs)) |
| 128 | + path = torchaudio.utils.download_asset(f"models/{self._path}") |
| 129 | + state_dict = torch.load(path, weights_only=True) |
| 130 | + model.load_state_dict(state_dict) |
151 | 131 | model.eval()
|
152 | 132 | return model
|
153 | 133 |
|
|
0 commit comments