Skip to content

Commit

Permalink
adjust other code for new interface
Browse files Browse the repository at this point in the history
  • Loading branch information
tbalke committed Nov 10, 2021
1 parent 71deb29 commit 1b12a7c
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def postprocess(x):

weights = counts / Io # scale by Io to balance the data vs regularization term
W = linop.Diagonal(snp.sqrt(weights))
f = loss.WeightedSquaredL2Loss(y=y, A=A, weight_op=W)
f = loss.WeightedSquaredL2Loss(y=y, A=A, W=W)

admm_weighted = ADMM(
f=f,
Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 100 # ADMM penalty parameter
σ = density * 0.2 # denoiser sigma

weight_op = Diagonal(weights)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

weight_op = Diagonal(weights)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
2 changes: 1 addition & 1 deletion scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:

if self.admm.f is not None:
if isinstance(self.admm.f, WeightedSquaredL2Loss):
ATWy = self.admm.f.A.adj(self.admm.f.weight_op @ self.admm.f.y)
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal @ self.admm.f.y)
rhs += 2.0 * self.admm.f.scale * ATWy
else:
ATy = self.admm.f.A.adj(self.admm.f.y)
Expand Down
11 changes: 2 additions & 9 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scico.loss import WeightedSquaredL2Loss
from scico.typing import JaxArray, Shape

from ._linop import Diagonal, LinearOperator
from ._linop import LinearOperator


class ParallelBeamProjector(LinearOperator):
Expand Down Expand Up @@ -120,19 +120,12 @@ def __init__(self, *args, **kwargs):
"to instantiate a `SVMBIRWeightedSquaredL2Loss`."
)

if not isinstance(self.weight_op, Diagonal):
raise ValueError(
f"`weight_op` must be `Diagonal` but instead got {type(self.weight_op)}"
)

self.weights = self.weight_op.diagonal

self.has_prox = True

def prox(self, v: JaxArray, lam: float) -> JaxArray:
v = v.reshape(self.A.svmbir_input_shape)
y = self.y.reshape(self.A.svmbir_output_shape)
weights = self.weights.reshape(self.A.svmbir_output_shape)
weights = self.W.diagonal.reshape(self.A.svmbir_output_shape)
sigma_p = snp.sqrt(lam)
result = svmbir.recon(
np.array(y),
Expand Down
2 changes: 1 addition & 1 deletion scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(
if W is None:
self.W = linop.Identity(y.shape)
elif isinstance(W, linop.Diagonal):
if np.all(W.diagonal >= 0):
if snp.all(W.diagonal >= 0):
self.W = W
else:
raise Exception(f"The weights, W.diagonal, must be non-negative.")
Expand Down
4 changes: 2 additions & 2 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d):

# test with weights
weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype)
D = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, weight_op=D)
W = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, W=W)
prox_test(v, f, f.prox, alpha=0.25)
11 changes: 5 additions & 6 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def setup_method(self):
W = 0.1 * W + 1.0
self.Ao = linop.MatrixOperator(A)
self.Do = linop.Diagonal(D)
self.Wo = linop.Diagonal(W)
self.W = linop.Diagonal(W)
self.y, key = randn((n,), key=key, dtype=dtype)
self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
Expand Down Expand Up @@ -377,14 +377,14 @@ def test_squared_l2(self):
pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_weighted_squared_l2(self):
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, weight_op=self.Wo)
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W)
assert L.is_smooth == True
assert L.has_eval == True
assert L.has_prox == False # not diagonal

# test eval
np.testing.assert_allclose(
L(self.v), 0.5 * ((self.Wo @ (self.Ao @ self.v - self.y)) ** 2).sum()
L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L
Expand All @@ -393,16 +393,15 @@ def test_weighted_squared_l2(self):
assert cL(self.v) == self.scalar * L(self.v)

# SquaredL2 with Diagonal linop has a prox
Wo = self.Wo
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, weight_op=Wo)
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W)

assert L_d.is_smooth == True
assert L_d.has_eval == True
assert L_d.has_prox == True

# test eval
np.testing.assert_allclose(
L_d(self.v), 0.5 * ((self.Wo @ (self.Do @ self.v - self.y)) ** 2).sum()
L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L_d
Expand Down

0 comments on commit 1b12a7c

Please sign in to comment.