-
Notifications
You must be signed in to change notification settings - Fork 6
/
library_data.py
109 lines (95 loc) · 4.47 KB
/
library_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*
'''
This is a supporting library for the loading the data.
Paper: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks. S. Kumar, X. Zhang, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2019.
'''
from __future__ import division
import numpy as np
import random
import sys
import operator
import copy
from collections import defaultdict
import os, re
import cPickle
import argparse
from sklearn.preprocessing import scale
# LOAD THE NETWORK
def load_network(args, time_scaling=True):
'''
This function loads the input network.
The network should be in the following format:
One line per interaction/edge.
Each line should be: user, item, timestamp, state label, array of features.
Timestamp should be in cardinal format (not in datetime).
State label should be 1 whenever the user state changes, 0 otherwise. If there are no state labels, use 0 for all interactions.
Feature list can be as long as desired. It should be atleast 1 dimensional. If there are no features, use 0 for all interactions.
'''
network = args.network
datapath = args.datapath
user_sequence = []
item_sequence = []
label_sequence = []
feature_sequence = []
timestamp_sequence = []
start_timestamp = None
y_true_labels = []
print "\n\n**** Loading %s network from file: %s ****" % (network, datapath)
f = open(datapath,"r")
f.readline()
for cnt, l in enumerate(f):
# FORMAT: user, item, timestamp, state label, feature list
ls = l.strip().split(",")
user_sequence.append(ls[0])
item_sequence.append(ls[1])
if start_timestamp is None:
start_timestamp = float(ls[2])
timestamp_sequence.append(float(ls[2]) - start_timestamp)
y_true_labels.append(int(ls[3])) # label = 1 at state change, 0 otherwise
feature_sequence.append(map(float,ls[4:])) #list()
f.close()
user_sequence = np.array(user_sequence) #user序列
item_sequence = np.array(item_sequence) #item序列
timestamp_sequence = np.array(timestamp_sequence) #item序列的时间戳序列
print "Formating item sequence"
nodeid = 0
item2id = {}
item_timedifference_sequence = []
item_current_timestamp = defaultdict(float)
for cnt, item in enumerate(item_sequence):
if item not in item2id:
item2id[item] = nodeid
nodeid += 1
timestamp = timestamp_sequence[cnt]
item_timedifference_sequence.append(timestamp - item_current_timestamp[item]) #时间间隔
item_current_timestamp[item] = timestamp #每个item当前的时间(最新被作用的时间)
num_items = len(item2id)
item_sequence_id = [item2id[item] for item in item_sequence] #item序列,新的id
print "Formating user sequence"
nodeid = 0
user2id = {}
user_timedifference_sequence = []
user_current_timestamp = defaultdict(float)
user_previous_itemid_sequence = []
user_latest_itemid = defaultdict(lambda: num_items) #当没有键时,默认返回num_items值
for cnt, user in enumerate(user_sequence):
if user not in user2id: #给user编id
user2id[user] = nodeid
nodeid += 1
timestamp = timestamp_sequence[cnt]
user_timedifference_sequence.append(timestamp - user_current_timestamp[user]) #user序列的时间间隔
user_current_timestamp[user] = timestamp #user的时间戳
user_previous_itemid_sequence.append(user_latest_itemid[user]) #user上次作用的item
user_latest_itemid[user] = item2id[item_sequence[cnt]] #user最后作用的item
num_users = len(user2id)
user_sequence_id = [user2id[user] for user in user_sequence]
if time_scaling:
print "Scaling timestamps"
user_timedifference_sequence = scale(np.array(user_timedifference_sequence) + 1)
item_timedifference_sequence = scale(np.array(item_timedifference_sequence) + 1)
print "*** Network loading completed ***\n\n"
return [user2id, user_sequence_id, user_timedifference_sequence, user_previous_itemid_sequence, \
item2id, item_sequence_id, item_timedifference_sequence, \
timestamp_sequence, \
feature_sequence, \
y_true_labels]