Skip to content

Commit

Permalink
Remove problematic array copy operations (#156)
Browse files Browse the repository at this point in the history
* Remove problematic array copy operations

* Add tests for BlockArray variables
  • Loading branch information
bwohlberg authored Jan 6, 2022
1 parent ffec40b commit 305b744
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 50 deletions.
4 changes: 2 additions & 2 deletions scico/optimize/_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def z_init(self, x0: Union[JaxArray, BlockArray]):
x0: Starting point for :math:`\mb{x}`.
"""
z = self.C(x0)
z_old = z.copy()
z_old = z
return z, z_old

def u_init(self, x0: Union[JaxArray, BlockArray]):
Expand Down Expand Up @@ -297,7 +297,7 @@ def step(self):
proxarg = self.x - (self.mu / self.nu) * self.C.conj().T(self.C(self.x) - self.z + self.u)
self.x = self.f.prox(proxarg, self.mu, v0=self.x)

self.z_old = self.z.copy()
self.z_old = self.z
Cx = self.C(self.x)
self.z = self.g.prox(Cx + self.u, self.nu, v0=self.z)
self.u = self.u + Cx - self.z
Expand Down
8 changes: 4 additions & 4 deletions scico/optimize/_primaldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def __init__(
dtype = C.input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = ensure_on_device(x0)
self.x_old = self.x.copy()
self.x_old = self.x
if z0 is None:
input_shape = C.output_shape
dtype = C.output_dtype
z0 = snp.zeros(input_shape, dtype=dtype)
self.z = ensure_on_device(z0)
self.z_old = self.z.copy()
self.z_old = self.z

def objective(
self,
Expand Down Expand Up @@ -232,8 +232,8 @@ def norm_dual_residual(self) -> float:

def step(self):
"""Perform a single iteration."""
self.x_old = self.x.copy()
self.z_old = self.z.copy()
self.x_old = self.x
self.z_old = self.z
proxarg = self.x - self.tau * self.C.conj().T(self.z)
self.x = self.f.prox(proxarg, self.tau, v0=self.x)
proxarg = self.z + self.sigma * self.C(
Expand Down
87 changes: 65 additions & 22 deletions scico/test/optimize/test_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, random
from scico.blockarray import BlockArray
from scico.optimize import LinearizedADMM


Expand All @@ -12,42 +13,51 @@ def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.λ = 1e0

def test_ladmm(self):
maxiter = 2
μ = 1e-1
ν = 1e-1
A = linop.Identity(self.y.shape)
f = loss.SquaredL2Loss(y=self.y, A=A)
g = (self.λ / 2) * functional.BM3D()
C = linop.Identity(self.y.shape)

self.maxiter = 2
self.μ = 1e-1
self.ν = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.BM3D()
self.C = linop.Identity(self.y.shape)

def test_itstat(self):
itstat_fields = {"Iter": "%d", "Time": "%8.2e"}

def itstat_func(obj):
return (obj.itnum, obj.timer.elapsed())

ladmm_ = LinearizedADMM(
f=f,
g=g,
C=C,
mu=μ,
nu=ν,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
assert len(ladmm_.itstat_object.fieldname) == 4
assert snp.sum(ladmm_.x) == 0.0

ladmm_ = LinearizedADMM(
f=f,
g=g,
C=C,
mu=μ,
nu=ν,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False},
)
assert len(ladmm_.itstat_object.fieldname) == 2

def test_callback(self):
ladmm_ = LinearizedADMM(
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
ladmm_.test_flag = False

def callback(obj):
Expand All @@ -57,6 +67,39 @@ def callback(obj):
assert ladmm_.test_flag


class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
17,
).astype(np.float32),
)
)
self.λ = 1e0
self.maxiter = 1
self.μ = 1e-1
self.ν = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.L2Norm()
self.C = linop.Identity(self.y.shape)

def test_blockarray(self):
ladmm_ = LinearizedADMM(
f=self.f,
g=self.g,
C=self.C,
mu=self.μ,
nu=self.ν,
maxiter=self.maxiter,
)
x = ladmm_.solve()
assert isinstance(x, BlockArray)


class TestReal:
def setup_method(self, method):
np.random.seed(12345)
Expand Down
87 changes: 65 additions & 22 deletions scico/test/optimize/test_pdhg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, random
from scico.blockarray import BlockArray
from scico.optimize import PDHG


Expand All @@ -12,42 +13,51 @@ def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.λ = 1e0

def test_pdhg(self):
maxiter = 2
τ = 1e-1
σ = 1e-1
A = linop.Identity(self.y.shape)
f = loss.SquaredL2Loss(y=self.y, A=A)
g = (self.λ / 2) * functional.BM3D()
C = linop.Identity(self.y.shape)

self.maxiter = 2
self.τ = 1e-1
self.σ = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.BM3D()
self.C = linop.Identity(self.y.shape)

def test_itstat(self):
itstat_fields = {"Iter": "%d", "Time": "%8.2e"}

def itstat_func(obj):
return (obj.itnum, obj.timer.elapsed())

pdhg_ = PDHG(
f=f,
g=g,
C=C,
tau=τ,
sigma=σ,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
assert len(pdhg_.itstat_object.fieldname) == 4
assert snp.sum(pdhg_.x) == 0.0

pdhg_ = PDHG(
f=f,
g=g,
C=C,
tau=τ,
sigma=σ,
maxiter=maxiter,
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False},
)
assert len(pdhg_.itstat_object.fieldname) == 2

def test_callback(self):
pdhg_ = PDHG(
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
pdhg_.test_flag = False

def callback(obj):
Expand All @@ -57,6 +67,39 @@ def callback(obj):
assert pdhg_.test_flag


class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
17,
).astype(np.float32),
)
)
self.λ = 1e0
self.maxiter = 1
self.τ = 1e-1
self.σ = 1e-1
self.A = linop.Identity(self.y.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g = (self.λ / 2) * functional.L2Norm()
self.C = linop.Identity(self.y.shape)

def test_blockarray(self):
pdhg_ = PDHG(
f=self.f,
g=self.g,
C=self.C,
tau=self.τ,
sigma=self.σ,
maxiter=self.maxiter,
)
x = pdhg_.solve()
assert isinstance(x, BlockArray)


class TestReal:
def setup_method(self, method):
np.random.seed(12345)
Expand Down

0 comments on commit 305b744

Please sign in to comment.