-
Notifications
You must be signed in to change notification settings - Fork 8
/
BiopointData.py
48 lines (39 loc) · 1.63 KB
/
BiopointData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch_geometric.data import InMemoryDataset,Data
from os.path import join, isfile
from os import listdir
import numpy as np
import os.path as osp
from utils.construct_graph import read_data
class BiopointDataset(InMemoryDataset):
def __init__(self, root, name, transform=None, pre_transform=None):
self.root = root
self.name = name
super(BiopointDataset, self).__init__(root,transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
data_dir = osp.join(self.root,'raw')
onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))]
onlyfiles.sort()
return onlyfiles
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
# Download to `self.raw_dir`.
return
def process(self):
# Read data into huge `Data` list.
self.data, self.slices = read_data(self.raw_dir)
if self.pre_filter is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [data for data in data_list if self.pre_filter(data)]
self.data, self.slices = self.collate(data_list)
if self.pre_transform is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [self.pre_transform(data) for data in data_list]
self.data, self.slices = self.collate(data_list)
torch.save((self.data, self.slices), self.processed_paths[0])
def __repr__(self):
return '{}({})'.format(self.name, len(self))