-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
65 lines (57 loc) · 2.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
def sample_simplex(d):
'''Return one sample uniformly on the d-dimensional simplex.'''
Exp = -np.log(np.random.uniform(size=d))
return Exp/np.sum(Exp)
def bures_wasserstein(mean_1, mean_2, cov_1, cov_2):
'''Return the OT distance between two Gaussian distributions N(mean_1,cov_1) and N(mean_1,cov_2).'''
assert mean_1.shape == mean_2.shape
assert cov_1.shape == cov_2.shape
d = mean_1.shape[0]
assert cov_2.shape == (d,d)
e,v = np.linalg.eigh(cov_1)
e[e<0]=0.
sqrt_1 = v.dot(np.diag(np.sqrt(e))).dot(v.T)
cross = sqrt_1.dot(cov_2).dot(sqrt_1)
e,v = np.linalg.eigh(cross)
e[e<0]=0.
cross = v.dot(np.diag(np.sqrt(e))).dot(v.T)
return np.linalg.norm(mean_1-mean_2)**2 + np.trace(cov_1 + cov_2 - 2*cross)
def euclidean_proj_simplex(v, s=1):
""" Compute the Euclidean projection on a positive simplex."""
# Adrien Gaidon - INRIA - 2011
assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
n, = v.shape # will raise ValueError if v is not 1-D
# check if we are already on the simplex
if v.sum() == s and np.alltrue(v >= 0):
# best projection: itself!
return v
# get the array of cumulative sums of a sorted (decreasing) copy of v
u = np.sort(v)[::-1]
cssv = np.cumsum(u)
# get the number of > 0 components of the optimal solution
rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
# compute the Lagrange multiplier associated to the simplex constraint
theta = (cssv[rho] - s) / (rho + 1.0)
# compute the projection by thresholding v using theta
w = (v - theta).clip(min=0)
return w
def projection_Omega(matrix, k, max_iter_Dykstra=10):
'''Project the matrix onto {0 <= Omega <= I with Trace(Omega)=k}, using Dykstra's projection algorithm.'''
d = matrix.shape[0]
matrix = 0.5*(matrix + matrix.T)
e,v = np.linalg.eigh(matrix)
x = e
p = np.zeros(d)
q = np.zeros(d)
for _ in range(max_iter_Dykstra):
y = x + p
y[y>1] = 1.
p = x + p - y
x = euclidean_proj_simplex(y+q, s=k)
q = y + q - x
x[x<0.] = 0.
return v.dot(np.diag(x)).dot(v.T)