-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
executable file
·141 lines (127 loc) · 6.46 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python
import argparse
import logging
import math
import sys
import time
import os
import copy
import pickle as pkl
import json
import numpy as np
import pdb
import torch
import torch.nn as nn
from utils.data_handler import *
from preprocess_data import *
from models.utils import *
from configs.test_configs import *
def draw(data, x, y, ax, cbar=False):
center = data.mean().item()
seaborn.heatmap(data,
xticklabels=x, square=True, yticklabels=y,
cbar=cbar, ax=ax, cmap=cmap) #'coolwarm')
def make_plots(layers, att_heads, net, sublayer_idx, x, y, prefix, combine_scores=False):
if combine_scores:
fig, axs = plt.subplots(1, len(layers), figsize=(10,10))
else:
fig, axs = plt.subplots(len(layers), len(att_heads), figsize=(40,10))
for idx_l, l in enumerate(layers):
for idx_h, h in enumerate(att_heads):
if combine_scores and idx_h>0:
att_scores += net.layers[l].attn[sublayer_idx].attn.data[0,h].data.cpu()
else:
att_scores = net.layers[l].attn[sublayer_idx].attn.data[0,h].data.cpu()
norm_att_scores = gaussian_filter(att_scores, sigma=1)
if combine_scores:
if h != att_heads[-1]: continue
draw(norm_att_scores, x, y if idx_l==0 else [], ax=axs[idx_l])
else:
draw(norm_att_scores, x, y if idx_h==0 else [], ax=axs[idx_l*len(att_heads) + idx_h])
plt.savefig('{}_combinedscores{}.png'.format(prefix, combine_scores), transparent=True)
# Evaluation routine
def generate_response(model, data, loader, vocab, slots):
result_dialogues = {}
model.eval()
batch_idx = 0
if args.verbose:
it = enumerate(loader)
else:
it = tqdm(enumerate(loader),total=len(loader), desc="", ncols=0)
with torch.no_grad():
for batch_idx, batch in it:
batch.to_cuda()
dialogue_id = batch.dial_ids[0]
ref_dialogue = data[dialogue_id]
result_dialogues[dialogue_id] = {}
dial_context = display_txt(ref_dialogue['in_txt'])
result_dialogues[dialogue_id]['dial_context'] = dial_context
original_state = display_state(ref_dialogue['state'], slots['domain_slots'])
result_dialogues[dialogue_id]['state'] = original_state
original_response = display_txt(ref_dialogue['out_txt_out'][:-1])
result_dialogues[dialogue_id]['response'] = original_response
dst_output = None
output = {}
if train_args.setting in ['dst', 'e2e']:
if train_args.add_prev_dial_state and not args.gt_previous_bs:
dialogue_name, turn_id = dialogue_id.split('_')
if int(turn_id) > 1:
assert dialogue_name == last_dialogue_name
assert int(turn_id) == (last_turn_id + 1)
batch.in_state = last_encoded_state
batch.in_state_mask = torch.ones(last_encoded_state.shape).unsqueeze(-2).long().cuda()
output, dst_output = generate_dst(model, batch, vocab, slots, args)
predicted_state = display_state(dst_output, slots['domain_slots'])
if train_args.add_prev_dial_state and not args.gt_previous_bs:
_, last_domain_token_state = make_bs_txt(dst_output, slots['domain_slots'])
last_encoded_state = encode(lang['in+domain+bs']['word2idx'], last_domain_token_state)
last_encoded_state = torch.from_numpy(np.asarray(last_encoded_state)).unsqueeze(0).long().cuda()
last_dialogue_name = dialogue_name
last_turn_id = int(turn_id)
result_dialogues[dialogue_id]['predicted_state'] = predicted_state
if args.verbose:
print("Dialogue ID: {}".format(dialogue_id))
print("Original state: {}".format(original_state))
print("Decoded state: {}".format(predicted_state))
if train_args.setting in ['c2t', 'e2e']:
res_output = generate_res(model, batch, vocab, slots, args, output, dst_output)
result_dialogues[dialogue_id]['predicted_response'] = res_output
if args.verbose:
print("Original response: {}".format(original_response))
for idx, response in res_output.items():
print('HYP[{}]: {} ({})'.format(idx+1, response['txt'], response['score']))
if args.verbose:
print('-----------------------')
return result_dialogues
# Load params and model
print('Loading training params from ' + args.out_dir + '/' + args.model + '.conf')
train_args = pkl.load(open(args.out_dir + '/' + args.model + '.conf', 'rb'))
print("Loading model weights from " + args.out_dir + '/' + args.model + '_{}.pth.tar'.format(args.tep))
model = torch.load(args.out_dir + '/' + args.model + '_{}.pth.tar'.format(args.tep))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load vocab
prefix = train_args.prefix
lang_dir = 'data{}/multi-woz/{}lang.pkl'.format(train_args.data_version, prefix)
print('Extracting vocab from ' + lang_dir)
lang = pkl.load(open(lang_dir, 'rb'))
# Load data
encoded_data_dir = 'data{}/multi-woz/{}encoded_data.pkl'.format(train_args.data_version, prefix)
print('Extracting data from ' + encoded_data_dir)
encoded_data = pkl.load(open(encoded_data_dir, 'rb'))
# Load slots
slots_dir = 'data{}/multi-woz/{}slots.pkl'.format(train_args.data_version, prefix)
print('Extracting slots from ' + slots_dir)
slots = pkl.load(open(slots_dir, 'rb'))
slots = merge_dst(slots)
test_data = encoded_data['test']
test_data = limit_his_len(test_data, lang['in']['word2idx'], train_args.max_dial_his_len, train_args.only_system_utt)
if train_args.detach_dial_his:
test_data = detach_dial_his(test_data, lang['in']['word2idx'])
test_dataloader, test_samples = dh.create_dataset(lang, slots, test_data, False, train_args, 1, 0)
print('#test samples = {}'.format(test_samples))
# generate
print('-----------------------generate--------------------------')
result = generate_response(model, test_data, test_dataloader, lang, slots)
logging.info('writing results to ' + args.out_dir + '/' + args.output)
json.dump(result, open(args.out_dir + '/' + args.output, 'w'), indent=4)