Skip to content

Commit

Permalink
[WIP] add mass feature to nx.kl_div and harmonize kl computation in t…
Browse files Browse the repository at this point in the history
…he toolbox (#654)

* add mass feature to nx.kl_div

* test

* test

* fix tipo doc

* fix jax
  • Loading branch information
cedricvincentcuaz authored Jul 10, 2024
1 parent e530985 commit 24ad25c
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 25 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 0.9.5dev

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)

#### Closed issues

Expand Down
42 changes: 29 additions & 13 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,16 +944,17 @@ def eigh(self, a):
"""
raise NotImplementedError()

def kl_div(self, p, q, eps=1e-16):
def kl_div(self, p, q, mass=False, eps=1e-16):
r"""
Computes the Kullback-Leibler divergence.
Computes the (Generalized) Kullback-Leibler divergence.
This function follows the api from :any:`scipy.stats.entropy`.
Parameter eps is used to avoid numerical errors and is added in the log.
.. math::
KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle
+ \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
"""
Expand Down Expand Up @@ -1352,8 +1353,11 @@ def sqrtm(self, a):
def eigh(self, a):
return np.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return np.sum(p * np.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = np.sum(p * np.log(p / q + eps))
if mass:
value = value + np.sum(q - p)
return value

def isfinite(self, a):
return np.isfinite(a)
Expand Down Expand Up @@ -1751,8 +1755,11 @@ def sqrtm(self, a):
def eigh(self, a):
return jnp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return jnp.sum(p * jnp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = jnp.sum(p * jnp.log(p / q + eps))
if mass:
value = value + jnp.sum(q - p)
return value

def isfinite(self, a):
return jnp.isfinite(a)
Expand Down Expand Up @@ -2238,8 +2245,11 @@ def sqrtm(self, a):
def eigh(self, a):
return torch.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return torch.sum(p * torch.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = torch.sum(p * torch.log(p / q + eps))
if mass:
value = value + torch.sum(q - p)
return value

def isfinite(self, a):
return torch.isfinite(a)
Expand Down Expand Up @@ -2639,8 +2649,11 @@ def sqrtm(self, a):
def eigh(self, a):
return cp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return cp.sum(p * cp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = cp.sum(p * cp.log(p / q + eps))
if mass:
value = value + cp.sum(q - p)
return value

def isfinite(self, a):
return cp.isfinite(a)
Expand Down Expand Up @@ -3063,8 +3076,11 @@ def sqrtm(self, a):
def eigh(self, a):
return tf.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return tnp.sum(p * tnp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = tnp.sum(p * tnp.log(p / q + eps))
if mass:
value = value + tnp.sum(q - p)
return value

def isfinite(self, a):
return tnp.isfinite(a)
Expand Down
4 changes: 2 additions & 2 deletions ot/bregman/_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
logA = nx.log(A + 1e-16)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
Expand Down Expand Up @@ -702,7 +702,7 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
logA = nx.log(A + 1e-16)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
c = nx.zeros(dim, type_as=A)
err = 1
Expand Down
8 changes: 2 additions & 6 deletions ot/coot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat
Advances in Neural Information Processing ny_sampstems, 33 (2020).
"""

def compute_kl(p, q):
kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
return kl

# Main function

if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
Expand Down Expand Up @@ -245,9 +241,9 @@ def compute_kl(p, q):
coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
# Entropic part
if eps_samp != 0:
coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
coot = coot + eps_samp * nx.kl_div(pi_samp, wxy_samp)
if eps_feat != 0:
coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
coot = coot + eps_feat * nx.kl_div(pi_feat, wxy_feat)
list_coot.append(coot)

if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
Expand Down
8 changes: 4 additions & 4 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a
return a * nx.log(a + 1e-16) - a

def f2(b):
return b
Expand All @@ -118,7 +118,7 @@ def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
return nx.log(b + 1e-16)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Expand Down Expand Up @@ -502,7 +502,7 @@ def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a
return a * nx.log(a + 1e-16) - a

def f2(b):
return b
Expand All @@ -511,7 +511,7 @@ def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
return nx.log(b + 1e-16)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Expand Down
16 changes: 16 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,19 @@ def test_label_normalization(nx):
# labels are shifted but the shift if expected
y_normalized_start = ot.utils.label_normalization(y, start=1)
np.testing.assert_array_equal(y, y_normalized_start)


def test_kl_div(nx):
n = 10
rng = np.random.RandomState(0)
# test on non-negative tensors
x = rng.randn(n)
x = x - x.min() + 1e-5
y = rng.randn(n)
y = y - y.min() + 1e-5
xb = nx.from_numpy(x)
yb = nx.from_numpy(y)
kl = nx.kl_div(xb, yb)
kl_mass = nx.kl_div(xb, yb, True)
recovered_kl = kl_mass - nx.sum(yb - xb)
np.testing.assert_allclose(kl, recovered_kl)

0 comments on commit 24ad25c

Please sign in to comment.