Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add mass feature to nx.kl_div and harmonize kl computation in the toolbox #654

Merged
merged 5 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
## 0.9.5dev

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)

#### Closed issues

42 changes: 29 additions & 13 deletions ot/backend.py
Original file line number Diff line number Diff line change
@@ -944,16 +944,17 @@ def eigh(self, a):
"""
raise NotImplementedError()

def kl_div(self, p, q, eps=1e-16):
def kl_div(self, p, q, mass=False, eps=1e-16):
r"""
Computes the Kullback-Leibler divergence.
Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from :any:`scipy.stats.entropy`.

Parameter eps is used to avoid numerical errors and is added in the log.

.. math::
KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle
+ \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
"""
@@ -1352,8 +1353,11 @@ def sqrtm(self, a):
def eigh(self, a):
return np.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return np.sum(p * np.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = np.sum(p * np.log(p / q + eps))
if mass:
value = value + np.sum(q - p)
return value

def isfinite(self, a):
return np.isfinite(a)
@@ -1751,8 +1755,11 @@ def sqrtm(self, a):
def eigh(self, a):
return jnp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return jnp.sum(p * jnp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = jnp.sum(p * jnp.log(p / q + eps))
if mass:
value = value + jnp.sum(q - p)
return value

def isfinite(self, a):
return jnp.isfinite(a)
@@ -2238,8 +2245,11 @@ def sqrtm(self, a):
def eigh(self, a):
return torch.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return torch.sum(p * torch.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = torch.sum(p * torch.log(p / q + eps))
if mass:
value = value + torch.sum(q - p)
return value

def isfinite(self, a):
return torch.isfinite(a)
@@ -2639,8 +2649,11 @@ def sqrtm(self, a):
def eigh(self, a):
return cp.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return cp.sum(p * cp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = cp.sum(p * cp.log(p / q + eps))
if mass:
value = value + cp.sum(q - p)
return value

def isfinite(self, a):
return cp.isfinite(a)
@@ -3063,8 +3076,11 @@ def sqrtm(self, a):
def eigh(self, a):
return tf.linalg.eigh(a)

def kl_div(self, p, q, eps=1e-16):
return tnp.sum(p * tnp.log(p / q + eps))
def kl_div(self, p, q, mass=False, eps=1e-16):
value = tnp.sum(p * tnp.log(p / q + eps))
if mass:
value = value + tnp.sum(q - p)
return value

def isfinite(self, a):
return tnp.isfinite(a)
4 changes: 2 additions & 2 deletions ot/bregman/_barycenter.py
Original file line number Diff line number Diff line change
@@ -364,7 +364,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
logA = nx.log(A + 1e-16)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
@@ -702,7 +702,7 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
logA = nx.log(A + 1e-16)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
c = nx.zeros(dim, type_as=A)
err = 1
8 changes: 2 additions & 6 deletions ot/coot.py
Original file line number Diff line number Diff line change
@@ -139,10 +139,6 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat
Advances in Neural Information Processing ny_sampstems, 33 (2020).
"""

def compute_kl(p, q):
kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
return kl

# Main function

if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
@@ -245,9 +241,9 @@ def compute_kl(p, q):
coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
# Entropic part
if eps_samp != 0:
coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
coot = coot + eps_samp * nx.kl_div(pi_samp, wxy_samp)
if eps_feat != 0:
coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
coot = coot + eps_feat * nx.kl_div(pi_feat, wxy_feat)
list_coot.append(coot)

if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
8 changes: 4 additions & 4 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a
return a * nx.log(a + 1e-16) - a

def f2(b):
return b
@@ -118,7 +118,7 @@ def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
return nx.log(b + 1e-16)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

@@ -502,7 +502,7 @@ def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a
return a * nx.log(a + 1e-16) - a

def f2(b):
return b
@@ -511,7 +511,7 @@ def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
return nx.log(b + 1e-16)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

16 changes: 16 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -620,3 +620,19 @@ def test_label_normalization(nx):
# labels are shifted but the shift if expected
y_normalized_start = ot.utils.label_normalization(y, start=1)
np.testing.assert_array_equal(y, y_normalized_start)


def test_kl_div(nx):
n = 10
rng = np.random.RandomState(0)
# test on non-negative tensors
x = rng.randn(n)
x = x - x.min() + 1e-5
y = rng.randn(n)
y = y - y.min() + 1e-5
xb = nx.from_numpy(x)
yb = nx.from_numpy(y)
kl = nx.kl_div(xb, yb)
kl_mass = nx.kl_div(xb, yb, True)
recovered_kl = kl_mass - nx.sum(yb - xb)
np.testing.assert_allclose(kl, recovered_kl)