-
Notifications
You must be signed in to change notification settings - Fork 8
/
preprocess.py
182 lines (164 loc) · 5.95 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import sys,pickle
from IPython import embed
from collections import defaultdict
import matplotlib.pyplot as plt
#from FeaturePipeline import FeatureGenerator, warp_feature
import numpy as np
from tqdm import tqdm
import fasttext as ft
import torch
#from nltk.corpus import wordnet
import csv
class FeatureGenerator(object):
"""Name Embedding FeatureGenerator"""
def __init__(self, model_path):
super(FeatureGenerator, self).__init__()
self.model = ft.load_model(model_path)
def generateEmbFeature(self, name, sent=True):
if sent == True:
return self.model.get_sentence_vector(name.replace('"',''))
else:
return self.model.get_word_vector(name.replace('"',''))
# the interface to generate a numpy feature matrix for every node in the graph
def warp_feature(model, source, mapping):
feature_matrix = np.zeros((len(source), 100))
row = 0
none_cnt = 0
for k in source:
# now we use the first attribute
for idx,attr in enumerate(source[k][1]):
#embed()
#feature_matrix[mapping[k], idx*100:(idx+1)*100] = model.generateEmbFeature(attr, sent=True)
feature_matrix[mapping[k], :] += model.generateEmbFeature(attr, sent=True)
return feature_matrix
def generateTrainWithType(in_path, graph_a, graph_b, positive_only=False):
train_data, val_data, test_data = [], [], []
with open(in_path+'train.csv') as IN:
IN.readline()
left_set, right_set = set(), set()
for line in IN:
tmp = line.strip().split(',')
if tmp[0] not in left_set and tmp[1] not in right_set:
left_set.add(tmp[0])
right_set.add(tmp[1])
else:
continue
#print(graph_a.entity_table['ID_{}'.format(tmp[0])])
#print(graph_b.entity_table['ID_{}'.format(tmp[1])])
#print(tmp[2])
#embed()
train_data.append([graph_a.id2idx['ID_{}'.format(tmp[0])],
graph_b.id2idx['ID_{}'.format(tmp[1])], int(tmp[2])])
# embed()
with open(in_path+'valid.csv') as IN:
IN.readline()
left_set, right_set = set(), set()
for line in IN:
tmp = line.strip().split(',')
if tmp[0] not in left_set and tmp[1] not in right_set:
left_set.add(tmp[0])
right_set.add(tmp[1])
else:
continue
#print(graph_a.entity_table['ID_{}'.format(tmp[0])])
#print(graph_b.entity_table['ID_{}'.format(tmp[1])])
#print(tmp[2])
#embed()
val_data.append([graph_a.id2idx['ID_{}'.format(tmp[0])],
graph_b.id2idx['ID_{}'.format(tmp[1])], int(tmp[2])])
with open(in_path+'test.csv') as IN:
IN.readline()
left_set, right_set = set(), set()
for line in IN:
tmp = line.strip().split(',')
if tmp[0] not in left_set and tmp[1] not in right_set:
left_set.add(tmp[0])
right_set.add(tmp[1])
else:
continue
#print(graph_a.entity_table['ID_{}'.format(tmp[0])])
#print(graph_b.entity_table['ID_{}'.format(tmp[1])])
#print(tmp[2])
#embed()
test_data.append([graph_a.id2idx['ID_{}'.format(tmp[0])],
graph_b.id2idx['ID_{}'.format(tmp[1])], int(tmp[2])])
return torch.LongTensor(train_data), torch.LongTensor(val_data), torch.LongTensor(test_data)
class Graph(object):
"""docstring for Graph"""
def __init__(self, pretrain):
super(Graph, self).__init__()
#self.relation_list = relation_list
self.id2idx = {}
self.entity_table = {}
self.features = None
self.edge_src = []
self.edge_dst = []
self.edge_type = []
self.pretrain_path = pretrain
def buildGraph(self, table):
# self.self.entity_table_table = self.entity_table_path
#
with open(table) as IN:
spamreader = csv.reader(IN, delimiter=',')
# embed()
# fields = IN.readline().strip().split(',')
fields = next(spamreader)
# self.entity_table, id2idx = {}, {}
type_list, type_dict = [], {}
attr_list = []
for idx,field in enumerate(fields[1:]):
if '_' in field:
type_list.append(field.split('_')[0])
else:
attr_list.append(field)
edge_list = []
for line in spamreader:
# print(line)
tmp = line
for idx, value in enumerate(tmp[1:]):
if idx < len(type_list):
if idx == 0:
_ID = 'ID_{}'.format(tmp[0])
self.entity_table[_ID] = [type_list[idx], value]
self.id2idx[_ID] = len(self.id2idx)
target_id = self.id2idx[_ID]
else:
_id = '{}_{}'.format(type_list[idx],value)
if _id not in self.entity_table:
self.entity_table[_id] = [type_list[idx], value]
#_ID = '{}_{}'.format(tm, type_list[idx])
self.id2idx[_id] = len(self.id2idx)
#edge_list.append([target_id, idx, id2idx[value]])
self.edge_src.append(target_id)
self.edge_dst.append(self.id2idx[_id])
self.edge_type.append(idx - 1)
else:
self.entity_table[_ID].append(value)
feat = FeatureGenerator(self.pretrain_path)
#for tmp in triples
#if tmp[0] in self.id2idx and tmp[1] in self.id2idx and tmp[2] in self.relation_list:
#if tmp[0] in self.id2idx and tmp[2] in self.id2idx:
#g.add_edges()
#self.edge_list.append([self.id2idx[tmp[0]], self.id2idx[tmp[2]], self.relation_list.index(tmp[1])])
# embed()
self.features = warp_feature(feat, self.entity_table, self.id2idx)
#assert self.features.shape()[0] == len(self.id2idx)
def checkTest(mapping_a, mapping_b, in_file):
type_cnt_a, type_cnt_b = defaultdict(int), defaultdict(int)
str_pair = set()
with open(in_file) as IN:
for line in IN:
tmp = line.strip().split('\t')
if tmp[0] in mapping_a and tmp[1] in mapping_b:
str_pair.add('{}_{}'.format(tmp[0], tmp[1]))
for x in mapping_a[tmp[0]]['type']:
type_cnt_a[x]+= 1
for x in mapping_b[tmp[1]]['type']:
type_cnt_b[x]+= 1
print("Len of original data is {}".format(len(str_pair)))
print(type_cnt_a, type_cnt_b)
if __name__ == '__main__':
dataset = 'itunes' #imdb
graph_a, graph_b = Graph(), Graph()
graph_a.buildGraph('data/itunes_amazon_exp_data/exp_data/tableA.csv')
graph_b.buildGraph('data/itunes_amazon_exp_data/exp_data/tableB.csv')