This repository has been archived by the owner on Jul 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 163
/
edge_cls_tasker.py
84 lines (66 loc) · 2.38 KB
/
edge_cls_tasker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import taskers_utils as tu
import utils as u
class Edge_Cls_Tasker():
def __init__(self,args,dataset):
self.data = dataset
#max_time for link pred should be one before
self.max_time = dataset.max_time
self.args = args
self.num_classes = dataset.num_classes
if not args.use_1_hot_node_feats:
self.feats_per_node = dataset.feats_per_node
self.get_node_feats = self.build_get_node_feats(args,dataset)
self.prepare_node_feats = self.build_prepare_node_feats(args,dataset)
self.is_static = False
def build_prepare_node_feats(self,args,dataset):
if args.use_2_hot_node_feats or args.use_1_hot_node_feats:
def prepare_node_feats(node_feats):
return u.sparse_prepare_tensor(node_feats,
torch_size= [dataset.num_nodes,
self.feats_per_node])
else:
prepare_node_feats = self.data.prepare_node_feats
return prepare_node_feats
def build_get_node_feats(self,args,dataset):
if args.use_2_hot_node_feats:
max_deg_out, max_deg_in = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg_out + max_deg_in
def get_node_feats(adj):
return tu.get_2_hot_deg_feats(adj,
max_deg_out,
max_deg_in,
dataset.num_nodes)
elif args.use_1_hot_node_feats:
max_deg,_ = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg
def get_node_feats(adj):
return tu.get_1_hot_deg_feats(adj,
max_deg,
dataset.num_nodes)
else:
def get_node_feats(adj):
return dataset.nodes_feats
return get_node_feats
def get_sample(self,idx,test):
hist_adj_list = []
hist_ndFeats_list = []
hist_mask_list = []
for i in range(idx - self.args.num_hist_steps, idx+1):
cur_adj = tu.get_sp_adj(edges = self.data.edges,
time = i,
weighted = True,
time_window = self.args.adj_mat_time_window)
node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
node_feats = self.get_node_feats(cur_adj)
cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes)
hist_adj_list.append(cur_adj)
hist_ndFeats_list.append(node_feats)
hist_mask_list.append(node_mask)
label_adj = tu.get_edge_labels(edges = self.data.edges,
time = idx)
return {'idx': idx,
'hist_adj_list': hist_adj_list,
'hist_ndFeats_list': hist_ndFeats_list,
'label_sp': label_adj,
'node_mask_list': hist_mask_list}