Skip to content

Commit 36dd712

Browse files
committed
dual learning
1 parent 3fc6082 commit 36dd712

File tree

5 files changed

+414
-0
lines changed

5 files changed

+414
-0
lines changed

data.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import torch
3+
import pickle
4+
5+
6+
class Dictionary(object):
7+
def __init__(self):
8+
self.word2idx = {'<unk>': 0}
9+
self.idx2word = ['<unk>']
10+
self.wordcnt = {}
11+
12+
def add_word(self, word):
13+
if word not in self.word2idx:
14+
self.idx2word.append(word)
15+
self.word2idx[word] = len(self.idx2word) - 1
16+
self.wordcnt[word] = 1
17+
else:
18+
self.wordcnt[word] = self.wordcnt[word] + 1
19+
return self.word2idx[word]
20+
21+
def getid(self, word, thresh=10):
22+
if (word not in self.word2idx) or (self.wordcnt[word] < thresh):
23+
return self.word2idx['<unk>']
24+
return self.word2idx[word]
25+
26+
def __len__(self):
27+
return len(self.idx2word)
28+
29+
30+
class Corpus(object):
31+
def __init__(self, path):
32+
self.dictionary = Dictionary()
33+
self.train = self.tokenize(os.path.join(path, 'train.txt'))
34+
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
35+
self.test = self.tokenize(os.path.join(path, 'test.txt'))
36+
37+
with open(os.path.join(path, 'dict.pkl'), 'wb') as f:
38+
pickle.dump(self.dictionary, f)
39+
40+
def tokenize(self, path):
41+
"""Tokenizes a text file."""
42+
assert os.path.exists(path)
43+
# Add words to the dictionary
44+
with open(path, 'r') as f:
45+
tokens = 0
46+
for line in f:
47+
words = ['<sos>'] + line.split() + ['<eos>']
48+
tokens += len(words)
49+
for word in words:
50+
self.dictionary.add_word(word)
51+
52+
# Tokenize file content
53+
with open(path, 'r') as f:
54+
ids = torch.LongTensor(tokens)
55+
token = 0
56+
for line in f:
57+
words = ['<sos>'] + line.split() + ['<eos>']
58+
for word in words:
59+
ids[token] = self.dictionary.getid(word)
60+
token += 1
61+
62+
return ids
63+

dual.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import sys
4+
import torch
5+
import argparse
6+
import random
7+
8+
from torch.autograd import Variable
9+
10+
from nmt import read_corpus, data_iter
11+
from nmt import NMT, to_input_variable
12+
13+
from lm import LMProb
14+
from lm import model
15+
16+
def dual(args):
17+
vocabs = {}
18+
opts = {}
19+
state_dicts = {}
20+
train_srcs = {}
21+
lms = {}
22+
23+
# load model params & training data
24+
print('load modelA from [{:s}]'.format(args.modelA_bin), file=sys.stderr)
25+
params = torch.load(args.modelA_bin, map_location=lambda storage, loc: storage)
26+
vocabs['A'] = params['vocab']
27+
opts['A'] = params['args']
28+
state_dicts['A'] = params['state_dict']
29+
print('load train_srcA from [{:s}]'.format(args.train_srcA), file=sys.stderr)
30+
train_srcs['A'] = read_corpus(args.train_srcA, source='src')
31+
print('load lmA from [{:s}]'.format(args.lmA), file=sys.stderr)
32+
lms['A'] = LMProb(args.lmA, args.lmAdict)
33+
34+
print('load modelB from [{:s}]'.format(args.modelB_bin), file=sys.stderr)
35+
params = torch.load(args.modelB_bin, map_location=lambda storage, loc: storage)
36+
vocabs['B'] = params['vocab']
37+
opts['B'] = params['args']
38+
state_dicts['B'] = params['state_dict']
39+
print('load train_srcB from [{:s}]'.format(args.train_srcB), file=sys.stderr)
40+
train_srcs['B'] = read_corpus(args.train_srcB, source='src')
41+
print('load lmB from [{:s}]'.format(args.lmB), file=sys.stderr)
42+
lms['B'] = LMProb(args.lmB, args.lmBdict)
43+
44+
models = {}
45+
optimizers = {}
46+
47+
for m in ['A', 'B']:
48+
# build model
49+
models[m] = NMT(opts[m], vocabs[m])
50+
models[m].load_state_dict(state_dicts[m])
51+
models[m].train()
52+
models[m] = models[m].cuda()
53+
54+
random.shuffle(train_srcs[m])
55+
56+
# optimizer
57+
optimizers[m] = torch.optim.Adam(models[m].parameters())
58+
59+
# loss function
60+
loss_nll = torch.nn.NLLLoss()
61+
loss_ce = torch.nn.CrossEntropyLoss()
62+
63+
epoch = 0
64+
while True:
65+
if epoch == 2:
66+
break
67+
epoch += 1
68+
print('start of epoch {:d}'.format(epoch))
69+
70+
data = {}
71+
data['A'] = iter(train_srcs['A'])
72+
data['B'] = iter(train_srcs['B'])
73+
74+
for t in range(0, len(train_srcs['A'])):
75+
print('sent', t)
76+
for m in ['A', 'B']:
77+
lm_probs = []
78+
79+
NLL_losses = []
80+
CE_losses = []
81+
82+
modelA = models[m]
83+
modelB = models[change(m)]
84+
lmB = lms[change(m)]
85+
optimizerA = optimizers[m]
86+
optimizerB = optimizers[change(m)]
87+
vocabB = vocabs[change(m)]
88+
s = next(data[m])
89+
90+
hyps = modelA.beam(s, beam_size=5)
91+
92+
for ids, smid, dist in hyps:
93+
var_ids = torch.autograd.Variable(torch.LongTensor(ids[1:]), requires_grad=False)
94+
NLL_losses.append(loss_nll(dist, var_ids).cpu())
95+
96+
lm_probs.append(lmB.get_prob(smid))
97+
98+
src_sent_var = to_input_variable([smid], vocabB.src, cuda=True)
99+
tgt_sent_var = to_input_variable([['<s>'] + s + ['</s>']], vocabB.tgt, cuda=True)
100+
src_sent_len = [len(smid)]
101+
102+
score = modelB(src_sent_var, src_sent_len, tgt_sent_var[:-1]).squeeze(1)
103+
104+
CE_losses.append(loss_ce(score, tgt_sent_var[1:].view(-1)).cpu())
105+
106+
r1_mean = sum(lm_probs) / len(lm_probs)
107+
r1 = [Variable(torch.FloatTensor([p - r1_mean]), requires_grad=False) for p in lm_probs]
108+
109+
r2_mean = sum(CE_losses) / len(CE_losses)
110+
r2 = [Variable(-(l.data - r2_mean.data), requires_grad=False) for l in CE_losses]
111+
112+
rk = [a + b for a, b in zip(r1, r2)]
113+
114+
optimizerA.zero_grad()
115+
optimizerB.zero_grad()
116+
117+
torch.mean(torch.cat(NLL_losses) * torch.cat(rk)).backward()
118+
torch.mean(torch.cat(CE_losses)).backward()
119+
120+
optimizerA.step()
121+
optimizerB.step()
122+
123+
124+
def change(m):
125+
if m == 'A':
126+
return 'B'
127+
else:
128+
return 'A'
129+
130+
if __name__ == '__main__':
131+
parser = argparse.ArgumentParser()
132+
parser.add_argument('modelA_bin')
133+
parser.add_argument('modelB_bin')
134+
parser.add_argument('lmA')
135+
parser.add_argument('lmAdict')
136+
parser.add_argument('lmB')
137+
parser.add_argument('lmBdict')
138+
parser.add_argument('train_srcA')
139+
parser.add_argument('train_srcB')
140+
args = parser.parse_args()
141+
142+
dual(args)
143+

model.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch.nn as nn
2+
from torch.autograd import Variable
3+
4+
5+
class RNNModel(nn.Module):
6+
"""Container module with an encoder, a recurrent module, and a decoder."""
7+
8+
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
9+
super(RNNModel, self).__init__()
10+
self.drop = nn.Dropout(dropout)
11+
self.encoder = nn.Embedding(ntoken, ninp)
12+
self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
13+
self.decoder = nn.Linear(nhid, ntoken)
14+
15+
# Optionally tie weights as in:
16+
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
17+
# https://arxiv.org/abs/1608.05859
18+
# and
19+
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
20+
# https://arxiv.org/abs/1611.01462
21+
if tie_weights:
22+
if nhid != ninp:
23+
raise ValueError('When using the tied flag, nhid must be equal to emsize')
24+
self.decoder.weight = self.encoder.weight
25+
26+
self.init_weights()
27+
28+
self.nhid = nhid
29+
self.nlayers = nlayers
30+
31+
def init_weights(self):
32+
initrange = 0.1
33+
self.encoder.weight.data.uniform_(-initrange, initrange)
34+
self.decoder.bias.data.fill_(0)
35+
self.decoder.weight.data.uniform_(-initrange, initrange)
36+
37+
def forward(self, input, hidden):
38+
emb = self.drop(self.encoder(input))
39+
output, hidden = self.rnn(emb, hidden)
40+
output = self.drop(output)
41+
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
42+
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
43+
44+
def init_hidden(self, bsz):
45+
weight = next(self.parameters()).data
46+
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
47+

util.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from collections import defaultdict
2+
import numpy as np
3+
4+
def read_corpus(file_path, source):
5+
data = []
6+
for line in open(file_path):
7+
sent = line.strip().split(' ')
8+
# only append <s> and </s> to the target sentence
9+
if source == 'tgt':
10+
sent = ['<s>'] + sent + ['</s>']
11+
data.append(sent)
12+
13+
return data
14+
15+
16+
def batch_slice(data, batch_size, sort=True):
17+
batched_data = []
18+
batch_num = int(np.ceil(len(data) / float(batch_size)))
19+
for i in range(batch_num):
20+
cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i
21+
src_sents = [data[i * batch_size + b][0] for b in range(cur_batch_size)]
22+
tgt_sents = [data[i * batch_size + b][1] for b in range(cur_batch_size)]
23+
24+
if sort:
25+
src_ids = sorted(range(cur_batch_size), key=lambda src_id: len(src_sents[src_id]), reverse=True)
26+
src_sents = [src_sents[src_id] for src_id in src_ids]
27+
tgt_sents = [tgt_sents[src_id] for src_id in src_ids]
28+
29+
batched_data.append((src_sents, tgt_sents))
30+
31+
return batched_data
32+
33+
34+
def data_iter(data, batch_size, shuffle=True):
35+
"""
36+
randomly permute data, then sort by source length, and partition into batches
37+
ensure that the length of source sentences in each batch is decreasing
38+
"""
39+
40+
buckets = defaultdict(list)
41+
for pair in data:
42+
src_sent = pair[0]
43+
buckets[len(src_sent)].append(pair)
44+
45+
batched_data = []
46+
for src_len in buckets:
47+
tuples = buckets[src_len]
48+
if shuffle: np.random.shuffle(tuples)
49+
batched_data.extend(batch_slice(tuples, batch_size))
50+
51+
if shuffle:
52+
np.random.shuffle(batched_data)
53+
54+
for batch in batched_data:
55+
yield batch
56+

0 commit comments

Comments
 (0)