19
19
from utils import get_data_loader_distributed
20
20
from utils .loss import l2_loss , l2_loss_opt
21
21
from utils .metrics import weighted_rmse
22
+ from utils .plots import generate_images
22
23
from networks import vit
23
24
24
- import apex .optimizers as aoptim
25
-
26
- def compute_grad_norm (p_list , device ):
27
- norm_type = 2.0
28
- grads = [p .grad for p in p_list if p .grad is not None ]
29
- total_norm = torch .norm (torch .stack ([torch .norm (g .detach (), norm_type ).to (device ) for g in grads ]), norm_type )
30
- return total_norm
31
-
32
- def compute_parameter_norm (p_list , device ):
33
- norm_type = 2.0
34
- total_norm = torch .norm (torch .stack ([torch .norm (p .detach (), norm_type ).to (device ) for p in p_list ]), norm_type )
35
- return total_norm
36
-
37
25
def train (params , args , local_rank , world_rank , world_size ):
38
26
# set device and benchmark mode
39
27
torch .backends .cudnn .benchmark = True
@@ -48,6 +36,9 @@ def train(params, args, local_rank, world_rank, world_size):
48
36
49
37
# create model
50
38
model = vit .ViT (params ).to (device )
39
+
40
+ if params .enable_jit :
41
+ model = torch .compile (model )
51
42
52
43
if params .amp_dtype == torch .float16 :
53
44
scaler = GradScaler ()
@@ -61,9 +52,8 @@ def train(params, args, local_rank, world_rank, world_size):
61
52
model = DistributedDataParallel (model , device_ids = [local_rank ],
62
53
bucket_cap_mb = args .bucket_cap_mb )
63
54
64
- if params .enable_apex :
65
- optimizer = aoptim .FusedAdam (model .parameters (), lr = params .lr ,
66
- adam_w_mode = False , set_grad_none = True )
55
+ if params .enable_fused :
56
+ optimizer = optim .Adam (model .parameters (), lr = params .lr , fused = True )
67
57
else :
68
58
optimizer = optim .Adam (model .parameters (), lr = params .lr )
69
59
@@ -74,14 +64,14 @@ def train(params, args, local_rank, world_rank, world_size):
74
64
startEpoch = 0
75
65
76
66
if params .lr_schedule == 'cosine' :
77
- scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = params .num_epochs , last_epoch = startEpoch - 1 )
67
+ if params .warmup > 0 :
68
+ lr_scale = lambda x : min (params .lr * ((x + 1 )/ params .warmup ), 0.5 * params .lr * (1 + np .cos (np .pi * x / params .num_epochs )))
69
+ scheduler = torch .optim .lr_scheduler .LambdaLR (optimizer , lr_scale )
70
+ else :
71
+ scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = params .num_epochs , last_epoch = startEpoch - 1 )
78
72
else :
79
73
scheduler = None
80
74
81
- if params .enable_jit :
82
- model_handle = model .module if (params .distributed and not args .noddp ) else model
83
- model_handle = torch .jit .script (model_handle )
84
-
85
75
# select loss function
86
76
if params .enable_jit :
87
77
loss_func = l2_loss_opt
@@ -124,51 +114,48 @@ def train(params, args, local_rank, world_rank, world_size):
124
114
model .train ()
125
115
step_count = 0
126
116
for i , data in enumerate (train_data_loader , 0 ):
127
- if ( args . enable_manual_profiling and world_rank == 0 ) :
117
+ if world_rank == 0 :
128
118
if (epoch == 3 and i == 0 ):
129
119
torch .cuda .profiler .start ()
130
- if (epoch == 3 and i == 59 ):
120
+ if (epoch == 3 and i == len ( train_data_loader ) - 1 ):
131
121
torch .cuda .profiler .stop ()
132
122
133
- if args . enable_manual_profiling : torch .cuda .nvtx .range_push (f"step { i } " )
123
+ torch .cuda .nvtx .range_push (f"step { i } " )
134
124
iters += 1
135
125
dat_start = time .time ()
136
- if args . enable_manual_profiling : torch .cuda .nvtx .range_push (f"data copy in { i } " )
126
+ torch .cuda .nvtx .range_push (f"data copy in { i } " )
137
127
138
128
inp , tar = map (lambda x : x .to (device ), data )
139
- if args . enable_manual_profiling : torch .cuda .nvtx .range_pop () # copy in
129
+ torch .cuda .nvtx .range_pop () # copy in
140
130
141
131
tr_start = time .time ()
142
132
b_size = inp .size (0 )
143
133
144
134
optimizer .zero_grad ()
145
135
146
- if args . enable_manual_profiling : torch .cuda .nvtx .range_push (f"forward" )
136
+ torch .cuda .nvtx .range_push (f"forward" )
147
137
with autocast (enabled = params .amp_enabled , dtype = params .amp_dtype ):
148
138
gen = model (inp )
149
139
loss = loss_func (gen , tar )
150
- if args . enable_manual_profiling : torch .cuda .nvtx .range_pop () #forward
140
+ torch .cuda .nvtx .range_pop () #forward
151
141
152
142
if params .amp_dtype == torch .float16 :
153
143
scaler .scale (loss ).backward ()
154
- if args . enable_manual_profiling : torch .cuda .nvtx .range_push (f"optimizer" )
144
+ torch .cuda .nvtx .range_push (f"optimizer" )
155
145
scaler .step (optimizer )
156
- if args . enable_manual_profiling : torch .cuda .nvtx .range_pop () # optimizer
146
+ torch .cuda .nvtx .range_pop () # optimizer
157
147
scaler .update ()
158
148
else :
159
149
loss .backward ()
160
- if args . enable_manual_profiling : torch .cuda .nvtx .range_push (f"optimizer" )
150
+ torch .cuda .nvtx .range_push (f"optimizer" )
161
151
optimizer .step ()
162
- if args . enable_manual_profiling : torch .cuda .nvtx .range_pop () # optimizer
152
+ torch .cuda .nvtx .range_pop () # optimizer
163
153
164
154
if params .distributed :
165
155
torch .distributed .all_reduce (loss )
166
156
tr_loss .append (loss .item ()/ world_size )
167
157
168
- if args .enable_manual_profiling : torch .cuda .nvtx .range_pop () # step
169
-
170
- # g_norm = compute_grad_norm(model.parameters(), device)
171
- # p_norm = compute_parameter_norm(model.parameters(), device)
158
+ torch .cuda .nvtx .range_pop () # step
172
159
173
160
tr_end = time .time ()
174
161
tr_time += tr_end - tr_start
@@ -187,6 +174,8 @@ def train(params, args, local_rank, world_rank, world_size):
187
174
args .tboard_writer .add_scalar ('Loss/train' , np .mean (tr_loss ), iters )
188
175
args .tboard_writer .add_scalar ('Learning Rate' , optimizer .param_groups [0 ]['lr' ], iters )
189
176
args .tboard_writer .add_scalar ('Avg iters per sec' , step_count / (end - start ), iters )
177
+ fig = generate_images ([inp , tar , gen ])
178
+ args .tboard_writer .add_figure ('Visualization, t2m' , fig , iters , close = True )
190
179
191
180
val_start = time .time ()
192
181
val_loss = []
@@ -227,12 +216,12 @@ def train(params, args, local_rank, world_rank, world_size):
227
216
parser .add_argument ("--yaml_config" , default = './config/ViT.yaml' , type = str , help = 'path to yaml file containing training configs' )
228
217
parser .add_argument ("--config" , default = 'base' , type = str , help = 'name of desired config in yaml file' )
229
218
parser .add_argument ("--amp_mode" , default = 'none' , type = str , choices = ['none' , 'fp16' , 'bf16' ], help = 'select automatic mixed precision mode' )
230
- parser .add_argument ("--enable_apex " , action = 'store_true' , help = 'enable apex fused Adam optimizer' )
219
+ parser .add_argument ("--enable_fused " , action = 'store_true' , help = 'enable fused Adam optimizer' )
231
220
parser .add_argument ("--enable_jit" , action = 'store_true' , help = 'enable JIT compilation' )
232
- parser .add_argument ("--enable_manual_profiling" , action = 'store_true' , help = 'enable manual nvtx ranges and profiler start/stop calls' )
233
221
parser .add_argument ("--local_batch_size" , default = None , type = int , help = 'local batchsize (manually override global_batch_size config setting)' )
234
222
parser .add_argument ("--num_epochs" , default = None , type = int , help = 'number of epochs to run' )
235
223
parser .add_argument ("--num_data_workers" , default = None , type = int , help = 'number of data workers for data loader' )
224
+ parser .add_argument ("--data_loader_config" , default = None , type = str , choices = ['pytorch' , 'dali' ], help = "dataloader configuration. choices: 'pytorch', 'dali'" )
236
225
parser .add_argument ("--bucket_cap_mb" , default = 25 , type = int , help = 'max message bucket size in mb' )
237
226
parser .add_argument ("--disable_broadcast_buffers" , action = 'store_true' , help = 'disable syncing broadcasting buffers' )
238
227
parser .add_argument ("--noddp" , action = 'store_true' , help = 'disable DDP communication' )
@@ -253,9 +242,12 @@ def train(params, args, local_rank, world_rank, world_size):
253
242
amp_dtype = torch .bfloat16
254
243
params .update ({"amp_enabled" : amp_dtype is not torch .float32 ,
255
244
"amp_dtype" : amp_dtype ,
256
- "enable_apex " : args .enable_apex ,
245
+ "enable_fused " : args .enable_fused ,
257
246
"enable_jit" : args .enable_jit
258
247
})
248
+
249
+ if args .data_loader_config :
250
+ params .update ({"data_loader_config" : args .data_loader_config })
259
251
260
252
if args .num_epochs :
261
253
params .update ({"num_epochs" : args .num_epochs })
0 commit comments