Skip to content

Commit cc1c6fb

Browse files
merged single gpu changes and some fixes
2 parents d429578 + c55596a commit cc1c6fb

29 files changed

+612
-180
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,6 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
# output logs
163+
*.out

README.md

+430
Large diffs are not rendered by default.

config/ViT.yaml

+39-14
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,26 @@ base: &base
33
# Model config
44
embed_dim: 384
55
depth: 12
6-
dropout: 0
6+
dropout: 0.0
77
patch_size: 8
88
num_heads: 8
99

1010
# Training config
1111
img_size: [360, 720]
1212
dt: 1
13-
global_batch_size: 256 # number of samples per training batch
14-
num_epochs: 60
13+
global_batch_size: 16 # number of samples per training batch
14+
num_epochs: 25
1515
amp_mode: none
16-
enable_apex: False
17-
enable_jit: False
16+
enable_fused: false
17+
enable_jit: false
1818
expdir: '/logs'
1919
lr_schedule: 'cosine'
2020
lr: 5E-4
21+
warmup: 0
2122

2223
# Data
2324
data_loader_config: 'pytorch'
24-
num_data_workers: 2 # number of dataloader worker threads per proc
25+
num_data_workers: 0 # number of dataloader worker threads per proc
2526
n_in_channels: 20
2627
n_out_channels: 20
2728
train_data_path: '/data/train'
@@ -30,28 +31,52 @@ base: &base
3031
time_means_path: '/data/stats/time_means.npy'
3132
global_means_path: '/data/stats/global_means.npy'
3233
global_stds_path: '/data/stats/global_stds.npy'
34+
limit_nsamples: None
35+
limit_nsamples_val: None
3336

3437
# Comms
3538
wireup_info: env
3639
wireup_store: tcp
3740

38-
short_noopt:
39-
<<: *base
40-
num_epochs: 10
41-
num_data_workers: 0
42-
global_batch_size: 64
43-
4441
short: &short
4542
<<: *base
4643
num_epochs: 10
4744
num_data_workers: 8
4845
global_batch_size: 64
4946
embed_dim: 1024
5047

48+
short_limitsamples: &short_ls
49+
<<: *base
50+
limit_nsamples: 512
51+
limit_nsamples_val: 128
52+
num_epochs: 4
53+
5154
# Short config with full optimizations
5255
short_opt:
53-
<<: *short
54-
global_batch_size: 16
56+
<<: *short_ls
57+
global_batch_size: 64
58+
data_loader_config: 'dali'
59+
num_data_workers: 8
60+
amp_mode: bf16
61+
enable_jit: true
62+
enable_fused: true
63+
64+
65+
# Model parallel configs
66+
short_mp: &short_mp
67+
<<: *short_ls
68+
global_batch_size: 64
5569
data_loader_config: 'dali'
5670
num_data_workers: 8
5771
embed_dim: 1024
72+
amp_mode: bf16
73+
enable_jit: true
74+
enable_fused: true
75+
76+
mp:
77+
<<: *short_mp
78+
global_batch_size: 64
79+
limit_nsamples: None
80+
limit_nsamples_val: None
81+
num_epochs: 20
82+
data_loader_config: 'dali'

interactive_train.sh

+4-6
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ mkdir -p ${LOGDIR}
99

1010
ngpu=4
1111
config_file=./config/ViT.yaml
12-
config="short_opt"
13-
run_num="testgraphs"
14-
amp_mode="fp16"
12+
config="short_mp"
13+
run_num="0"
1514
col_parallel_size=1
16-
row_parallel_size=2
17-
local_batch_size=16
15+
row_parallel_size=1
1816
#cmd="python train.py --amp_mode=$amp_mode --yaml_config=$config_file --config=$config --run_num=$run_num --local_batch_size=$local_batch_size"
19-
cmd="python train_mp_graphs.py --local_batch_size=$local_batch_size --row_parallel_size=$row_parallel_size --col_parallel_size=$col_parallel_size --amp_mode=$amp_mode --yaml_config=$config_file --config=$config --run_num=$run_num"
17+
cmd="python train_mp.py --row_parallel_size=$row_parallel_size --col_parallel_size=$col_parallel_size --yaml_config=$config_file --config=$config --run_num=$run_num"
2018

2119

2220
srun -n $ngpu --cpus-per-task=32 --gpus-per-node $ngpu shifter --image=${image} --module=gpu,nccl-2.18 -V ${DATADIR}:/data -V ${LOGDIR}:/logs bash -c "source export_DDP_vars.sh && $cmd"
2.4 MB
Binary file not shown.
4.61 MB
Binary file not shown.

sample_nsys_profiles/dali.nsys-rep

2.69 MB
Binary file not shown.
2.73 MB
Binary file not shown.
Binary file not shown.

start_tensorboard.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
},
3939
"outputs": [],
4040
"source": [
41-
"log_dir = os.path.expandvars('${SCRATCH}/sc23-dl-tutorial/logs/short_opt/4GPU/bs16_graphs')"
41+
"log_dir = os.path.expandvars('${SCRATCH}/sc23-dl-tutorial/logs/short_mp/4MP/bs64')"
4242
]
4343
},
4444
{
@@ -65,7 +65,7 @@
6565
{
6666
"data": {
6767
"text/html": [
68-
"<a href=\"https://jupyter.nersc.gov/user/shas1693/perlmutter-login-node-base/proxy/33151/\">https://jupyter.nersc.gov/user/shas1693/perlmutter-login-node-base/proxy/33151/</a>"
68+
"<a href=\"https://jupyter.nersc.gov/user/shas1693/perlmutter-login-node-base/proxy/35199/\">https://jupyter.nersc.gov/user/shas1693/perlmutter-login-node-base/proxy/35199/</a>"
6969
],
7070
"text/plain": [
7171
"<IPython.core.display.HTML object>"

submit_pm.sh

+8-15
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
11
#!/bin/bash
2-
#SBATCH -C gpu
3-
#SBATCH --nodes=1
4-
#SBATCH -q regular
5-
#SBATCH -A nstaff
2+
#SBATCH -C 'gpu&hbm80g'
63
#SBATCH --ntasks-per-node 4
74
#SBATCH --cpus-per-task 32
85
#SBATCH --gpus-per-node 4
9-
#SBATCH --time=03:00:00
6+
#SBATCH --time=06:00:00
107
#SBATCH --image=nersc/pytorch:ngc-23.04-v0
118
#SBATCH --module=gpu,nccl-2.18
12-
#SBATCH -J vit-era5
9+
#SBATCH -J vit-era5-dp-ng
1310
#SBATCH -o %x-%j.out
1411

1512
DATADIR=/pscratch/sd/s/shas1693/data/sc23_tutorial_data/downsampled
1613
LOGDIR=${SCRATCH}/sc23-dl-tutorial/logs
1714
mkdir -p ${LOGDIR}
1815

1916
config_file=./config/ViT.yaml
20-
config="short_opt"
21-
run_num="bs16_graphs"
22-
suffix="_graphs"
23-
#suffix=""
24-
amp_mode="fp16"
17+
config="mp"
18+
run_num="bs64"
19+
suffix=""
2520
col_parallel_size=1
26-
row_parallel_size=4
27-
args="--col_parallel_size=$col_parallel_size --row_parallel_size=$row_parallel_size --amp_mode=$amp_mode --yaml_config=$config_file --config=$config --run_num=$run_num"
28-
21+
row_parallel_size=1
22+
args="--col_parallel_size=$col_parallel_size --row_parallel_size=$row_parallel_size --yaml_config=$config_file --config=$config --run_num=$run_num"
2923

3024
export FI_MR_CACHE_MONITOR=userfaultfd
31-
export NCCL_NET_GDR_LEVEL=PHB
3225
export HDF5_USE_FILE_LOCKING=FALSE
3326

3427
# Profiling

test_model_dims.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from utils.YParams import YParams
44
from torchinfo import summary
55

6-
params = YParams('./config/ViT.yaml', 'base')
7-
params.device = 'gpu'
6+
params = YParams('./config/ViT.yaml', 'short_mp')
87
model = ViT(params)
9-
summary(model, input_size=(1,20,360,720))
8+
summary(model, input_size=(16,20,360,720))

train.py

+31-39
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,9 @@
1919
from utils import get_data_loader_distributed
2020
from utils.loss import l2_loss, l2_loss_opt
2121
from utils.metrics import weighted_rmse
22+
from utils.plots import generate_images
2223
from networks import vit
2324

24-
import apex.optimizers as aoptim
25-
26-
def compute_grad_norm(p_list, device):
27-
norm_type = 2.0
28-
grads = [p.grad for p in p_list if p.grad is not None]
29-
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
30-
return total_norm
31-
32-
def compute_parameter_norm(p_list, device):
33-
norm_type = 2.0
34-
total_norm = torch.norm(torch.stack([torch.norm(p.detach(), norm_type).to(device) for p in p_list]), norm_type)
35-
return total_norm
36-
3725
def train(params, args, local_rank, world_rank, world_size):
3826
# set device and benchmark mode
3927
torch.backends.cudnn.benchmark = True
@@ -48,6 +36,9 @@ def train(params, args, local_rank, world_rank, world_size):
4836

4937
# create model
5038
model = vit.ViT(params).to(device)
39+
40+
if params.enable_jit:
41+
model = torch.compile(model)
5142

5243
if params.amp_dtype == torch.float16:
5344
scaler = GradScaler()
@@ -61,9 +52,8 @@ def train(params, args, local_rank, world_rank, world_size):
6152
model = DistributedDataParallel(model, device_ids=[local_rank],
6253
bucket_cap_mb=args.bucket_cap_mb)
6354

64-
if params.enable_apex:
65-
optimizer = aoptim.FusedAdam(model.parameters(), lr = params.lr,
66-
adam_w_mode=False, set_grad_none=True)
55+
if params.enable_fused:
56+
optimizer = optim.Adam(model.parameters(), lr = params.lr, fused=True)
6757
else:
6858
optimizer = optim.Adam(model.parameters(), lr = params.lr)
6959

@@ -74,14 +64,14 @@ def train(params, args, local_rank, world_rank, world_size):
7464
startEpoch = 0
7565

7666
if params.lr_schedule == 'cosine':
77-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_epochs, last_epoch=startEpoch-1)
67+
if params.warmup > 0:
68+
lr_scale = lambda x: min(params.lr*((x+1)/params.warmup), 0.5*params.lr*(1 + np.cos(np.pi*x/params.num_epochs)))
69+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale)
70+
else:
71+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_epochs, last_epoch=startEpoch-1)
7872
else:
7973
scheduler = None
8074

81-
if params.enable_jit:
82-
model_handle = model.module if (params.distributed and not args.noddp) else model
83-
model_handle = torch.jit.script(model_handle)
84-
8575
# select loss function
8676
if params.enable_jit:
8777
loss_func = l2_loss_opt
@@ -124,51 +114,48 @@ def train(params, args, local_rank, world_rank, world_size):
124114
model.train()
125115
step_count = 0
126116
for i, data in enumerate(train_data_loader, 0):
127-
if (args.enable_manual_profiling and world_rank==0):
117+
if world_rank == 0:
128118
if (epoch == 3 and i == 0):
129119
torch.cuda.profiler.start()
130-
if (epoch == 3 and i == 59):
120+
if (epoch == 3 and i == len(train_data_loader) - 1):
131121
torch.cuda.profiler.stop()
132122

133-
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"step {i}")
123+
torch.cuda.nvtx.range_push(f"step {i}")
134124
iters += 1
135125
dat_start = time.time()
136-
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"data copy in {i}")
126+
torch.cuda.nvtx.range_push(f"data copy in {i}")
137127

138128
inp, tar = map(lambda x: x.to(device), data)
139-
if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # copy in
129+
torch.cuda.nvtx.range_pop() # copy in
140130

141131
tr_start = time.time()
142132
b_size = inp.size(0)
143133

144134
optimizer.zero_grad()
145135

146-
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"forward")
136+
torch.cuda.nvtx.range_push(f"forward")
147137
with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype):
148138
gen = model(inp)
149139
loss = loss_func(gen, tar)
150-
if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() #forward
140+
torch.cuda.nvtx.range_pop() #forward
151141

152142
if params.amp_dtype == torch.float16:
153143
scaler.scale(loss).backward()
154-
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer")
144+
torch.cuda.nvtx.range_push(f"optimizer")
155145
scaler.step(optimizer)
156-
if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer
146+
torch.cuda.nvtx.range_pop() # optimizer
157147
scaler.update()
158148
else:
159149
loss.backward()
160-
if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer")
150+
torch.cuda.nvtx.range_push(f"optimizer")
161151
optimizer.step()
162-
if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer
152+
torch.cuda.nvtx.range_pop() # optimizer
163153

164154
if params.distributed:
165155
torch.distributed.all_reduce(loss)
166156
tr_loss.append(loss.item()/world_size)
167157

168-
if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # step
169-
170-
# g_norm = compute_grad_norm(model.parameters(), device)
171-
# p_norm = compute_parameter_norm(model.parameters(), device)
158+
torch.cuda.nvtx.range_pop() # step
172159

173160
tr_end = time.time()
174161
tr_time += tr_end - tr_start
@@ -187,6 +174,8 @@ def train(params, args, local_rank, world_rank, world_size):
187174
args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters)
188175
args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters)
189176
args.tboard_writer.add_scalar('Avg iters per sec', step_count/(end-start), iters)
177+
fig = generate_images([inp, tar, gen])
178+
args.tboard_writer.add_figure('Visualization, t2m', fig, iters, close=True)
190179

191180
val_start = time.time()
192181
val_loss = []
@@ -227,12 +216,12 @@ def train(params, args, local_rank, world_rank, world_size):
227216
parser.add_argument("--yaml_config", default='./config/ViT.yaml', type=str, help='path to yaml file containing training configs')
228217
parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file')
229218
parser.add_argument("--amp_mode", default='none', type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode')
230-
parser.add_argument("--enable_apex", action='store_true', help='enable apex fused Adam optimizer')
219+
parser.add_argument("--enable_fused", action='store_true', help='enable fused Adam optimizer')
231220
parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation')
232-
parser.add_argument("--enable_manual_profiling", action='store_true', help='enable manual nvtx ranges and profiler start/stop calls')
233221
parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)')
234222
parser.add_argument("--num_epochs", default=None, type=int, help='number of epochs to run')
235223
parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader')
224+
parser.add_argument("--data_loader_config", default=None, type=str, choices=['pytorch', 'dali'], help="dataloader configuration. choices: 'pytorch', 'dali'")
236225
parser.add_argument("--bucket_cap_mb", default=25, type=int, help='max message bucket size in mb')
237226
parser.add_argument("--disable_broadcast_buffers", action='store_true', help='disable syncing broadcasting buffers')
238227
parser.add_argument("--noddp", action='store_true', help='disable DDP communication')
@@ -253,9 +242,12 @@ def train(params, args, local_rank, world_rank, world_size):
253242
amp_dtype = torch.bfloat16
254243
params.update({"amp_enabled": amp_dtype is not torch.float32,
255244
"amp_dtype" : amp_dtype,
256-
"enable_apex" : args.enable_apex,
245+
"enable_fused" : args.enable_fused,
257246
"enable_jit" : args.enable_jit
258247
})
248+
249+
if args.data_loader_config:
250+
params.update({"data_loader_config" : args.data_loader_config})
259251

260252
if args.num_epochs:
261253
params.update({"num_epochs" : args.num_epochs})

0 commit comments

Comments
 (0)