Skip to content

Commit b77976c

Browse files
author
Elad Hoffer
committed
measure nll for perplexity
1 parent bfce86f commit b77976c

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

seq2seq/tools/trainer.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@ def forward(self, module_inputs, target):
3636
output = self.module(*module_inputs)
3737
output = output.view(-1, output.size(2))
3838
target = target.view(-1)
39+
output = nn.functional.log_softmax(output, -1)
40+
# make sure criterion is not from_logits
3941
loss = self.criterion(output, target).view(1, 1)
42+
nll = nn.functional.nll_loss(
43+
output, target, ignore_index=self.ignore_index, reduction='sum')
4044
if self.get_accuracy:
4145
_, argmax = output.max(-1)
4246
invalid_targets = target.eq(self.ignore_index)
4347
accuracy = argmax.eq(target).masked_fill_(
4448
invalid_targets, 0).long().sum()
45-
return loss, accuracy.view(1, 1)
49+
return loss, nll, accuracy.view(1, 1)
4650
else:
47-
return loss
51+
return loss, nll
4852

4953

5054
def _chunk_tuple(seq_tuple, num_chunks, batch_first=True):
@@ -116,7 +120,7 @@ def __init__(self, model, regime=None,
116120
super(Seq2SeqTrainer, self).__init__()
117121
self.model = model
118122
self.criterion = criterion or CrossEntropyLoss(
119-
ignore_index=PAD, smooth_eps=label_smoothing, reduction='sum')
123+
ignore_index=PAD, smooth_eps=label_smoothing, reduction='sum', from_logits=False)
120124

121125
self.optimizer = OptimRegime(self.model, regime=regime)
122126
self.grad_clip = grad_clip
@@ -161,6 +165,7 @@ def batch_first(self):
161165
def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batch=1):
162166
loss_measure = 0
163167
accuracy_measure = 0
168+
nll_measure = 0
164169
num_words = 0
165170
if training:
166171
self.optimizer.zero_grad()
@@ -194,7 +199,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
194199
if training:
195200
self.optimizer.pre_forward()
196201
# compute output
197-
loss, accuracy = self.model_with_loss(inputs, target_labels)
202+
loss, nll, accuracy = self.model_with_loss(inputs, target_labels)
198203

199204
loss = loss.sum()
200205
loss_measure += float(loss / num_words)
@@ -203,6 +208,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
203208
else:
204209
loss /= target.size(batch_dim)
205210
accuracy_measure += float(accuracy.sum().float() / num_words)
211+
nll_measure += float(nll.sum() / num_words)
206212

207213
if training:
208214
self.optimizer.pre_backward()
@@ -231,7 +237,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
231237
clip_grad_norm_(self.model.decoder.embedder.parameters(),
232238
self.embedding_grad_clip)
233239
self.optimizer.step()
234-
return loss_measure, accuracy_measure, num_words
240+
return loss_measure, nll_measure, accuracy_measure, num_words
235241

236242
def _feed_data(self, data_loader, num_iterations=None, training=True, chunk_batch=1):
237243
if training:
@@ -261,13 +267,13 @@ def _feed_data(self, data_loader, num_iterations=None, training=True, chunk_batc
261267
# update optimizer according to epoch and steps
262268
self.optimizer.update(self.epoch, self.training_steps)
263269
# do a train/evaluate iteration
264-
loss, acc, num_words = self.iterate(src, target,
265-
training=training,
266-
chunk_batch=chunk_batch)
270+
loss, nll, acc, num_words = self.iterate(src, target,
271+
training=training,
272+
chunk_batch=chunk_batch)
267273

268274
# measure accuracy and record loss
269275
losses.update(loss, num_words)
270-
perplexity.update(math.exp(loss), num_words)
276+
perplexity.update(math.exp(nll), num_words)
271277
accuracy.update(acc, num_words)
272278

273279
# measure elapsed time
@@ -470,8 +476,6 @@ def __init__(self, *kargs, **kwargs):
470476
_, target_tok = self.save_info['tokenizers'].values()
471477
target_words = target_tok.common_words(8188)
472478
self.contrast_batch = batch_nested_sequences(target_words)
473-
import pdb
474-
pdb.set_trace()
475479

476480
def iterate(self, src_tuple, target_tuple, training=True):
477481
# limit number of tokens to avoid gpu overload
@@ -499,7 +503,7 @@ def iterate(self, src_tuple, target_tuple, training=True):
499503
target_labels = target[1:]
500504

501505
# compute output
502-
loss, accuracy = self.model_with_loss(inputs, target_labels)
506+
loss, nll, accuracy = self.model_with_loss(inputs, target_labels)
503507

504508
loss = loss.sum()
505509
loss_measure = float(loss / num_words)

0 commit comments

Comments
 (0)