diff --git a/scico/solver.py b/scico/solver.py index 663d92dbc..a20e02261 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -63,6 +63,7 @@ import numpy as np import jax +import jax.experimental.host_callback as hcb import scico.numpy as snp from scico.numpy import BlockArray @@ -210,11 +211,6 @@ def minimize( x0_shape = x0.shape x0_dtype = x0.dtype x0 = x0.ravel() # if x0 is a BlockArray it will become a DeviceArray here - if isinstance(x0, jax.interpreters.xla.DeviceArray): - dev = x0.device_buffer.device() # device for x0; used to put result back in place - x0 = np.array(x0).astype(float) - else: - dev = None # Run the SciPy minimizer if method in ( @@ -229,25 +225,32 @@ def minimize( min_func = _wrap_func(func_, x0_shape, x0_dtype) jac = False - res = spopt.minimize( - min_func, - x0=x0, - args=args, - jac=jac, - method=method, - options=options, + res = spopt.OptimizeResult({"x": None}) + + def fun(x0): + nonlocal res # To use the external res and update side effect + res = spopt.minimize( + min_func, + x0=x0, + args=args, + jac=jac, + method=method, + options=options, + ) # Returns OptimizeResult with x0 as ndarray + return res.x.astype(x0_dtype) + + # HCB call with side effects to get the OptimizeResult on the same device it was called + res.x = hcb.call( + fun, + arg=x0, + result_shape=x0, # From Jax-docs: This can be an object that has .shape and .dtype attributes ) - # un-vectorize the output array, put on device + # un-vectorize the output array from spopt.minimize res.x = snp.reshape( res.x, x0_shape ) # if x0 was originally a BlockArray then res.x is converted back to one here - res.x = res.x.astype(x0_dtype) - - if dev: - res.x = jax.device_put(res.x, dev) - if iscomplex: res.x = _join_real_imag(res.x) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index c40aa7afd..c6a4f9f67 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -179,6 +179,24 @@ def f(x): assert out.x.shape == x.shape np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) + # Check if minimize returns the object to the proper device + devices = jax.devices() + + # For default device: + x0 = jax.device_put(snp.zeros_like(x), devices[0]) + out = solver.minimize(f, x0=x0, method=method) + assert out.x.device() == devices[0] + assert out.x.shape == x0.shape + np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) + + # If there are more than one device present: + if len(devices) > 1: + x0 = jax.device_put(snp.zeros_like(x), devices[1]) + out = solver.minimize(f, x0=x0, method=method) + assert out.x.device() == devices[1] + assert out.x.shape == x0.shape + np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) + def test_split_join_array(): x, key = random.randn((4, 4), dtype=np.complex64)