Skip to content

Commit ced4cfd

Browse files
agrinherogol
authored andcommitted
Allow saving / loading checkpoints from cloud paths (coqui-ai#683)
* Allow saving / loading checkpoints from cloud paths Allows saving and loading checkpoints directly from cloud paths like Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec. Note: The user will have to install the relevant dependency for each protocol. Otherwise fsspec will fail and specify which dependency is missing. * Append suffix _fsspec to save/load function names * Add a lower bound to the fsspec dependency Skips the 0 major version. * Add missing changes from refactor * Use fsspec for remaining artifacts * Add test case with path requiring fsspec * Avoid writing logs to file unless output_path is local * Document the possibility of using paths supported by fsspec * Fix style and lint * Add missing lint fixes * Add type annotations to new functions * Use Coqpit method for converting config to dict * Fix type annotation in semi-new function * Add return type for load_fsspec * Fix bug where fs not always created * Restore the experiment removal functionality
1 parent 181177a commit ced4cfd

31 files changed

+218
-94
lines changed

TTS/bin/convert_melgan_torch_to_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tensorflow as tf
77
import torch
88

9-
from TTS.utils.io import load_config
9+
from TTS.utils.io import load_config, load_fsspec
1010
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
1111
compare_torch_tf,
1212
convert_tf_name,
@@ -33,7 +33,7 @@
3333

3434
# init torch model
3535
model = setup_generator(c)
36-
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
36+
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
3737
state_dict = checkpoint["model"]
3838
model.load_state_dict(state_dict)
3939
model.remove_weight_norm()

TTS/bin/convert_tacotron2_torch_to_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
1414
from TTS.tts.tf.utils.generic_utils import save_checkpoint
1515
from TTS.tts.utils.text.symbols import phonemes, symbols
16-
from TTS.utils.io import load_config
16+
from TTS.utils.io import load_config, load_fsspec
1717

1818
sys.path.append("/home/erogol/Projects")
1919
os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -32,7 +32,7 @@
3232

3333
# init torch model
3434
model = setup_model(c)
35-
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
35+
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
3636
state_dict = checkpoint["model"]
3737
model.load_state_dict(state_dict)
3838

TTS/bin/extract_tts_spectrograms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from TTS.tts.utils.speakers import get_speaker_manager
1717
from TTS.utils.audio import AudioProcessor
1818
from TTS.utils.generic_utils import count_parameters
19+
from TTS.utils.io import load_fsspec
1920

2021
use_cuda = torch.cuda.is_available()
2122

@@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
239240
model = setup_model(c)
240241

241242
# restore model
242-
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
243+
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
243244
model.load_state_dict(checkpoint["model"])
244245

245246
if use_cuda:

TTS/bin/train_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from TTS.tts.datasets import load_meta_data
1818
from TTS.utils.audio import AudioProcessor
1919
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
20+
from TTS.utils.io import load_fsspec
2021
from TTS.utils.radam import RAdam
2122
from TTS.utils.training import NoamLR, check_update
2223

@@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
169170
raise Exception("The %s not is a loss supported" % c.loss)
170171

171172
if args.restore_path:
172-
checkpoint = torch.load(args.restore_path)
173+
checkpoint = load_fsspec(args.restore_path)
173174
try:
174175
model.load_state_dict(checkpoint["model"])
175176

TTS/config/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
from typing import Dict
55

6+
import fsspec
67
import yaml
78
from coqpit import Coqpit
89

@@ -13,7 +14,7 @@
1314
def read_json_with_comments(json_path):
1415
"""for backward compat."""
1516
# fallback to json
16-
with open(json_path, "r", encoding="utf-8") as f:
17+
with fsspec.open(json_path, "r", encoding="utf-8") as f:
1718
input_str = f.read()
1819
# handle comments
1920
input_str = re.sub(r"\\\n", "", input_str)
@@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
7677
config_dict = {}
7778
ext = os.path.splitext(config_path)[1]
7879
if ext in (".yml", ".yaml"):
79-
with open(config_path, "r", encoding="utf-8") as f:
80+
with fsspec.open(config_path, "r", encoding="utf-8") as f:
8081
data = yaml.safe_load(f)
8182
elif ext == ".json":
8283
try:
83-
with open(config_path, "r", encoding="utf-8") as f:
84-
input_str = f.read()
85-
data = json.loads(input_str)
84+
with fsspec.open(config_path, "r", encoding="utf-8") as f:
85+
data = json.load(f)
8686
except json.decoder.JSONDecodeError:
8787
# backwards compat.
8888
data = read_json_with_comments(config_path)

TTS/config/shared_configs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit):
225225
num_eval_loader_workers (int):
226226
Number of workers for evaluation time dataloader.
227227
output_path (str):
228-
Path for training output folder. The nonexist part of the given path is created automatically.
229-
All training outputs are saved there.
228+
Path for training output folder, either a local file path or other
229+
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
230+
S3 (s3://) paths. The nonexist part of the given path is created
231+
automatically. All training artefacts are saved there.
230232
"""
231233

232234
model: str = None

TTS/speaker_encoder/models/lstm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
from torch import nn
44

5+
from TTS.utils.io import load_fsspec
6+
57

68
class LSTMWithProjection(nn.Module):
79
def __init__(self, input_size, hidden_size, proj_size):
@@ -120,7 +122,7 @@ def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
120122

121123
# pylint: disable=unused-argument, redefined-builtin
122124
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
123-
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
125+
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
124126
self.load_state_dict(state["model"])
125127
if use_cuda:
126128
self.cuda()

TTS/speaker_encoder/models/resnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
import torch.nn as nn
44

5+
from TTS.utils.io import load_fsspec
6+
57

68
class SELayer(nn.Module):
79
def __init__(self, channel, reduction=8):
@@ -201,7 +203,7 @@ def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
201203
return embeddings
202204

203205
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
204-
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
206+
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
205207
self.load_state_dict(state["model"])
206208
if use_cuda:
207209
self.cuda()

TTS/speaker_encoder/utils/generic_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from multiprocessing import Manager
77

88
import numpy as np
9-
import torch
109
from scipy import signal
1110

1211
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
1312
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
13+
from TTS.utils.io import save_fsspec
1414

1515

1616
class Storage(object):
@@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
198198
"loss": model_loss,
199199
"date": datetime.date.today().strftime("%B %d, %Y"),
200200
}
201-
torch.save(state, checkpoint_path)
201+
save_fsspec(state, checkpoint_path)
202202

203203

204204
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
@@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
216216
bestmodel_path = "best_model.pth.tar"
217217
bestmodel_path = os.path.join(out_path, bestmodel_path)
218218
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
219-
torch.save(state, bestmodel_path)
219+
save_fsspec(state, bestmodel_path)
220220
return best_loss

TTS/speaker_encoder/utils/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import os
33

4-
import torch
4+
from TTS.utils.io import save_fsspec
55

66

77
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
@@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
1717
"loss": model_loss,
1818
"date": datetime.date.today().strftime("%B %d, %Y"),
1919
}
20-
torch.save(state, checkpoint_path)
20+
save_fsspec(state, checkpoint_path)
2121

2222

2323
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
@@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
3434
bestmodel_path = "best_model.pth.tar"
3535
bestmodel_path = os.path.join(out_path, bestmodel_path)
3636
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
37-
torch.save(state, bestmodel_path)
37+
save_fsspec(state, bestmodel_path)
3838
return best_loss

TTS/trainer.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22

3-
import glob
43
import importlib
54
import logging
65
import os
@@ -12,7 +11,9 @@
1211
from argparse import Namespace
1312
from dataclasses import dataclass, field
1413
from typing import Dict, List, Tuple, Union
14+
from urllib.parse import urlparse
1515

16+
import fsspec
1617
import torch
1718
from coqpit import Coqpit
1819
from torch import nn
@@ -29,13 +30,13 @@
2930
from TTS.utils.generic_utils import (
3031
KeepAverage,
3132
count_parameters,
32-
create_experiment_folder,
33+
get_experiment_folder_path,
3334
get_git_branch,
3435
remove_experiment_folder,
3536
set_init_dict,
3637
to_cuda,
3738
)
38-
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
39+
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
3940
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
4041
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
4142
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@@ -173,7 +174,6 @@ def __init__(
173174
self.best_loss = float("inf")
174175
self.train_loader = None
175176
self.eval_loader = None
176-
self.output_audio_path = os.path.join(output_path, "test_audios")
177177

178178
self.keep_avg_train = None
179179
self.keep_avg_eval = None
@@ -309,7 +309,7 @@ def _restore_list_objs(states, obj):
309309
return obj
310310

311311
print(" > Restoring from %s ..." % os.path.basename(restore_path))
312-
checkpoint = torch.load(restore_path)
312+
checkpoint = load_fsspec(restore_path)
313313
try:
314314
print(" > Restoring Model...")
315315
model.load_state_dict(checkpoint["model"])
@@ -776,7 +776,7 @@ def _fit(self) -> None:
776776
"""🏃 train -> evaluate -> test for the number of epochs."""
777777
if self.restore_step != 0 or self.args.best_path:
778778
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
779-
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
779+
self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"]
780780
print(f" > Starting with loaded last best loss {self.best_loss}.")
781781

782782
self.total_steps_done = self.restore_step
@@ -834,9 +834,16 @@ def save_best_model(self) -> None:
834834

835835
@staticmethod
836836
def _setup_logger_config(log_file: str) -> None:
837-
logging.basicConfig(
838-
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
839-
)
837+
handlers = [logging.StreamHandler()]
838+
839+
# Only add a log file if the output location is local due to poor
840+
# support for writing logs to file-like objects.
841+
parsed_url = urlparse(log_file)
842+
if not parsed_url.scheme or parsed_url.scheme == "file":
843+
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
844+
handlers.append(logging.FileHandler(schemeless_path))
845+
846+
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
840847

841848
@staticmethod
842849
def _is_apex_available() -> bool:
@@ -926,22 +933,27 @@ def init_arguments():
926933
return parser
927934

928935

929-
def get_last_checkpoint(path):
936+
def get_last_checkpoint(path: str) -> Tuple[str, str]:
930937
"""Get latest checkpoint or/and best model in path.
931938
932939
It is based on globbing for `*.pth.tar` and the RegEx
933940
`(checkpoint|best_model)_([0-9]+)`.
934941
935942
Args:
936-
path (list): Path to files to be compared.
943+
path: Path to files to be compared.
937944
938945
Raises:
939946
ValueError: If no checkpoint or best_model files are found.
940947
941948
Returns:
942-
last_checkpoint (str): Last checkpoint filename.
949+
Path to the last checkpoint
950+
Path to best checkpoint
943951
"""
944-
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
952+
fs = fsspec.get_mapper(path).fs
953+
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
954+
scheme = urlparse(path).scheme
955+
if scheme: # scheme is not preserved in fs.glob, add it back
956+
file_names = [scheme + "://" + file_name for file_name in file_names]
945957
last_models = {}
946958
last_model_nums = {}
947959
for key in ["checkpoint", "best_model"]:
@@ -963,7 +975,7 @@ def get_last_checkpoint(path):
963975
key_file_names = [fn for fn in file_names if key in fn]
964976
if last_model is None and len(key_file_names) > 0:
965977
last_model = max(key_file_names, key=os.path.getctime)
966-
last_model_num = torch.load(last_model)["step"]
978+
last_model_num = load_fsspec(last_model)["step"]
967979

968980
if last_model is not None:
969981
last_models[key] = last_model
@@ -1030,12 +1042,11 @@ def process_args(args, config=None):
10301042
print(" > Mixed precision mode is ON")
10311043
experiment_path = args.continue_path
10321044
if not experiment_path:
1033-
experiment_path = create_experiment_folder(config.output_path, config.run_name)
1045+
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
10341046
audio_path = os.path.join(experiment_path, "test_audios")
10351047
# setup rank 0 process in distributed training
10361048
tb_logger = None
10371049
if args.rank == 0:
1038-
os.makedirs(audio_path, exist_ok=True)
10391050
new_fields = {}
10401051
if args.restore_path:
10411052
new_fields["restore_path"] = args.restore_path
@@ -1047,8 +1058,6 @@ def process_args(args, config=None):
10471058
used_characters = parse_symbols()
10481059
new_fields["characters"] = used_characters
10491060
copy_model_files(config, experiment_path, new_fields)
1050-
os.chmod(audio_path, 0o775)
1051-
os.chmod(experiment_path, 0o775)
10521061
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
10531062
# write model desc to tensorboard
10541063
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)

TTS/tts/models/align_tts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from TTS.tts.utils.measures import alignment_diagonal_score
1717
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
1818
from TTS.utils.audio import AudioProcessor
19+
from TTS.utils.io import load_fsspec
1920

2021

2122
@dataclass
@@ -389,7 +390,7 @@ def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
389390
def load_checkpoint(
390391
self, config, checkpoint_path, eval=False
391392
): # pylint: disable=unused-argument, redefined-builtin
392-
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
393+
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
393394
self.load_state_dict(state["model"])
394395
if eval:
395396
self.eval()

TTS/tts/models/base_tacotron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
1414
from TTS.tts.utils.text import make_symbols
1515
from TTS.utils.generic_utils import format_aux_input
16+
from TTS.utils.io import load_fsspec
1617
from TTS.utils.training import gradual_training_scheduler
1718

1819

@@ -113,7 +114,7 @@ def inference(self):
113114
def load_checkpoint(
114115
self, config, checkpoint_path, eval=False
115116
): # pylint: disable=unused-argument, redefined-builtin
116-
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
117+
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
117118
self.load_state_dict(state["model"])
118119
if "r" in state:
119120
self.decoder.set_r(state["r"])

0 commit comments

Comments
 (0)