-
Notifications
You must be signed in to change notification settings - Fork 13
/
model.py
322 lines (266 loc) · 14.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import numpy as np
import tensorflow as tf
import sys
sys.path.append('submodule/')
from mytools import *
import rnn_wrapper as wrapper
def _attention(params, memory, memory_length):
if params['attn'] == 'bahdanau':
return tf.contrib.seq2seq.BahdanauAttention(
params['hidden_size'] * 2,
memory,
memory_length)
elif params['attn'] == 'normed_bahdanau':
return tf.contrib.seq2seq.BahdanauAttention(
params['hidden_size'] * 2,
memory,
memory_length,
normalize = True)
elif params['attn'] == 'luong':
return tf.contrib.seq2seq.LuongAttention(
params['hidden_size'] * 2,
memory,
memory_length)
elif params['attn'] == 'scaled_luong':
return tf.contrib.seq2seq.LuongAttention(
params['hidden_size'] * 2,
memory,
memory_length,
scale = True)
else:
raise ValueError('Unknown attention mechanism : %s' %params['attn'])
def q_generation(features, labels, mode, params):
dtype = params['dtype']
hidden_size = params['hidden_size']
voca_size = params['voca_size']
sentence = features['s'] # [batch, length]
len_s = features['len_s']
answer = features['a']
len_a = features['len_a']
beam_width = params['beam_width']
length_penalty_weight = params['length_penalty_weight']
# batch_size should not be specified
# if fixed, then the redundant eval_data will make error
# it may related to bug of tensorflow api
batch_size = tf.shape(sentence)[0]
if mode != tf.estimator.ModeKeys.PREDICT:
question = tf.cast(features['q'], tf.int32) # label
len_q = tf.cast(features['len_q'], tf.int32)
else:
question = None
len_q = None
# Embedding for sentence, question and rnn encoding of sentence
with tf.variable_scope('SharedScope'):
# Embedded inputs
# Same name == embedding sharing
embd_s = embed_op(sentence, params, name = 'embedding')
embd_a = embed_op(answer, params, name = 'embedding')
if question is not None:
embd_q = embed_op(question, params, name = 'embedding')
# Build encoder cell
def lstm_cell_enc():
cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
return tf.contrib.rnn.DropoutWrapper(cell,
input_keep_prob = 1 - params['rnn_dropout'] if mode == tf.estimator.ModeKeys.TRAIN else 1)
def lstm_cell_dec():
cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size * 2)
return tf.contrib.rnn.DropoutWrapper(cell,
input_keep_prob = 1 - params['rnn_dropout'] if mode == tf.estimator.ModeKeys.TRAIN else 1)
encoder_cell_fw = lstm_cell_enc() if params['encoder_layer'] == 1 else tf.nn.rnn_cell.MultiRNNCell([lstm_cell_enc() for _ in range(params['encoder_layer'])])
encoder_cell_bw = lstm_cell_enc() if params['encoder_layer'] == 1 else tf.nn.rnn_cell.MultiRNNCell([lstm_cell_enc() for _ in range(params['encoder_layer'])])
encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn(
encoder_cell_fw,
encoder_cell_bw,
inputs = embd_s,
sequence_length = len_s,
dtype = dtype)
encoder_outputs = tf.concat(encoder_outputs, -1)
if params['encoder_layer'] == 1:
encoder_state_c = tf.concat([encoder_state[0].c, encoder_state[1].c], axis = 1)
encoder_state_h = tf.concat([encoder_state[0].h, encoder_state[1].h], axis = 1)
encoder_state = tf.contrib.rnn.LSTMStateTuple(c = encoder_state_c, h = encoder_state_h)
else:
_encoder_state = list()
for state_fw, state_bw in zip(encoder_state[0], encoder_state[1]):
partial_state_c = tf.concat([state_fw.c, state_bw.c], axis = 1)
partial_state_h = tf.concat([state_fw.h, state_bw.h], axis = 1)
partial_state = tf.contrib.rnn.LSTMStateTuple(c = partial_state_c, h = partial_state_h)
_encoder_state.append(partial_state)
encoder_state = tuple(_encoder_state)
answer_cell_fw = lstm_cell_enc() if params['answer_layer'] == 1 else tf.nn.rnn_cell.MultiRNNCell([lstm_cell_enc() for _ in range(params['answer_layer'])])
answer_cell_bw = lstm_cell_enc() if params['answer_layer'] == 1 else tf.nn.rnn_cell.MultiRNNCell([lstm_cell_enc() for _ in range(params['answer_layer'])])
answer_outputs, answer_state = tf.nn.bidirectional_dynamic_rnn(
answer_cell_fw,
answer_cell_bw,
inputs = embd_a,
sequence_length = len_a,
dtype = dtype,
scope = 'answer_scope')
answer_outputs = tf.concat(answer_outputs, -1)
if params['answer_layer'] == 1:
answer_state_c = tf.concat([answer_state[0].c, answer_state[1].c], axis = 1)
answer_state_h = tf.concat([answer_state[0].h, answer_state[1].h], axis = 1)
answer_state = tf.contrib.rnn.LSTMStateTuple(c = answer_state_c, h = answer_state_h)
else:
_answer_state = list()
for state_fw, state_bw in zip(answer_state[0], answer_state[1]):
partial_state_c = tf.concat([state_fw.c, state_bw.c], axis = 1)
partial_state_h = tf.concat([state_fw.h, state_bw.h], axis = 1)
partial_state = tf.contrib.rnn.LSTMStateTuple(c = partial_state_c, h = partial_state_h)
_answer_state.append(partial_state)
answer_state = tuple(_answer_state)
if params['dec_init_ans'] and params['decoder_layer'] == params['answer_layer']:
copy_state = answer_state
elif params['encoder_layer'] == params['decoder_layer']:
copy_state = encoder_state
else:
copy_state = None
if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, beam_width)
len_s = tf.contrib.seq2seq.tile_batch(len_s, beam_width)
if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0 and copy_state is not None:
copy_state = tf.contrib.seq2seq.tile_batch(copy_state, beam_width)
# This part should be moved into QuestionGeneration scope
with tf.variable_scope('SharedScope/EmbeddingScope', reuse = True):
embedding_q = tf.get_variable('embedding')
# Rnn decoding of sentence with attention
with tf.variable_scope('QuestionGeneration'):
# Memory for attention
attention_states = encoder_outputs
# Create an attention mechanism
attention_mechanism = _attention(params, attention_states, len_s)
# Build decoder cell
decoder_cell = lstm_cell_dec() if params['decoder_layer'] == 1 else tf.nn.rnn_cell.MultiRNNCell([lstm_cell_dec() for _ in range(params['decoder_layer'])])
if params['use_keyword'] > 0:
#cell_input_fn = lambda inputs, attention : tf.concat([inputs, attention, o_s], -1)
def cell_input_fn(inputs, attention):
if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
last_attention = attention[0::beam_width]
else:
last_attention = attention
last_attention = tf.expand_dims(last_attention, 2)
m = answer_outputs
p_s = tf.nn.softmax(tf.matmul(m, last_attention), name = 'p_s')
o_s = tf.reduce_sum(p_s * m, axis = 1)
if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
o_s = tf.contrib.seq2seq.tile_batch(o_s, beam_width)
return tf.concat([inputs, o_s], -1)
else:
cell_input_fn = None
#def cell_input_fn(inputs, attention):
# global o_s
# if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
# inputs = inputs[0::beam_width]
# o_s = tf.layers.dense(tf.concat([inputs, o_s], -1), params['hidden_size'], activation = tf.tanh)
# if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
# o_s = tf.contrib.seq2seq.tile_batch(o_s, beam_width)
# return tf.concat([o_s, attention], -1)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
decoder_cell, attention_mechanism,
attention_layer_size = 2 * hidden_size,
cell_input_fn = cell_input_fn,
initial_cell_state = None)
#initial_cell_state = encoder_state if params['encoder_layer'] == params['decoder_layer'] else None)
if params['if_wean']:
decoder_cell = wrapper.WeanWrapper(decoder_cell, embedding_q)
else:
decoder_cell = tf.contrib.rnn.OutputProjectionWrapper(decoder_cell, voca_size)
# Helper for decoder cell
if mode == tf.estimator.ModeKeys.TRAIN:
helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
inputs = embd_q,
sequence_length = len_q,
embedding = embedding_q,
sampling_probability = 0.25)
else: # EVAL & TEST
start_token = params['start_token'] * tf.ones([batch_size], dtype = tf.int32)
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding_q, start_token, params['end_token']
)
# Decoder
if mode != tf.estimator.ModeKeys.PREDICT or beam_width == 0:
initial_state = decoder_cell.zero_state(dtype = dtype, batch_size = batch_size)
if copy_state is not None:
initial_state = initial_state.clone(cell_state = copy_state)
decoder = tf.contrib.seq2seq.BasicDecoder(
decoder_cell, helper, initial_state,
output_layer=None)
else:
initial_state = decoder_cell.zero_state(dtype = dtype, batch_size = batch_size * beam_width)
if copy_state is not None:
initial_state = initial_state.clone(cell_state = copy_state)
decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell = decoder_cell,
embedding = embedding_q,
start_tokens = start_token,
end_token = params['end_token'],
initial_state = initial_state,
beam_width = beam_width,
length_penalty_weight = length_penalty_weight)
# Dynamic decoding
if mode == tf.estimator.ModeKeys.TRAIN:
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations = None)
logits_q = outputs.rnn_output
softmax_q = tf.nn.softmax(logits_q)
predictions_q = tf.argmax(softmax_q, axis = -1)
elif mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished = False, maximum_iterations = params['maxlen_q_test'])
predictions_q = outputs.predicted_ids # [batch, length, beam_width]
predictions_q = tf.transpose(predictions_q, [0, 2, 1]) # [batch, beam_width, length]
predictions_q = predictions_q[:, 0, :] # [batch, length]
else:
max_iter = params['maxlen_q_test'] if mode == tf.estimator.ModeKeys.PREDICT else params['maxlen_q_dev']
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations = max_iter)
logits_q = outputs.rnn_output
softmax_q = tf.nn.softmax(logits_q)
predictions_q = tf.argmax(softmax_q, axis = -1)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode = mode,
predictions = {
'question' : predictions_q
})
# Loss
label_q = tf.concat([question[:,1:], tf.zeros([batch_size, 1], dtype = tf.int32)], axis = 1, name = 'label_q')
maxlen_q = params['maxlen_q_train'] if mode == tf.estimator.ModeKeys.TRAIN else params['maxlen_q_dev']
current_length = tf.shape(logits_q)[1]
def concat_padding():
num_pad = maxlen_q - current_length
padding = tf.zeros([batch_size, num_pad, voca_size], dtype = dtype)
return tf.concat([logits_q, padding], axis = 1)
def slice_to_maxlen():
return tf.slice(logits_q, [0,0,0], [batch_size, maxlen_q, voca_size])
logits_q = tf.cond(current_length < maxlen_q,
concat_padding,
slice_to_maxlen)
weight_q = tf.sequence_mask(len_q, maxlen_q, dtype)
loss_q = tf.contrib.seq2seq.sequence_loss(
logits_q,
label_q,
weight_q, # [batch, length]
average_across_timesteps = True,
average_across_batch = True,
softmax_loss_function = None # default : sparse_softmax_cross_entropy
)
loss = loss_q
# eval_metric for estimator
eval_metric_ops = {
'bleu' : bleu_score(label_q, predictions_q)
}
# Summary
tf.summary.scalar('loss_question', loss_q)
tf.summary.scalar('total_loss', loss)
# Optimizer
learning_rate = params['learning_rate']
if params['decay_step'] is not None:
learning_rate = tf.train.exponential_decay(learning_rate, tf.train.get_global_step(), params['decay_step'], params['decay_rate'], staircase = True)
optimizer = tf.train.AdamOptimizer(learning_rate)
grad_and_var = optimizer.compute_gradients(loss, tf.trainable_variables())
grad, var = zip(*grad_and_var)
#clipped_grad, norm = tf.clip_by_global_norm(grad, 5)
train_op = optimizer.apply_gradients(zip(grad, var), global_step = tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)