-
Notifications
You must be signed in to change notification settings - Fork 1
/
load_data.py
37 lines (31 loc) · 1.19 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import scipy.io as sio
import numpy as np
import torch
import torch_geometric.utils as utils
from torch_geometric.data import Data
def load_pyg_data(d_name,path=r'/root/data'):
data = sio.loadmat(f'{path}/{d_name}_str.mat')
label = data['Label'].reshape(-1)
attribute = torch.FloatTensor(data['Attributes'])
edge = torch.LongTensor(data['Edge'])
y = torch.LongTensor(label)
pygData = Data(x=attribute,edge_index=edge,y=y)
return pygData
def load_mat(d_name,path='./data'):
data = sio.loadmat(f'{path}/{d_name}.mat')
adj = torch.LongTensor(data['Network'].toarray())
attr = torch.FloatTensor(data['Attributes'].toarray())
label = torch.LongTensor(data['Label'].reshape(-1))
str_label = torch.LongTensor(data['str_anomaly_label'].reshape(-1))
attr_label = torch.LongTensor(data['attr_anomaly_label'].reshape(-1))
edge_index = utils.dense_to_sparse(adj)[0]
pygData = Data(x=attr,edge_index=edge_index,y=label,str_y=str_label,attr_y=attr_label)
return pygData
from pygod.utils import load_data as ld
def load_weibo():
data = ld(name='weibo')
data.str_y = data.attr_y = data.y
return data
if __name__ == '__main__':
pass
load_mat('cora')