-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess_samples.py
307 lines (229 loc) · 10.8 KB
/
process_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
from __future__ import print_function
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
import sys
import re
import argparse
import torch
from util import read_corpus
import numpy as np
from scipy.misc import comb
from vocab import Vocab, VocabEntry
import math
def is_valid_sample(sent):
tokens = sent.split(' ')
return len(tokens) >= 1 and len(tokens) < 50
def sample_from_model(args):
para_data = args.parallel_data
sample_file = args.sample_file
output = args.output
tgt_sent_pattern = re.compile('^\[(\d+)\] (.*?)$')
para_data = [l.strip().split(' ||| ') for l in open(para_data)]
f_out = open(output, 'w')
f = open(sample_file)
f.readline()
for src_sent, tgt_sent in para_data:
line = f.readline().strip()
assert line.startswith('****')
line = f.readline().strip()
print(line)
assert line.startswith('target:')
tgt_sent2 = line[len('target:'):]
assert tgt_sent == tgt_sent2
line = f.readline().strip() # samples
tgt_sent = ' '.join(tgt_sent.split(' ')[1:-1])
tgt_samples = set()
for i in xrange(1, 101):
line = f.readline().rstrip('\n')
m = tgt_sent_pattern.match(line)
assert m, line
assert int(m.group(1)) == i
sampled_tgt_sent = m.group(2).strip()
if is_valid_sample(sampled_tgt_sent):
tgt_samples.add(sampled_tgt_sent)
line = f.readline().strip()
assert line.startswith('****')
tgt_samples.add(tgt_sent)
tgt_samples = list(tgt_samples)
assert len(tgt_samples) > 0
tgt_ref_tokens = tgt_sent.split(' ')
bleu_scores = []
for tgt_sample in tgt_samples:
bleu_score = sentence_bleu([tgt_ref_tokens], tgt_sample.split(' '))
bleu_scores.append(bleu_score)
tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True)
print('%d samples' % len(tgt_samples))
print('*' * 50, file=f_out)
print('source: ' + src_sent, file=f_out)
print('%d samples' % len(tgt_samples), file=f_out)
for i in tgt_ranks:
print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out)
print('*' * 50, file=f_out)
f_out.close()
def get_new_ngram(ngram, n, vocab):
"""
replace ngram `ngram` with a newly sampled ngram of the same length
"""
new_ngram_wids = [np.random.randint(3, len(vocab)) for i in xrange(n)]
new_ngram = [vocab.id2word[wid] for wid in new_ngram_wids]
return new_ngram
def sample_ngram(args):
src_sents = read_corpus(args.src, 'src')
tgt_sents = read_corpus(args.tgt, 'src') # do not read in <s> and </s>
f_out = open(args.output, 'w')
vocab = torch.load(args.vocab)
tgt_vocab = vocab.tgt
smooth_bleu = args.smooth_bleu
sm_func = None
if smooth_bleu:
sm_func = SmoothingFunction().method3
for src_sent, tgt_sent in zip(src_sents, tgt_sents):
src_sent = ' '.join(src_sent)
tgt_len = len(tgt_sent)
tgt_samples = []
tgt_samples_distort_rates = [] # how many unigrams are replaced
# generate 100 samples
# append itself
tgt_samples.append(tgt_sent)
tgt_samples_distort_rates.append(0)
for sid in xrange(args.sample_size - 1):
n = np.random.randint(1, min(tgt_len, args.max_ngram_size + 1)) # we do not replace the last token: it must be a period!
idx = np.random.randint(tgt_len - n)
ngram = tgt_sent[idx: idx+n]
new_ngram = get_new_ngram(ngram, n, tgt_vocab)
sampled_tgt_sent = list(tgt_sent)
sampled_tgt_sent[idx: idx+n] = new_ngram
# compute the probability of this sample
# prob = 1. / args.max_ngram_size * 1. / (tgt_len - 1 + n) * 1 / (len(tgt_vocab) ** n)
tgt_samples.append(sampled_tgt_sent)
tgt_samples_distort_rates.append(n)
# compute bleu scores or edit distances and rank the samples by bleu scores
rewards = []
for tgt_sample, tgt_sample_distort_rate in zip(tgt_samples, tgt_samples_distort_rates):
if args.reward == 'bleu':
reward = sentence_bleu([tgt_sent], tgt_sample, smoothing_function=sm_func)
else:
reward = -tgt_sample_distort_rate
rewards.append(reward)
tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: rewards[i], reverse=True)
# convert list of tokens into a string
tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples]
print('*' * 50, file=f_out)
print('source: ' + src_sent, file=f_out)
print('%d samples' % len(tgt_samples), file=f_out)
for i in tgt_ranks:
print('%s ||| %f' % (tgt_samples[i], rewards[i]), file=f_out)
print('*' * 50, file=f_out)
f_out.close()
def sample_ngram_adapt(args):
src_sents = read_corpus(args.src, 'src')
tgt_sents = read_corpus(args.tgt, 'src') # do not read in <s> and </s>
f_out = open(args.output, 'w')
vocab = torch.load(args.vocab)
tgt_vocab = vocab.tgt
max_len = max([len(tgt_sent) for tgt_sent in tgt_sents]) + 1
for src_sent, tgt_sent in zip(src_sents, tgt_sents):
src_sent = ' '.join(src_sent)
tgt_len = len(tgt_sent)
tgt_samples = []
# generate 100 samples
# append itself
tgt_samples.append(tgt_sent)
for sid in xrange(args.sample_size - 1):
max_n = min(tgt_len - 1, 4)
bias_n = int(max_n * tgt_len / max_len) + 1
assert 1 <= bias_n <= 4, 'bias_n={}, not in [1,4], max_n={}, tgt_len={}, max_len={}'.format(bias_n, max_n, tgt_len, max_len)
p = [1.0/(max_n + 5)] * max_n
p[bias_n - 1] = 1 - p[0] * (max_n - 1)
assert abs(sum(p) - 1) < 1e-10, 'sum(p) != 1'
n = np.random.choice(np.arange(1, int(max_n + 1)), p=p) # we do not replace the last token: it must be a period!
assert n < tgt_len, 'n={}, tgt_len={}'.format(n, tgt_len)
idx = np.random.randint(tgt_len - n)
ngram = tgt_sent[idx: idx+n]
new_ngram = get_new_ngram(ngram, n, tgt_vocab)
sampled_tgt_sent = list(tgt_sent)
sampled_tgt_sent[idx: idx+n] = new_ngram
tgt_samples.append(sampled_tgt_sent)
# compute bleu scores and rank the samples by bleu scores
bleu_scores = []
for tgt_sample in tgt_samples:
bleu_score = sentence_bleu([tgt_sent], tgt_sample)
bleu_scores.append(bleu_score)
tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True)
# convert list of tokens into a string
tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples]
print('*' * 50, file=f_out)
print('source: ' + src_sent, file=f_out)
print('%d samples' % len(tgt_samples), file=f_out)
for i in tgt_ranks:
print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out)
print('*' * 50, file=f_out)
f_out.close()
def sample_from_hamming_distance_payoff_distribution(args):
src_sents = read_corpus(args.src, 'src')
tgt_sents = read_corpus(args.tgt, 'src') # do not read in <s> and </s>
f_out = open(args.output, 'w')
vocab = torch.load(args.vocab)
tgt_vocab = vocab.tgt
payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(max(len(sent) for sent in tgt_sents),
vocab_size=len(vocab.tgt),
tau=args.temp)
for src_sent, tgt_sent in zip(src_sents, tgt_sents):
tgt_samples = [] # make sure the ground truth y* is in the samples
tgt_sent_len = len(tgt_sent) - 3 # remove <s> and </s> and ending period .
tgt_ref_tokens = tgt_sent[1:-1]
bleu_scores = []
# sample an edit distances
e_samples = np.random.choice(range(tgt_sent_len + 1), p=payoff_prob[tgt_sent_len], size=args.sample_size,
replace=True)
for i, e in enumerate(e_samples):
if e > 0:
# sample a new tgt_sent $y$
old_word_pos = np.random.choice(range(1, tgt_sent_len + 1), size=e, replace=False)
new_words = [vocab.tgt.id2word[wid] for wid in np.random.randint(3, len(vocab.tgt), size=e)]
new_tgt_sent = list(tgt_sent)
for pos, word in zip(old_word_pos, new_words):
new_tgt_sent[pos] = word
bleu_score = sentence_bleu([tgt_ref_tokens], new_tgt_sent[1:-1])
bleu_scores.append(bleu_score)
else:
new_tgt_sent = list(tgt_sent)
bleu_scores.append(1.)
# print('y: %s' % ' '.join(new_tgt_sent))
tgt_samples.append(new_tgt_sent)
def generate_hamming_distance_payoff_distribution(max_sent_len, vocab_size, tau=1.):
"""compute the q distribution for Hamming Distance (substitution only) as in the RAML paper"""
probs = dict()
Z_qs = dict()
for sent_len in xrange(1, max_sent_len + 1):
counts = [1.] # e = 0, count = 1
for e in xrange(1, sent_len + 1):
# apply the rescaling trick as in https://gist.github.com/norouzi/8c4d244922fa052fa8ec18d8af52d366
count = comb(sent_len, e) * math.exp(-e / tau) * ((vocab_size - 1) ** (e - e / tau))
counts.append(count)
Z_qs[sent_len] = Z_q = sum(counts)
prob = [count / Z_q for count in counts]
probs[sent_len] = prob
# print('sent_len=%d, %s' % (sent_len, prob))
return probs, Z_qs
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['sample_from_model', 'sample_ngram_adapt', 'sample_ngram'], required=True)
parser.add_argument('--vocab', type=str)
parser.add_argument('--src', type=str)
parser.add_argument('--tgt', type=str)
parser.add_argument('--parallel_data', type=str)
parser.add_argument('--sample_file', type=str)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--sample_size', type=int, default=100)
parser.add_argument('--reward', choices=['bleu', 'edit_dist'], default='bleu')
parser.add_argument('--max_ngram_size', type=int, default=4)
parser.add_argument('--temp', type=float, default=0.5)
parser.add_argument('--smooth_bleu', action='store_true', default=False)
args = parser.parse_args()
if args.mode == 'sample_ngram':
sample_ngram(args)
elif args.mode == 'sample_from_model':
sample_from_model(args)
elif args.mode == 'sample_ngram_adapt':
sample_ngram_adapt(args)