-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
113 lines (88 loc) · 3.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import dgl
import torch as th
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics
from sklearn.manifold import TSNE
import utils
from model import EGES
from sampler import Sampler
def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
sampler = Sampler(
train_g,
args.walk_length,
args.num_walks,
args.window_size,
args.num_negative
)
# for each node in the graph, we sample pos and neg
# pairs for it, and feed these sampled pairs into the model.
# (nodes in the graph are of course batched before sampling)
dataloader = DataLoader(
th.arange(train_g.num_nodes()),
# this is the batch_size of input nodes
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: sampler.sample(x, sku_info)
)
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(args.epochs):
epoch_total_loss = 0
for step, (srcs, dsts, labels) in enumerate(dataloader):
# the batch size of output pairs is unfixed
# TODO: shuffle the triples?
srcs_embeds, dsts_embeds = model(srcs, dsts)
loss = model.loss(srcs_embeds, dsts_embeds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_total_loss += loss.item()
if step % args.log_every == 0:
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)))
eval(model, test_g, sku_info)
return model
def eval(model, test_graph, sku_info):
preds, labels = [], []
for edge in test_graph:
src = th.tensor(sku_info[edge.src.numpy()[0]]).view(1, 4)
dst = th.tensor(sku_info[edge.dst.numpy()[0]]).view(1, 4)
# (1, dim)
src = model.query_node_embed(src)
dst = model.query_node_embed(dst)
# (1, dim) -> (1, dim) -> (1, )
logit = th.sigmoid(th.sum(src * dst))
preds.append(logit.detach().numpy().tolist())
labels.append(edge.label)
fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
print("Evaluate link prediction AUC: {:.4f}".format(metrics.auc(fpr, tpr)))
def tsne(model, sku_info):
nodes = []
for _, sku_fields in sku_info.items():
nodes.append(sku_fields)
nodes = th.tensor(nodes, dtype=th.int32)
embeds = model.query_node_embed(nodes).detach().numpy()
embeds = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(embeds)
if __name__ == "__main__":
args = utils.init_args()
valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)
g, sku_encoder, sku_decoder = utils.construct_graph(
args.action_data,
args.session_interval_sec,
valid_sku_raw_ids,
args.min_sku_freq
)
train_g, test_g = utils.split_train_test_graph(g, args.num_negative)
sku_info_encoder, sku_info_decoder, sku_info = \
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder)
num_skus = len(sku_encoder)
num_brands = len(sku_info_encoder["brand"])
num_shops = len(sku_info_encoder["shop"])
num_cates = len(sku_info_encoder["cate"])
print(
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\
format(num_skus, num_brands, num_shops, num_cates)
)
model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates)
tsne_embeds = tsne(model, sku_info)