@@ -36,15 +36,19 @@ def forward(self, module_inputs, target):
36
36
output = self .module (* module_inputs )
37
37
output = output .view (- 1 , output .size (2 ))
38
38
target = target .view (- 1 )
39
+ output = nn .functional .log_softmax (output , - 1 )
40
+ # make sure criterion is not from_logits
39
41
loss = self .criterion (output , target ).view (1 , 1 )
42
+ nll = nn .functional .nll_loss (
43
+ output , target , ignore_index = self .ignore_index , reduction = 'sum' )
40
44
if self .get_accuracy :
41
45
_ , argmax = output .max (- 1 )
42
46
invalid_targets = target .eq (self .ignore_index )
43
47
accuracy = argmax .eq (target ).masked_fill_ (
44
48
invalid_targets , 0 ).long ().sum ()
45
- return loss , accuracy .view (1 , 1 )
49
+ return loss , nll , accuracy .view (1 , 1 )
46
50
else :
47
- return loss
51
+ return loss , nll
48
52
49
53
50
54
def _chunk_tuple (seq_tuple , num_chunks , batch_first = True ):
@@ -116,7 +120,7 @@ def __init__(self, model, regime=None,
116
120
super (Seq2SeqTrainer , self ).__init__ ()
117
121
self .model = model
118
122
self .criterion = criterion or CrossEntropyLoss (
119
- ignore_index = PAD , smooth_eps = label_smoothing , reduction = 'sum' )
123
+ ignore_index = PAD , smooth_eps = label_smoothing , reduction = 'sum' , from_logits = False )
120
124
121
125
self .optimizer = OptimRegime (self .model , regime = regime )
122
126
self .grad_clip = grad_clip
@@ -161,6 +165,7 @@ def batch_first(self):
161
165
def iterate (self , src_tuple_batch , target_tuple_batch , training = True , chunk_batch = 1 ):
162
166
loss_measure = 0
163
167
accuracy_measure = 0
168
+ nll_measure = 0
164
169
num_words = 0
165
170
if training :
166
171
self .optimizer .zero_grad ()
@@ -194,7 +199,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
194
199
if training :
195
200
self .optimizer .pre_forward ()
196
201
# compute output
197
- loss , accuracy = self .model_with_loss (inputs , target_labels )
202
+ loss , nll , accuracy = self .model_with_loss (inputs , target_labels )
198
203
199
204
loss = loss .sum ()
200
205
loss_measure += float (loss / num_words )
@@ -203,6 +208,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
203
208
else :
204
209
loss /= target .size (batch_dim )
205
210
accuracy_measure += float (accuracy .sum ().float () / num_words )
211
+ nll_measure += float (nll .sum () / num_words )
206
212
207
213
if training :
208
214
self .optimizer .pre_backward ()
@@ -231,7 +237,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
231
237
clip_grad_norm_ (self .model .decoder .embedder .parameters (),
232
238
self .embedding_grad_clip )
233
239
self .optimizer .step ()
234
- return loss_measure , accuracy_measure , num_words
240
+ return loss_measure , nll_measure , accuracy_measure , num_words
235
241
236
242
def _feed_data (self , data_loader , num_iterations = None , training = True , chunk_batch = 1 ):
237
243
if training :
@@ -261,13 +267,13 @@ def _feed_data(self, data_loader, num_iterations=None, training=True, chunk_batc
261
267
# update optimizer according to epoch and steps
262
268
self .optimizer .update (self .epoch , self .training_steps )
263
269
# do a train/evaluate iteration
264
- loss , acc , num_words = self .iterate (src , target ,
265
- training = training ,
266
- chunk_batch = chunk_batch )
270
+ loss , nll , acc , num_words = self .iterate (src , target ,
271
+ training = training ,
272
+ chunk_batch = chunk_batch )
267
273
268
274
# measure accuracy and record loss
269
275
losses .update (loss , num_words )
270
- perplexity .update (math .exp (loss ), num_words )
276
+ perplexity .update (math .exp (nll ), num_words )
271
277
accuracy .update (acc , num_words )
272
278
273
279
# measure elapsed time
@@ -470,8 +476,6 @@ def __init__(self, *kargs, **kwargs):
470
476
_ , target_tok = self .save_info ['tokenizers' ].values ()
471
477
target_words = target_tok .common_words (8188 )
472
478
self .contrast_batch = batch_nested_sequences (target_words )
473
- import pdb
474
- pdb .set_trace ()
475
479
476
480
def iterate (self , src_tuple , target_tuple , training = True ):
477
481
# limit number of tokens to avoid gpu overload
@@ -499,7 +503,7 @@ def iterate(self, src_tuple, target_tuple, training=True):
499
503
target_labels = target [1 :]
500
504
501
505
# compute output
502
- loss , accuracy = self .model_with_loss (inputs , target_labels )
506
+ loss , nll , accuracy = self .model_with_loss (inputs , target_labels )
503
507
504
508
loss = loss .sum ()
505
509
loss_measure = float (loss / num_words )
0 commit comments