-
Notifications
You must be signed in to change notification settings - Fork 11
/
invertible_network_utils.py
142 lines (116 loc) · 4.65 KB
/
invertible_network_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Create invertible mixing networks."""
import numpy as np
import torch
from torch import nn
from scipy.stats import ortho_group
from typing import Union
from typing_extensions import Literal
import encoders
__all__ = ["construct_invertible_flow", "construct_invertible_mlp"]
def construct_invertible_mlp(
n: int = 20,
n_layers: int = 2,
n_iter_cond_thresh: int = 10000,
cond_thresh_ratio: float = 0.25,
weight_matrix_init: Union[Literal["pcl"], Literal["rvs"]] = "pcl",
act_fct: Union[
Literal["relu"],
Literal["leaky_relu"],
Literal["elu"],
Literal["smooth_leaky_relu"],
Literal["softplus"],
] = "leaky_relu",
):
"""
Create an (approximately) invertible mixing network based on an MLP.
Based on the mixing code by Hyvarinen et al.
Args:
n: Dimensionality of the input and output data
n_layers: Number of layers in the MLP.
n_iter_cond_thresh: How many random matrices to use as a pool to find weights.
cond_thresh_ratio: Relative threshold how much the invertibility
(based on the condition number) can be violated in each layer.
weight_matrix_init: How to initialize the weight matrices.
act_fct: Activation function for hidden layers.
"""
class SmoothLeakyReLU(nn.Module):
def __init__(self, alpha=0.2):
super().__init__()
self.alpha = alpha
def forward(self, x):
return self.alpha * x + (1 - self.alpha) * torch.log(1 + torch.exp(x))
def get_act_fct(act_fct):
if act_fct == "relu":
return torch.nn.ReLU, {}, 1
if act_fct == "leaky_relu":
return torch.nn.LeakyReLU, {"negative_slope": 0.2}, 1
elif act_fct == "elu":
return torch.nn.ELU, {"alpha": 1.0}, 1
elif act_fct == "max_out":
raise NotImplemented()
elif act_fct == "smooth_leaky_relu":
return SmoothLeakyReLU, {"alpha": 0.2}, 1
elif act_fct == "softplus":
return torch.nn.Softplus, {"beta": 1}, 1
else:
raise Exception(f"activation function {act_fct} not defined.")
layers = []
act_fct, act_kwargs, act_fac = get_act_fct(act_fct)
# Subfuction to normalize mixing matrix
def l2_normalize(Amat, axis=0):
# axis: 0=column-normalization, 1=row-normalization
l2norm = np.sqrt(np.sum(Amat * Amat, axis))
Amat = Amat / l2norm
return Amat
condList = np.zeros([n_iter_cond_thresh])
if weight_matrix_init == "pcl":
for i in range(n_iter_cond_thresh):
A = np.random.uniform(-1, 1, [n, n])
A = l2_normalize(A, axis=0)
condList[i] = np.linalg.cond(A)
condList.sort() # Ascending order
condThresh = condList[int(n_iter_cond_thresh * cond_thresh_ratio)]
print("condition number threshold: {0:f}".format(condThresh))
for i in range(n_layers):
lin_layer = nn.Linear(n, n, bias=False)
if weight_matrix_init == "pcl":
condA = condThresh + 1
while condA > condThresh:
weight_matrix = np.random.uniform(-1, 1, (n, n))
weight_matrix = l2_normalize(weight_matrix, axis=0)
condA = np.linalg.cond(weight_matrix)
# print(" L{0:d}: cond={1:f}".format(i, condA))
print(
f"layer {i+1}/{n_layers}, condition number: {np.linalg.cond(weight_matrix)}"
)
lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
elif weight_matrix_init == "rvs":
weight_matrix = ortho_group.rvs(n)
lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
elif weight_matrix_init == "expand":
pass
else:
raise Exception(f"weight matrix {weight_matrix_init} not implemented")
layers.append(lin_layer)
if i < n_layers - 1:
layers.append(act_fct(**act_kwargs))
mixing_net = nn.Sequential(*layers)
# fix parameters
for p in mixing_net.parameters():
p.requires_grad = False
return mixing_net
def construct_invertible_flow(
n: int,
coupling_block: Union[Literal["gin", "glow"]] = "gin",
num_nodes: int = 8,
node_size_factor: int = 1,
):
"""
Creates an invertible mixing based through a flow-based network.
Args:
n: Dimensionality of the input and output data
coupling_block: Coupling method to use to combine nodes.
num_nodes: Depth of the flow network.
node_size_factor: Hidden units per node.
"""
return encoders.get_flow(n, n, False, coupling_block, num_nodes, node_size_factor)