Skip to content

Commit

Permalink
[REFACTOR] Minor refactor that checks the GPU on the last line
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Muñoz committed Oct 14, 2023
1 parent bfaf2d1 commit 1966448
Showing 1 changed file with 59 additions and 52 deletions.
111 changes: 59 additions & 52 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,15 @@ def test_sinkhorn2_variants_device_tf(method):
nx.assert_same_dtype_device(Mb, Gb)
nx.assert_same_dtype_device(Mb, lossb)

# Check that everything happens on the GPU
ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
nx.assert_same_dtype_device(Mb, lossb)

# Check this only if GPU is available
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
nx.assert_same_dtype_device(Mb, lossb)
assert nx.dtype_device(Gb)[1].startswith("GPU")


Expand Down Expand Up @@ -805,7 +807,7 @@ def test_wasserstein_bary_2d_device_tf(method):

# wasserstein
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
Expand All @@ -826,32 +828,34 @@ def test_wasserstein_bary_2d_device_tf(method):
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
Ab = nx.from_numpy(A)
# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)
# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check this only if GPU is available
if len(tf.config.list_physical_devices('GPU')) > 0:
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")


Expand Down Expand Up @@ -970,7 +974,7 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):

# wasserstein
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
Expand All @@ -991,32 +995,35 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)
# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check this only if there is a GPU
if len(tf.config.list_physical_devices('GPU')) > 0:
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")


Expand Down

0 comments on commit 1966448

Please sign in to comment.