Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced gpu<->cpu copies with host_callback in minimize. #253

Merged
merged 20 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4c8740e
Replaced gpu<->cpu copies with host_callback in minimize.
FernandoDavis Mar 16, 2022
fa381b2
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
FernandoDavis Mar 16, 2022
893fd99
Fixed minimize method in solver using hcb and side effects. Fixed fai…
FernandoDavis Mar 27, 2022
d2c40b7
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
FernandoDavis Mar 27, 2022
f1b735e
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
bwohlberg Apr 12, 2022
6dfed69
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
bwohlberg May 9, 2022
eafa25b
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
bwohlberg Jul 14, 2022
cd954be
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
FernandoDavis Sep 5, 2022
c995d78
Fixed hcb call with device. Solver tests are fixed and should test de…
FernandoDavis Sep 5, 2022
e3ebdc4
Corrected syntax for fun. Added device logic for call_with_device. Ad…
FernandoDavis Sep 7, 2022
65011c7
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
bwohlberg Sep 7, 2022
2b58a56
Merge branch 'FernandoDavis/Manual-gpu-copy-to-HCB' of github.com:lan…
FernandoDavis Sep 9, 2022
187eb00
Fixed device issues. Added test cases for devices and cleaned up code.
FernandoDavis Sep 9, 2022
642dc34
Merge branch 'main' into FernandoDavis/Manual-gpu-copy-to-HCB
bwohlberg Sep 13, 2022
d7ca4ba
Fixed hcb call to properly handle device management for minimize. Sol…
FernandoDavis Sep 13, 2022
95499e2
Merge branch 'FernandoDavis/Manual-gpu-copy-to-HCB' of github.com:lan…
FernandoDavis Sep 13, 2022
5841b89
Cleaned up comments.
FernandoDavis Sep 13, 2022
ad1508e
Fixing dtype of result from hcb float32 vs float64
FernandoDavis Sep 13, 2022
ef3919d
Returned proper reshaped x0 object in hcb
FernandoDavis Sep 13, 2022
d6e0d30
Fixing lint/black.
FernandoDavis Sep 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import numpy as np

import jax
import jax.experimental.host_callback as hcb

import scico.numpy as snp
from scico.numpy import BlockArray
Expand Down Expand Up @@ -210,11 +211,6 @@ def minimize(
x0_shape = x0.shape
x0_dtype = x0.dtype
x0 = x0.ravel() # if x0 is a BlockArray it will become a DeviceArray here
if isinstance(x0, jax.interpreters.xla.DeviceArray):
dev = x0.device_buffer.device() # device for x0; used to put result back in place
x0 = np.array(x0).astype(float)
else:
dev = None

# Run the SciPy minimizer
if method in (
Expand All @@ -229,13 +225,28 @@ def minimize(
min_func = _wrap_func(func_, x0_shape, x0_dtype)
jac = False

res = spopt.minimize(
min_func,
x0=x0,
args=args,
jac=jac,
method=method,
options=options,
res = Optional[Any]

def fun(x0):
nonlocal res # To use the external res and update side effect
res = spopt.minimize(
min_func,
x0=x0,
args=args,
jac=jac,
method=method,
options=options,
) # Returns OptimizeResult
return res.x.astype(x0.dtype) # Return for host_callback

if isinstance(x0, jax.interpreters.xla.DeviceArray):
dev = x0.device_buffer.device() # Save to return x0 to its proper device later

# hcb call with side effects to get the OptimizeResult on the same device it was called
hcb.call(
fun,
arg=x0,
result_shape=x0,
)
FernandoDavis marked this conversation as resolved.
Show resolved Hide resolved

# un-vectorize the output array, put on device
Expand All @@ -245,8 +256,8 @@ def minimize(

res.x = res.x.astype(x0_dtype)
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved

if dev:
res.x = jax.device_put(res.x, dev)
if isinstance(x0, jax.interpreters.xla.DeviceArray):
res.x = jax.device_put(res.x, dev) # Return x0/res.x to its device
FernandoDavis marked this conversation as resolved.
Show resolved Hide resolved

if iscomplex:
res.x = _join_real_imag(res.x)
Expand Down
24 changes: 21 additions & 3 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,24 @@ def f(x):
assert out.x.shape == x.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# Check if minimize returns the object to the proper device
devices = jax.devices()

# For default device:
x0 = jax.device_put(snp.zeros_like(x), devices[0])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device_buffer.device() == devices[0]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# If there are more than one device present:
if len(devices) > 1:
x0 = jax.device_put(snp.zeros_like(x), devices[1])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device_buffer.device() == devices[1]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)


def test_split_join_array():
x, key = random.randn((4, 4), dtype=np.complex64)
Expand Down Expand Up @@ -206,22 +224,22 @@ def test_split_join_blockarray():


def test_bisect():
f = lambda x: x**3
f = lambda x: x ** 3
x, info = solver.bisect(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True)
assert snp.sum(snp.abs(x)) == 0.0
assert info["iter"] == 0
x = solver.bisect(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5, ftol=1e-5)
assert snp.max(snp.abs(x)) <= 1e-5
assert snp.max(snp.abs(f(x))) <= 1e-5
c, key = random.randn((5, 1), dtype=np.float32)
f = lambda x, c: x**3 - c**3
f = lambda x, c: x ** 3 - c ** 3
x = solver.bisect(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5, ftol=1e-5)
assert snp.max(snp.abs(x - c)) <= 1e-5
assert snp.max(snp.abs(f(x, c))) <= 1e-5


def test_golden():
f = lambda x: x**2
f = lambda x: x ** 2
x, info = solver.golden(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True)
assert snp.max(snp.abs(x)) <= 1e-7
x = solver.golden(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5)
Expand Down