-
Notifications
You must be signed in to change notification settings - Fork 0
/
process.py
174 lines (134 loc) · 5.9 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from utils import compress_attention, create_mapping, BFS, build_graph
from multiprocessing import Pool
import torch
from parsivar import Tokenizer, FindStems
my_tokenizer = Tokenizer()
my_stemmer = FindStems()
# from utils import create_mapping, compress_attention, build_graph
def process_matrix(attentions, layer_idx=-1, head_num=0, avg_head=False, trim=True):
if avg_head:
attn = torch.mean(attentions[0][layer_idx], 0)
attention_matrix = attn.detach().numpy()
else:
attn = attentions[0][layer_idx][head_num]
attention_matrix = attn.detach().numpy()
attention_matrix = attention_matrix[1:-1, 1:-1]
return attention_matrix
def bfs(args):
s, end, graph, max_size, black_list_relation = args
return BFS(s, end, graph, max_size, black_list_relation)
def check_relations_validity(relations):
for rel in relations:
# if rel in invalid_relations_set or rel.isnumeric():
if rel in rel.isnumeric():
return False
return True
# def global_initializer(nlp_object):
# global pars_nlp
# pars_nlp = nlp_object
def filter_relation_sets(params):
triplet, id2token, verbs, token2id = params
# print("xxx triplet", triplet)
triplet_idx = triplet[0]
confidence = triplet[1]
head, tail = triplet_idx[0], triplet_idx[-1]
if head in id2token and tail in id2token:
head = id2token[head]
# print("xxx head", head)
tail = id2token[tail]
# print("xxx tail", tail)
# relations = [id2token[idx] for idx in triplet_idx[1:-1] if idx in id2token]
# print("verbs", verbs)
# print("token2id", token2id)
pre_relations = [n for n in verbs if token2id[n] - token2id[tail] > 0] # and token2id[n] == near_verb_id]
relations = []
near_verb_id = min([token2id[n] for n in pre_relations])
# relations.append(my_stemmer.convert_to_stem(id2token[near_verb_id]))
relations.append(id2token[near_verb_id])
# print(relations)
if len(relations) > 0:
return {'h': head, 't': tail, 'r': relations[0], 'c': confidence}
return {}
def contains_number(value):
return any([char.isdigit() for char in value])
def parse_sentence(sentence, tokenizer, encoder):
'''Implement the match part of MAMA
'''
verbs, inputs, tokenid2word_mapping, token2id, noun_chunks = create_mapping(sentence, return_pt=True,
tokenizer=tokenizer,
)
# print("tokenid2word_mapping", tokenid2word_mapping)
# print("token2id", token2id)
# print("verbs", verbs)
with torch.no_grad():
outputs = encoder(**inputs, output_attentions=True)
trim = True
attention = process_matrix(outputs[2], avg_head=True, trim=trim)
#
merged_attention = compress_attention(attention, tokenid2word_mapping)
attn_graph = build_graph(merged_attention)
tail_head_pairs = []
# print(token2id)
for head in noun_chunks:
for tail in noun_chunks:
if head != tail and tail not in verbs and head not in verbs and not contains_number(head):
tail_head_pairs.append((token2id[head], token2id[tail]))
# print("tail_head_pairs", tail_head_pairs)
# black_list_relation = set([token2id[n] for n in token2id if n not in verbs])
# black_list_relation = set([token2id[n] for n in verbs])
all_relation_pairs = []
id2token = {value: key for key, value in token2id.items()}
black_list_relation = set([token2id[n] for n in id2token.values() if n not in verbs])
# ##################################
# for a in black_list_relation:
# print("block_rels", id2token[a])
# print("black_list_relation", black_list_relation)
# for a in tail_head_pairs:
# print(id2token[a[0]], id2token[a[1]])
# ###################################
with Pool(10) as pool:
params = [(pair[0], pair[1], attn_graph, max(tokenid2word_mapping), black_list_relation,) for pair in
tail_head_pairs]
for output in pool.imap_unordered(bfs, params):
if len(output):
all_relation_pairs += [(o, id2token, verbs, token2id) for o in output]
# print("all_relation_pairs", all_relation_pairs)
triplet_text = []
with Pool(10) as pool:
for triplet in pool.imap_unordered(filter_relation_sets, all_relation_pairs):
if len(triplet) > 0:
triplet_text.append(triplet)
return triplet_text
# return triplet_text, disamb_ents
def deduplication2(output_tri):
for i in range(len(output_tri)):
a_tokens = my_tokenizer.tokenize_words(output_tri[i]['h'])
for j in range(len(output_tri)):
b_tokens = my_tokenizer.tokenize_words(output_tri[j]['h'])
for x in a_tokens:
if x in b_tokens and not x.isnumeric():
if len(output_tri[i]['h']) > len(output_tri[j]['h']):
output_tri[j]['h'] = output_tri[i]['h']
break
elif len(output_tri[i]['h']) < len(output_tri[j]['h']):
output_tri[i]['h'] = output_tri[j]['h']
break
return output_tri
def deduplication(triplets):
unique_pairs = []
pair_confidence = []
for t in triplets:
key = '{}\t{}\t{}'.format(t['h'], t['r'], t['t'])
conf = t['c']
if key not in unique_pairs:
unique_pairs.append(key)
pair_confidence.append(conf)
# else:
# print("key", key)
# print("unique_pairs", unique_pairs)
unique_triplets = []
for idx, unique_pair in enumerate(unique_pairs):
h, r, t = unique_pair.split('\t')
unique_triplets.append({'h': h, 'r': r, 't': t, 'c': pair_confidence[idx]})
# print(unique_triplets)
return unique_triplets