-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
model.py
195 lines (162 loc) · 7.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- coding: utf-8 -*-
# /usr/bin/python3
'''
Feb. 2019 by kyubyong park.
https://www.github.com/kyubyong/transformer
Transformer network
'''
import tensorflow as tf
from data_load import load_vocab
from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, label_smoothing, noam_scheme
from utils import convert_idx_to_token_tensor
from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO)
class Transformer:
'''
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
training: boolean.
'''
def __init__(self, hp):
self.hp = hp
self.token2idx, self.idx2token = load_vocab(hp.vocab)
self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)
def encode(self, xs, training=True):
'''
Returns
memory: encoder outputs. (N, T1, d_model)
'''
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
x, seqlens, sents1 = xs
# src_masks
src_masks = tf.math.equal(x, 0) # (N, T1)
# embedding
enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model)
enc *= self.hp.d_model**0.5 # scale
enc += positional_encoding(enc, self.hp.maxlen1)
enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)
## Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# self-attention
enc = multihead_attention(queries=enc,
keys=enc,
values=enc,
key_masks=src_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False)
# feed forward
enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
memory = enc
return memory, sents1, src_masks
def decode(self, ys, memory, src_masks, training=True):
'''
memory: encoder outputs. (N, T1, d_model)
src_masks: (N, T1)
Returns
logits: (N, T2, V). float32.
y_hat: (N, T2). int32
y: (N, T2). int32
sents2: (N,). string.
'''
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
decoder_inputs, y, seqlens, sents2 = ys
# tgt_masks
tgt_masks = tf.math.equal(decoder_inputs, 0) # (N, T2)
# embedding
dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs) # (N, T2, d_model)
dec *= self.hp.d_model ** 0.5 # scale
dec += positional_encoding(dec, self.hp.maxlen2)
dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)
# Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# Masked self-attention (Note that causality is True at this time)
dec = multihead_attention(queries=dec,
keys=dec,
values=dec,
key_masks=tgt_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=True,
scope="self_attention")
# Vanilla attention
dec = multihead_attention(queries=dec,
keys=memory,
values=memory,
key_masks=src_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False,
scope="vanilla_attention")
### Feed Forward
dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
# Final linear projection (embedding weights are shared)
weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
return logits, y_hat, y, sents2
def train(self, xs, ys):
'''
Returns
loss: scalar.
train_op: training operation
global_step: scalar.
summaries: training summary node
'''
# forward
memory, sents1, src_masks = self.encode(xs)
logits, preds, y, sents2 = self.decode(ys, memory, src_masks)
# train scheme
y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"])) # 0: <pad>
loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
global_step = tf.train.get_or_create_global_step()
lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.minimize(loss, global_step=global_step)
tf.summary.scalar('lr', lr)
tf.summary.scalar("loss", loss)
tf.summary.scalar("global_step", global_step)
summaries = tf.summary.merge_all()
return loss, train_op, global_step, summaries
def eval(self, xs, ys):
'''Predicts autoregressively
At inference, input ys is ignored.
Returns
y_hat: (N, T2)
'''
decoder_inputs, y, y_seqlen, sents2 = ys
decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
ys = (decoder_inputs, y, y_seqlen, sents2)
memory, sents1, src_masks = self.encode(xs, False)
logging.info("Inference graph is being built. Please be patient.")
for _ in tqdm(range(self.hp.maxlen2)):
logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False)
if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break
_decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
ys = (_decoder_inputs, y, y_seqlen, sents2)
# monitor a random sample
n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
sent1 = sents1[n]
pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
sent2 = sents2[n]
tf.summary.text("sent1", sent1)
tf.summary.text("pred", pred)
tf.summary.text("sent2", sent2)
summaries = tf.summary.merge_all()
return y_hat, summaries