-
Notifications
You must be signed in to change notification settings - Fork 10
/
gene_cas.py
243 lines (200 loc) · 9.16 KB
/
gene_cas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import pickle
import random
import time
import networkx as nx
from absl import app, flags
# flags
FLAGS = flags.FLAGS
# observation and prediction time settings:
# for twitter dataset, we use 3600*24*1 (86400, 1 day) or 3600*24*2 (172800, 2 days) as observation time
# we use 3600*24*32 (2764800, 32 days) as prediction time
# for weibo dataset, we use 1800 (0.5 hour) or 3600 (1 hour) as observation time
# we use 3600*24 (86400, 1 day) as prediction time
# for aps dataset, we use 365*3 (1095, 3 years) or 365*5+1 (1826, 5 years) as observation time
# we use 365*20+5 (7305, 20 years) as prediction time
flags.DEFINE_integer('observation_time', 1800, 'Observation time.')
flags.DEFINE_integer('prediction_time', 86400, 'Prediction time.')
# path
flags.DEFINE_string ('input', '', 'Dataset path.')
def generate_cascades(ob_time, pred_time, filename, file_train, file_val, file_test, seed):
# a list to save the cascades
filtered_data = list()
cascades_type = dict() # 0 for train, 1 for val, 2 for test
cascades_time_dict = dict()
cascades_total = 0
cascades_valid_total = 0
# Important node: for weibo dataset, if you want to compare CasFlow with baselines such as DeepHawkes and CasCN,
# make sure the ob_time is set consistently.
if ob_time in [3600, 3600*2, 3600*3]: # end_hour is set to 19 in DeepHawkes and CasCN, but it should be 18
end_hour = 19
else:
end_hour = 18
with open(filename) as file:
for line in file:
# split the cascades into 5 parts
# 1: cascade id
# 2: user/item id
# 3: publish date/time
# 4: number of adoptions
# 5: a list of adoptions
cascades_total += 1
parts = line.split('\t')
cascade_id = parts[0]
# filter cascades by their publish date/time
if 'weibo' in FLAGS.input:
# timezone invariant
hour = int(time.strftime('%H', time.gmtime(float(parts[2])))) + 8
if hour < 8 or hour >= end_hour:
continue
elif 'twitter' in FLAGS.input:
month = int(time.strftime('%m', time.localtime(float(parts[2]))))
day = int(time.strftime('%d', time.localtime(float(parts[2]))))
if month == 4 and day > 10:
continue
elif 'aps' in FLAGS.input:
publish_time = parts[2]
if publish_time > '1997':
continue
else:
pass
paths = parts[4].strip().split(' ')
observation_path = list()
# number of observed popularity
p_o = 0
for p in paths:
# observed adoption/participant
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now < ob_time:
p_o += 1
# save observed adoption/participant into 'observation_path'
observation_path.append((nodes, time_now))
# filter cascades which observed popularity less than 10
if p_o < 10:
continue
# sort list by their publish time/date
observation_path.sort(key=lambda tup: tup[1])
# for each cascade, save its publish time into a dict
if 'aps' in FLAGS.input:
cascades_time_dict[cascade_id] = int(0)
else:
cascades_time_dict[cascade_id] = int(parts[2])
o_path = list()
for i in range(len(observation_path)):
nodes = observation_path[i][0]
t = observation_path[i][1]
o_path.append('/'.join(nodes) + ':' + str(t))
# write data into the targeted file, if they are not excluded
line = parts[0] + '\t' + parts[1] + '\t' + parts[2] + '\t' \
+ parts[3] + '\t' + ' '.join(o_path) + '\n'
filtered_data.append(line)
cascades_valid_total += 1
# open three files to save train, val, and test set, respectively
with open(file_train, 'w') as data_train, \
open(file_val, 'w') as data_val, \
open(file_test, 'w') as data_test:
def shuffle_cascades():
# shuffle all cascades
shuffle_time = list(cascades_time_dict.keys())
random.seed(seed)
random.shuffle(shuffle_time)
count = 0
# split dataset
for key in shuffle_time:
if count < cascades_valid_total * .7:
cascades_type[key] = 0 # training set, 70%
elif count < cascades_valid_total * .85:
cascades_type[key] = 1 # validation set, 15%
else:
cascades_type[key] = 2 # test set, 15%
count += 1
shuffle_cascades()
# number of valid cascades
print("Number of valid cascades: {}/{}".format(cascades_valid_total, cascades_total))
# 3 lists to save the filtered sets
filtered_data_train = list()
filtered_data_val = list()
filtered_data_test = list()
for line in filtered_data:
cascade_id = line.split('\t')[0]
if cascades_type[cascade_id] == 0:
filtered_data_train.append(line)
elif cascades_type[cascade_id] == 1:
filtered_data_val.append(line)
elif cascades_type[cascade_id] == 2:
filtered_data_test.append(line)
else:
print('What happened?')
print("Number of valid train cascades: {}".format(len(filtered_data_train)))
print("Number of valid val cascades: {}".format(len(filtered_data_val)))
print("Number of valid test cascades: {}".format(len(filtered_data_test)))
# shuffle the train set again
random.seed(seed)
random.shuffle(filtered_data_train)
def file_write(file_name):
# write file, note that compared to the original 'dataset.txt', only cascade_id and each of the
# observed adoptions are saved, plus label information at last
file_name.write(cascade_id + '\t' + '\t'.join(observation_path) + '\t' + label + '\n')
# write cascades into files
for line in filtered_data_train + filtered_data_val + filtered_data_test:
# split the cascades into 5 parts
parts = line.split('\t')
cascade_id = parts[0]
observation_path = list()
label = int()
edges = set()
paths = parts[4].split(' ')
for p in paths:
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now < ob_time:
observation_path.append(','.join(nodes) + ':' + str(time_now))
for i in range(1, len(nodes)):
edges.add(nodes[i - 1] + ':' + nodes[i] + ':1')
# add label information depends on prediction_time, e.g., 24 hours for weibo dataset
if time_now < pred_time:
label += 1
# calculate the incremental popularity
label = str(label - len(observation_path))
# write files by cascade type
# 0 to train, 1 to val, 2 to test
if cascade_id in cascades_type and cascades_type[cascade_id] == 0:
file_write(data_train)
elif cascade_id in cascades_type and cascades_type[cascade_id] == 1:
file_write(data_val)
elif cascade_id in cascades_type and cascades_type[cascade_id] == 2:
file_write(data_test)
def generate_global_graph(file_name, graph_save_path):
g = nx.Graph()
with open(file_name, 'r') as f:
for line in f:
parts = line.strip().split('\t')
paths = parts[4].strip().split(' ')
for path in paths:
nodes = path.split(':')[0].split('/')
if len(nodes) < 2:
g.add_node(nodes[-1])
else:
g.add_edge(nodes[-1], nodes[-2])
print("Number of nodes in global graph:", g.number_of_nodes())
print("Number of edges in global graph:", g.number_of_edges())
with open(graph_save_path, 'wb') as f:
pickle.dump(g, f)
def main(argv):
time_start = time.time()
print('Start to run the CasFlow code!\n')
print('Should be finished in a few minutes.\n')
print('Dataset path: {}\n'.format(FLAGS.input))
if 'xovee' in FLAGS.input:
print('Note: this is just a sample dataset.\n')
generate_cascades(FLAGS.observation_time, FLAGS.prediction_time,
FLAGS.input + 'dataset.txt',
FLAGS.input + 'train.txt',
FLAGS.input + 'val.txt',
FLAGS.input + 'test.txt',
seed=0)
generate_global_graph(FLAGS.input + 'dataset.txt',
FLAGS.input + 'global_graph.pkl')
print('Processing time: {:.2f}s'.format(time.time()-time_start))
if __name__ == '__main__':
app.run(main)