-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
69 lines (65 loc) · 2.75 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from model import *
from utils import *
from dataloader import *
from beamsearch import *
def load_model():
x_cti = load_tkn_to_idx(sys.argv[2])
x_wti = load_tkn_to_idx(sys.argv[3])
y_itw = load_idx_to_tkn(sys.argv[4])
model = rnn_encoder_decoder(len(x_cti), len(x_wti), len(y_itw))
print(model)
load_checkpoint(sys.argv[1], model)
return model, x_cti, x_wti, y_itw
def run_model(model, data, itw):
with torch.no_grad():
model.eval()
for batch in data.split():
xc, xw, lens = batch.sort()
xc, xw = data.tensor(xc, xw, lens, eos = True)
b, t = len(xw), 0 # batch size, time step
eos = [False for _ in xw] # EOS states
mask, lens = maskset(xw)
model.dec.hs = model.enc(b, xc, xw, lens)
model.dec.hidden = model.enc.hidden
yi = LongTensor([[SOS_IDX]] * b)
if model.dec.feed_input:
model.dec.attn.v = zeros(b, 1, HIDDEN_SIZE)
batch.y1 = [[] for _ in range(b)]
batch.prob = [Tensor([0]) for _ in range(b)]
batch.attn = [[["", *batch.x1[i], EOS]] for i in batch.idx]
while t < MAX_LEN and sum(eos) < len(eos):
yo = model.dec(yi, mask, t)
args = (model.dec, batch, itw, eos, lens, yo)
yi = beam_search(*args, t) if BEAM_SIZE > 1 else greedy_search(*args)
t += 1
batch.unsort()
if VERBOSE:
print()
for i, x in filter(lambda x: not x[0] % BEAM_SIZE, enumerate(batch.attn)):
print("attn[%d] =" % (i // BEAM_SIZE))
print(mat2csv(x, rh = True))
for i, (x0, y0, y1) in enumerate(zip(batch.x0, batch.y0, batch.y1)):
if not i % BEAM_SIZE: # use the best candidate from each beam
y1 = [itw[y] for y in y1[:-1]]
yield x0, y0, y1
def predict(filename, model, x_cti, x_wti, y_itw):
data = dataloader()
fo = open(filename)
for x0 in fo:
x0 = x0.strip()
x1 = tokenize(x0, UNIT)
xc = [[x_cti[c] if c in x_cti else UNK_IDX for c in w] for w in x1]
xw = [x_wti[w] if w in x_wti else UNK_IDX for w in x1]
data.append_item(x0, x1, xc, xw)
for _ in range(BEAM_SIZE - 1):
data.append_row()
data.append_item(x0, x1, xc, xw)
data.append_row()
fo.close()
data.strip()
return run_model(model, data, y_itw)
if __name__ == "__main__":
if len(sys.argv) != 6:
sys.exit("Usage: %s model vocab.src.char_to_idx vocab.src.word_to_idx vocab.tgt.word_to_idx test_data" % sys.argv[0])
for x, y0, y1 in predict(sys.argv[5], *load_model()):
print((x, y0, y1) if y0 else (x, y1))