Skip to content

Commit

Permalink
Remove jax.device_put calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 31, 2023
1 parent ecf328c commit faa28f9
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 53 deletions.
4 changes: 2 additions & 2 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_prox_cg(
mask = np.ones(im.shape) > 0

W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32")
W = jax.device_put(W)
W = snp.array(W)
λ = 0.01

if is_masked:
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_approx_prox(

y = A @ im
W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32")
W = jax.device_put(W)
W = snp.array(W)
λ = 0.01

v, _ = scico.random.normal(im.shape, dtype=im.dtype)
Expand Down
28 changes: 14 additions & 14 deletions scico/test/linop/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_construct(self, jit):
H = VerticalStack([A, B], jit=jit)

# in general, returns a BlockArray
A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((3, 3)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], jit=jit)
x = np.ones((7, 11))
y = H @ x
Expand All @@ -39,8 +39,8 @@ def test_construct(self, jit):
assert np.allclose(y[1], B @ x)

# by default, collapse to jax array when possible
A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((2, 2)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], jit=jit)
x = np.ones((7, 11))
y = H @ x
Expand All @@ -51,8 +51,8 @@ def test_construct(self, jit):
assert np.allclose(y[1], B @ x)

# let user turn off collapsing
A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((2, 2)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], collapse=False, jit=jit)
x = np.ones((7, 11))
y = H @ x
Expand All @@ -62,27 +62,27 @@ def test_construct(self, jit):
@pytest.mark.parametrize("jit", [False, True])
def test_adjoint(self, collapse, jit):
# general case
A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((3, 3)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], collapse=collapse, jit=jit)
adjoint_test(H, self.key)

# collapsable case
A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((2, 2)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], collapse=collapse, jit=jit)
adjoint_test(H, self.key)

@pytest.mark.parametrize("collapse", [False, True])
@pytest.mark.parametrize("jit", [False, True])
def test_algebra(self, collapse, jit):
# adding
A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11))
A = Convolve(snp.ones((2, 2)), (7, 11))
B = Convolve(snp.ones((2, 2)), (7, 11))
H = VerticalStack([A, B], collapse=collapse, jit=jit)

A = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11))
B = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11))
A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11))
B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11))
G = VerticalStack([A, B], collapse=collapse, jit=jit)

x = np.ones((7, 11))
Expand Down
24 changes: 11 additions & 13 deletions scico/test/optimize/test_admm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

import jax

import pytest

import scico.numpy as snp
Expand All @@ -20,7 +18,7 @@
class TestMisc:
def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(16, 17).astype(np.float32))
self.y = snp.array(np.random.randn(16, 17).astype(np.float32))

def test_admm(self):
maxiter = 2
Expand Down Expand Up @@ -112,14 +110,14 @@ def setup_method(self, method):
MB = 5
N = 6
# Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
Amx = np.random.randn(MA, N)
Bmx = np.random.randn(MB, N)
y = np.random.randn(MA)
Amx = np.random.randn(MA, N).astype(np.float32)
Bmx = np.random.randn(MB, N).astype(np.float32)
y = np.random.randn(MA).astype(np.float32)
𝛼 = np.pi # sort of random number chosen to test non-default scale factor
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.𝛼 = 𝛼
self.λ = λ
# Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y
Expand Down Expand Up @@ -219,16 +217,16 @@ def setup_method(self, method):
MB = 5
N = 6
# Set up arrays for problem argmin (𝛼/2) ||A x - y||_W^2 + (λ/2) ||B x||_2^2
Amx = np.random.randn(MA, N)
W = np.abs(np.random.randn(MA, 1))
Bmx = np.random.randn(MB, N)
y = np.random.randn(MA)
Amx = np.random.randn(MA, N).astype(np.float32)
W = np.abs(np.random.randn(MA, 1).astype(np.float32))
Bmx = np.random.randn(MB, N).astype(np.float32)
y = np.random.randn(MA).astype(np.float32)
𝛼 = np.pi # sort of random number chosen to test non-default scale factor
λ = np.e
self.Amx = Amx
self.W = jax.device_put(W)
self.W = snp.array(W)
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.𝛼 = 𝛼
self.λ = λ
# Solution of problem is given by linear system
Expand Down
8 changes: 3 additions & 5 deletions scico/test/optimize/test_ladmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

import jax

import pytest

import scico.numpy as snp
Expand All @@ -13,7 +11,7 @@
class TestMisc:
def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.y = snp.array(np.random.randn(32, 33).astype(np.float32))
self.maxiter = 2
self.μ = 1e-1
self.ν = 1e-1
Expand Down Expand Up @@ -122,7 +120,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x
Expand Down Expand Up @@ -161,7 +159,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x
Expand Down
8 changes: 3 additions & 5 deletions scico/test/optimize/test_padmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

import jax

import pytest

import scico.numpy as snp
Expand All @@ -13,7 +11,7 @@
class TestMisc:
def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.y = snp.array(np.random.randn(32, 33).astype(np.float32))
self.maxiter = 2
self.ρ = 1e0
self.μ = 1e0
Expand Down Expand Up @@ -199,7 +197,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x
Expand Down Expand Up @@ -267,7 +265,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x
Expand Down
8 changes: 3 additions & 5 deletions scico/test/optimize/test_pdhg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

import jax

import pytest

import scico.numpy as snp
Expand All @@ -13,7 +11,7 @@
class TestMisc:
def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.y = snp.array(np.random.randn(32, 33).astype(np.float32))
self.maxiter = 2
self.τ = 1e-1
self.σ = 1e-1
Expand Down Expand Up @@ -128,7 +126,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x
Expand Down Expand Up @@ -189,7 +187,7 @@ def setup_method(self, method):
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = jax.device_put(y)
self.y = snp.array(y)
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x
Expand Down
7 changes: 4 additions & 3 deletions scico/test/optimize/test_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

import scico.numpy as snp
from scico import functional, linop, loss, random
from scico.optimize import PGM, AcceleratedPGM
from scico.optimize.pgm import (
Expand All @@ -20,9 +21,9 @@ def setup_method(self, method):
M = 5
N = 4
# Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
Amx = np.random.randn(M, N)
Amx = np.random.randn(M, N).astype(np.float32)
Bmx = np.identity(N)
y = jax.device_put(np.random.randn(M))
y = snp.array(np.random.randn(M).astype(np.float32))
λ = 1e0
self.Amx = Amx
self.y = y
Expand Down Expand Up @@ -196,7 +197,7 @@ def setup_method(self, method):
# Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||x||_2^2
Amx, key = random.randn((M, N), dtype=np.complex64, key=None)
Bmx = np.identity(N)
y = jax.device_put(np.random.randn(M))
y = snp.array(np.random.randn(M))
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
Expand Down
12 changes: 6 additions & 6 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def setup_method(self, method):

def test_wrap_func_and_grad(self):
N = 8
A = jax.device_put(np.random.randn(N, N))
x = jax.device_put(np.random.randn(N))
A = snp.array(np.random.randn(N, N))
x = snp.array(np.random.randn(N))

f = lambda x: 0.5 * snp.linalg.norm(A @ x) ** 2

Expand Down Expand Up @@ -117,10 +117,10 @@ def test_preconditioned_cg(self):
def test_lstsq_func(self):
N = 24
M = 32
Ac = jax.device_put(np.random.randn(N, M).astype(np.float32))
Ac = snp.array(np.random.randn(N, M).astype(np.float32))
Am = Ac.dot(Ac.T)
A = Am.dot
x = jax.device_put(np.random.randn(N).astype(np.float32))
x = snp.array(np.random.randn(N).astype(np.float32))
b = Am.dot(x)
x0 = snp.zeros((N,), dtype=np.float32)
tol = 1e-6
Expand All @@ -134,9 +134,9 @@ def test_lstsq_func(self):
def test_lstsq_op(self):
N = 32
M = 24
Ac = jax.device_put(np.random.randn(N, M).astype(np.float32))
Ac = snp.array(np.random.randn(N, M).astype(np.float32))
A = linop.MatrixOperator(Ac)
x = jax.device_put(np.random.randn(M).astype(np.float32))
x = snp.array(np.random.randn(M).astype(np.float32))
b = Ac.dot(x)
tol = 1e-7
try:
Expand Down

0 comments on commit faa28f9

Please sign in to comment.