-
Notifications
You must be signed in to change notification settings - Fork 0
/
gmml.py
125 lines (108 loc) · 5.33 KB
/
gmml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.distributions.multivariate_normal import MultivariateNormal
class GMML(nn.Module):
def __init__(self, input_dim, d, n_class, n_component=1, cov_type="diag", log_stretch_trick=False, **kwargs):
"""A GMM layer, which can be used as a last layer for a classification neural network.
Attributes:
input_dim (int): Input dimension
d (int): Reduced number of dimensions after random projection
n_class (int): The number of classes
n_component (int): The number of Gaussian components per class
cov_type: (str): The type of covariance matrices. If "diag", diagonal matrices are used, which is
computationally advantageous. If "full", the model uses full rank matrices that have high expression
capability at the cost of increased computational complexity. If "tril", use lower triangular matrices.
log_stretch_trick (bool): If True, computes the weighted sum over the logarithms of log probabilities, i.e.,
log(log(p) instead of log(p). This can help in situations where a big portion of data is classified with
maximum confidence (log_prob = 0, prob = 1) after the weighted sum.
"""
super(GMML, self).__init__(**kwargs)
assert input_dim > 0
assert d > 0
assert n_class > 1
assert n_component > 0
assert cov_type in ["diag", "full", "tril"]
self.input_dim = input_dim
self.d = d
self.s = n_class
self.g = n_component
self.cov_type = cov_type
self.n_total_component = n_component * n_class
self.log_stretch_trick = log_stretch_trick
self.relu = torch.nn.ReLU()
# Dimensionality reduction
self.bottleneck = nn.Linear(self.input_dim, self.d, bias=False)
self.bottleneck.requires_grad_(False)
self.bottleneck.weight.data = get_achlioptas(self.input_dim, self.d).transpose(0, 1)
# Free parameters
self.mu_p = Parameter(torch.randn(self.s, self.g, self.d), requires_grad=True)
self.omega_p = Parameter(torch.ones(self.s, self.g), requires_grad=True)
sigma_data = torch.eye(self.d).reshape(1, 1, self.d, self.d).repeat(self.s, self.g, 1, 1)
self.sigma_p = Parameter(sigma_data, requires_grad=True)
# Sampled parameters
self.omega = None
self.distribution = None
with torch.no_grad():
self.sample_parameters()
def init_mu(self, mu):
with torch.no_grad():
self.mu_p.data = mu
def init_omega(self, omega):
with torch.no_grad():
self.omega_p.data = omega
def forward(self, x):
b = x.shape[0]
x = self.bottleneck(x)
x = x.reshape(b, 1, x.shape[1])
log_wp = self.distribution.log_prob(x) + self.omega.reshape(self.s * self.g).log()
log_wp = - self.relu(-log_wp)
if self.log_stretch_trick:
log_wp = -torch.log(-log_wp + 0.0001)
log_mixture_p = log_wp - log_wp.logsumexp(dim=-1, keepdim=True)
return log_mixture_p.reshape(b, self.s, self.g).logsumexp(dim=-1)
def sample_parameters(self):
# OMEGA - should sum up to 1
# omega = torch.softmax(self.omega_p, -1) / self.omega_p.shape[-2]
omega = self.omega_p.data.clone()
omega -= min(0, omega.min().item())
omega = omega / omega.sum(dim=-1, keepdim=True)
omega /= omega.shape[-2]
self.omega = omega
# SIGMA - symmetric positive definite
device = self.sigma_p.device
tli = torch.tril_indices(row=self.sigma_p.size(-2), col=self.sigma_p.size(-1), offset=-1).to(device)
tui = torch.triu_indices(row=self.sigma_p.size(-2), col=self.sigma_p.size(-1), offset=1).to(device)
sigma_p = self.sigma_p
m = torch.matmul(sigma_p.transpose(-2, -1), sigma_p).to(device)
sigma = m + torch.diag_embed(0.01 * torch.mean(torch.linalg.eig(m).eigenvalues.real, dim=-1, keepdim=True)
.repeat(1, 1, self.d)).to(device)
if self.cov_type == "diag":
sigma[:, :, tli[0], tli[1]] = 0
sigma[:, :, tui[0], tui[1]] = 0
elif self.cov_type == "tril":
sigma[:, :, tui[0], tui[1]] = 0
# MU - no transformation
mu = self.mu_p
# Initialize normal distribution using sampled MU and SIGMA
if self.cov_type == "full":
self.distribution = MultivariateNormal(mu.flatten(0, 1), covariance_matrix=sigma.flatten(0, 1))
else:
self.distribution = MultivariateNormal(mu.flatten(0, 1), scale_tril=sigma.flatten(0, 1))
def get_achlioptas(n, m, s=3):
"""
Random Projection algorithm for Dimensionality Reduction from Achlioptas 2001 (Microsoft)
https://dl.acm.org/doi/pdf/10.1145/375551.375608
Args:
n: input data dimension
m: output desired dimension
s: 1 / density. Should be greater or equal than 1 (density range is (0, 1]).
Returns: n*m matrix that can be used for random-projection dimensionality reduction.
"""
t = torch.rand(n, m)
t = t.masked_fill(t.greater(1 - 1 / (2 * s)), -1)
t = t.masked_fill(t.greater(1 / (2 * s)), 0)
t = t.masked_fill(t.greater(0), 1)
t = t * math.sqrt(s) / math.sqrt(m)
return t