-
Notifications
You must be signed in to change notification settings - Fork 8
/
sampler.py
99 lines (84 loc) · 3.04 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import dgl
import numpy as np
import torch as th
class Sampler:
def __init__(self,
graph,
walk_length,
num_walks,
window_size,
num_negative):
self.graph = graph
self.walk_length = walk_length
self.num_walks = num_walks
self.window_size = window_size
self.num_negative = num_negative
self.node_weights = self.compute_node_sample_weight()
def sample(self, batch, sku_info):
"""
Given a batch of target nodes, sample postive
pairs and negative pairs from the graph
"""
batch = np.repeat(batch, self.num_walks)
pos_pairs = self.generate_pos_pairs(batch)
neg_pairs = self.generate_neg_pairs(pos_pairs)
# get sku info with id
srcs, dsts, labels = [], [], []
for pair in pos_pairs + neg_pairs:
src, dst, label = pair
src_info = sku_info[src]
dst_info = sku_info[dst]
srcs.append(src_info)
dsts.append(dst_info)
labels.append(label)
return th.tensor(srcs), th.tensor(dsts), th.tensor(labels)
def filter_padding(self, traces):
for i in range(len(traces)):
traces[i] = [x for x in traces[i] if x != -1]
def generate_pos_pairs(self, nodes):
"""
For seq [1, 2, 3, 4] and node NO.2,
the window_size=1 will generate:
(1, 2) and (2, 3)
"""
# random walk
traces, types = dgl.sampling.random_walk(
g=self.graph,
nodes=nodes,
length=self.walk_length,
prob="weight"
)
traces = traces.tolist()
self.filter_padding(traces)
# skip-gram
pairs = []
for trace in traces:
for i in range(len(trace)):
center = trace[i]
left = max(0, i - self.window_size)
right = min(len(trace), i + self.window_size + 1)
pairs.extend([[center, x, 1] for x in trace[left:i]])
pairs.extend([[center, x, 1] for x in trace[i+1:right]])
return pairs
def compute_node_sample_weight(self):
"""
Using node degree as sample weight
"""
return self.graph.in_degrees().float()
def generate_neg_pairs(self, pos_pairs):
"""
Sample based on node freq in traces, frequently shown
nodes will have larger chance to be sampled as
negative node.
"""
# sample `self.num_negative` neg dst node
# for each pos node pair's src node.
negs = th.multinomial(
self.node_weights,
len(pos_pairs) * self.num_negative,
replacement=True
).tolist()
tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)
assert(len(tar) == len(negs))
neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]
return neg_pairs