Skip to content
This repository has been archived by the owner on Aug 31, 2022. It is now read-only.

Commit

Permalink
introduce precondition_fun option to improve resolution of implicit f…
Browse files Browse the repository at this point in the history
…unction theorem differentiation of the outputs of the sinkhorn algorithm. Sinkhorn's first order condition (FOC) amounts to ∇ Energy = (marginal of transport) - (marginal fitting penalty) =0.

This commit introduces the possibility of differentiating instead
precondition_fun(marginal of transport) - precondition_fun(marginal of transport) = 0 instead.

For Sinkhorn, we draw inspiration from https://arxiv.org/pdf/2002.03229.pdf to use by default the function 𝜀 log. This provides a numerically more stable approach.

PiperOrigin-RevId: 399701681
  • Loading branch information
marcocuturi committed Sep 29, 2021
1 parent 8bdd368 commit 400da60
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 83 deletions.
233 changes: 157 additions & 76 deletions ott/core/sinkhorn.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def _sinkhorn_divergence(
chg_momentum_from=0,
anderson_acceleration=0)

# Since symmetric terms are computed assuming a = b, the linear systems
# arising in implicit differentiation (if used) of the potentials computed for
# the symmetric parts should be marked as symmetric.
linear_solve_kwargs = kwargs_symmetric.pop('linear_solve_kwargs', {})
linear_solve_kwargs.update(symmetric=True)
kwargs_symmetric.update(linear_solve_kwargs=linear_solve_kwargs)

out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs)
out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric)
if geometry_yy is None:
Expand Down
2 changes: 1 addition & 1 deletion ott/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

"""Current ott version."""

__version__ = "0.1.15"
__version__ = "0.1.16"
124 changes: 124 additions & 0 deletions tests/core/sinkhorn_differentiability_precond_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Tests for the Jacobian of optimal potential."""
import functools

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.test_util
from ott.tools import transport


class SinkhornJacobianPreconditioningTest(jax.test_util.JaxTestCase):

def setUp(self):
super().setUp()
self.rng = jax.random.PRNGKey(0)

@parameterized.product(
lse_mode=[True, False],
tau_a=[1.0, .94],
tau_b=[1.0, .91],
shape=[(18, 19), (27, 18), (275, 414)],
arg=[0, 1])
def test_potential_jacobian_sinkhorn(self, lse_mode, tau_a, tau_b, shape,
arg):
"""Test Jacobian of optimal potential w.r.t. weights and locations."""
n, m = shape
dim = 3
rngs = jax.random.split(self.rng, 7)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
a = jax.random.uniform(rngs[2], (n,)) + .2
b = jax.random.uniform(rngs[3], (m,)) + .2
a = a / (0.5 * n) if tau_a < 1.0 else a / jnp.sum(a)
b = b / (0.5 * m) if tau_b < 1.0 else b / jnp.sum(b)
random_dir = jax.random.uniform(rngs[4], (n,)) / n
# center projection direction so that < potential , random_dir>
# is invariant w.r.t additive shifts.
random_dir = random_dir - jnp.mean(random_dir)
delta_a = jax.random.uniform(rngs[5], (n,))
if tau_a == 1.0:
delta_a = delta_a - jnp.mean(delta_a)
delta_x = jax.random.uniform(rngs[6], (n, dim))

# As expected, lse_mode False has a harder time with small epsilon when
# differentiating.
epsilon = 0.01 if lse_mode else 0.1

def loss_from_potential(a,
x,
precondition_fun=None,
linear_solve_kwargs=None):
out = transport.Transport(
x,
y,
epsilon=epsilon,
a=a,
b=b,
tau_a=tau_a,
tau_b=tau_b,
lse_mode=lse_mode,
precondition_fun=precondition_fun,
linear_solve_kwargs=linear_solve_kwargs)
return jnp.sum(random_dir * out._f)

# Compute implicit gradient
loss_imp_no_precond = jax.jit(
jax.value_and_grad(
functools.partial(
loss_from_potential,
precondition_fun=lambda x: x,
linear_solve_kwargs={
'symmetric': True
}),
argnums=arg))

loss_imp_log_precond = jax.jit(
jax.value_and_grad(loss_from_potential, argnums=arg))

_, g_imp_np = loss_imp_no_precond(a, x)
imp_dif_np = jnp.sum(g_imp_np * (delta_a if arg == 0 else delta_x))

_, g_imp_lp = loss_imp_log_precond(a, x)
imp_dif_lp = jnp.sum(g_imp_lp * (delta_a if arg == 0 else delta_x))

# Compute finite difference
perturb_scale = 1e-4
a_p = a + perturb_scale * delta_a if arg == 0 else a
x_p = x if arg == 0 else x + perturb_scale * delta_x
a_m = a - perturb_scale * delta_a if arg == 0 else a
x_m = x if arg == 0 else x - perturb_scale * delta_x

val_p, _ = loss_imp_no_precond(a_p, x_p)
val_m, _ = loss_imp_no_precond(a_m, x_m)
fin_dif = (val_p - val_m) / (2 * perturb_scale)
self.assertAllClose(fin_dif, imp_dif_lp, atol=1e-2, rtol=1e-2)
self.assertAllClose(fin_dif, imp_dif_np, atol=1e-2, rtol=1e-2)
self.assertAllClose(imp_dif_np, imp_dif_lp, atol=1e-2, rtol=1e-2)

# center both if balanced problem testing gradient w.r.t weights
if tau_a == 1.0 and tau_b == 1.0 and arg == 0:
g_imp_np = g_imp_np - jnp.mean(g_imp_np)
g_imp_lp = g_imp_lp - jnp.mean(g_imp_lp)

self.assertAllClose(g_imp_np, g_imp_lp, atol=1e-2, rtol=1e-2)

if __name__ == '__main__':
absltest.main()
4 changes: 2 additions & 2 deletions tests/core/sinkhorn_differentiability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

# Lint as: python3
"""Tests for the Policy."""
"""Tests for the differentiability of reg_ot_cost w.r.t weights/locations."""

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -26,7 +26,7 @@
from ott.geometry import pointcloud


class SinkhornGradTest(jax.test_util.JaxTestCase):
class SinkhornJacobianTest(jax.test_util.JaxTestCase):

def setUp(self):
super().setUp()
Expand Down
6 changes: 2 additions & 4 deletions tests/core/sinkhorn_hessian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def test_hessian_sinkhorn(self, lse_mode, tau_a, tau_b, shape, arg):
"""Test hessian w.r.t. weights and locations."""
eps = 1e-3
n, m = shape
# use slightly different parameter to test linear_solve_kwargs
linear_solve_kwargs = {'ridge_kernel': 1.2e-4, 'ridge_identity': .9e-4}

dim = 3
rngs = jax.random.split(self.rng, 6)
Expand All @@ -64,8 +62,8 @@ def loss(a, x, implicit):
lse_mode=lse_mode,
implicit_differentiation=implicit,
use_danskin=False,
linear_solve_kwargs=linear_solve_kwargs,
threshold=1e-5)
threshold=1e-4,
linear_solve_kwargs={'ridge_kernel': 1e-4, 'ridge_identity': 1e-4})
return out.reg_ot_cost

delta_a = jax.random.uniform(rngs[4], (n,))
Expand Down

0 comments on commit 400da60

Please sign in to comment.