-
Notifications
You must be signed in to change notification settings - Fork 18
/
utilities.py
96 lines (75 loc) · 3.64 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gzip
import pickle
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric
def log(str, logfile=None):
str = f'[{datetime.datetime.now()}] {str}'
print(str)
if logfile is not None:
with open(logfile, mode='a') as f:
print(str, file=f)
def pad_tensor(input_, pad_sizes, pad_value=-1e8):
max_pad_size = pad_sizes.max()
output = input_.split(pad_sizes.cpu().numpy().tolist())
output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value)
for slice_ in output], dim=0)
return output
class BipartiteNodeData(torch_geometric.data.Data):
def __init__(self, constraint_features, edge_indices, edge_features, variable_features,
candidates, nb_candidates, candidate_choice, candidate_scores):
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.candidates = candidates
self.nb_candidates = nb_candidates
self.candidate_choices = candidate_choice
self.candidate_scores = candidate_scores
def __inc__(self, key, value, store, *args, **kwargs):
if key == 'edge_index':
return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
elif key == 'candidates':
return self.variable_features.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
class GraphDataset(torch_geometric.data.Dataset):
def __init__(self, sample_files):
super().__init__(root=None, transform=None, pre_transform=None)
self.sample_files = sample_files
def len(self):
return len(self.sample_files)
def get(self, index):
with gzip.open(self.sample_files[index], 'rb') as f:
sample = pickle.load(f)
sample_observation, sample_action, sample_action_set, sample_scores = sample['data']
constraint_features, (edge_indices, edge_features), variable_features = sample_observation
constraint_features = torch.FloatTensor(constraint_features)
edge_indices = torch.LongTensor(edge_indices.astype(np.int32))
edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1))
variable_features = torch.FloatTensor(variable_features)
candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
candidate_choice = torch.where(candidates == sample_action)[0][0] # action index relative to candidates
candidate_scores = torch.FloatTensor([sample_scores[j] for j in candidates])
graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features,
candidates, len(candidates), candidate_choice, candidate_scores)
graph.num_nodes = constraint_features.shape[0]+variable_features.shape[0]
return graph
class Scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer, **kwargs):
super().__init__(optimizer, **kwargs)
def step(self, metrics):
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
self.last_epoch =+1
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.num_bad_epochs == self.patience:
self._reduce_lr(self.last_epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]