Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deconvolution using Convolution Forward Operator with PDHG throwing _rmul_ not defined error #442

Closed
shnaqvi opened this issue Sep 7, 2023 · 9 comments
Assignees
Labels
question Further information is requested

Comments

@shnaqvi
Copy link
Contributor

shnaqvi commented Sep 7, 2023

Setting it up as a standalone question for making PDHG work with my simple deblur problem.

I've setup synthetic image and blurred it with an anisotropic Gaussian kernel. Now the optimization problem I'm trying to solve is:
image

import numpy as np
import matplotlib.pyplot as plt
import jax

#======
# SYNTHETIC DATA

# Create Synthetic Horizontal Stripes Pattern Image
im_s = np.zeros((2748, 3840)).astype(float)
stripe_width, stripe_gap, stripe_start, stripe_end = 50, 50, 500, 500
for y in range(0, im_s.shape[0]-(stripe_start+stripe_end), stripe_width + stripe_gap):
    im_s[stripe_start+y : stripe_start+y + 50, :] = .8

xx, yy = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
im_ctr = (np.array(im_s.shape)/2).astype(int)
r = np.sqrt((xx - im_ctr[0])**2 + (yy - im_ctr[1])**2)
mask = np.zeros_like(im_s).astype(float)
mask[r < 1000] = 1
im_s *= mask
plt.subplot(131); plt.imshow(im_s); plt.title('Ground Truth');

# Create Gaussian Kernel
from scipy.stats import multivariate_normal
x, y = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
pos = np.dstack((x, y))
rv = multivariate_normal.pdf(pos, im_ctr, [[200, 0], [300, 500]])
psf2 = rv/np.max(rv)
psf2_cropped = psf2[im_ctr[0]-100:im_ctr[0]+101, im_ctr[1]-100:im_ctr[1]+101]
plt.subplot(132); plt.imshow(psf2_cropped); plt.title('PSF zoomed in');
# Convolve and Create Blurred Image
im_jx = jax.device_put(im_sb) 
psf_jx = jax.device_put(psf2_cropped)  

C = linop.CircularConvolve(h=psf_jx, input_shape=im_jx.shape, h_center=[psf_jx.shape[0] // 2, psf_jx.shape[1] // 2])
Cx = C(im_jx)
   
plt.subplot(133); plt.imshow(Cx); plt.title('Blurred Image')
image

Using PDHG throws this TypeError that the Operation __rmul__ not defined between and . I'm getting this error when passing forward operator, A, into the loss.SquaredL2Loss() that gets passed to PDHG constructor. Can you please help navigate this error?

from scico import linop, loss, functional
from scico.optimize.admm import ADMM, CircularConvolveSolver
from scico.util import device_info

#======
# SOLVER
f = loss.SquaredL2Loss(y=Cx, A=C)
lbd = 5e-1#50  # L1 norm regularization parameter
g = lbd * functional.L21Norm()
D = linop.FiniteDifference(input_shape=im_jx.shape, circular=True)

maxiter = 50
tau, sigma = PDHG.estimate_parameters(D, factor=1.5)
solver_pdhg = PDHG(
    f=f,
    g=g,
    C=D,
    tau=tau,
    sigma=sigma,
    maxiter=maxiter,
    itstat_options={"display": True, "period": 10},
)

print(f"Solving on {device_info()}\n")
x = solver_pdhg.solve()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[47], line 39
     28 solver_pdhg = PDHG(
     29     f=f,
     30     g=g,
   (...)
     35     itstat_options={"display": True, "period": 10},
     36 )
     38 print(f"Solving on {device_info()}\n")
---> 39 x = solver_pdhg.solve()
     40 hist = solver.itstat_object.history(transpose=True)
     42 plt.subplot(121); plt.imshow(im_jx); plt.title('Blurred image')

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198), in Optimizer.solve(self, callback)
    196 self.timer.start()
    197 for self.itnum in range(self.itnum, self.itnum + self.maxiter):
--> 198     self.step()
    199     if self.nanstop and not self._working_vars_finite():
    200         raise ValueError(
    201             f"NaN or Inf value encountered in working variable in iteration {self.itnum}."
    202             ""
    203         )

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_primaldual.py:227](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_primaldual.py:227), in PDHG.step(self)
...
     50 if np.isscalar(b) or isinstance(b, jax.core.Tracer):
     51     return func(a, b)
---> 53 raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

TypeError: Operation __rmul__ not defined between  and .
@bwohlberg bwohlberg self-assigned this Sep 8, 2023
@bwohlberg bwohlberg added the bug Something isn't working label Sep 8, 2023
@bwohlberg
Copy link
Collaborator

You seem to have found a bug. Until we fix it, you should be able to get it working by setting tau = float(tau), and the same for sigma. But this is going to be very slow for this problem because one of the algorithm steps computes the proximal operator of the l2 loss with a non-trivial forward operator, which is currently computed via CG. This could be solved more efficiently for your specific type of forward operator, but that's not currently implemented.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Sep 8, 2023

Thanks @bwohlberg! I was able to unblock by making that suggested change. But yeah it is really slow, taking a 1-2 minutes per iteration, compared to ADMM which completed 20 iterations in 50sec.

I didn't follow that why do you say that the forward operator is non-trivial. We're using linop.CircularConvolve as the forward operator not a custom operator. Do you mean that scico doesn't have the proximal operator defined for the l2 loss with linop.CircularConvolve forward operator? Does the proximal operator for l2 loss change with the forward operator used? So if I define a custom forward operator for another space-variant problem I'm interested in, I would have to define a proximal operator for l2 loss with that operator?

@bwohlberg
Copy link
Collaborator

"Non-trivial" was perhaps a bit too vague. The prox of the l2 loss is cheap to compute when the forward operator is an identity or diagonal operator. Any other linear operator is currently solved via CG, which is very expensive. It should be straightforward to add support for a fast solution for linop.CircularConvolve, but it still has to be implemented. (Feel free to open a separate issue with that as a feature request.) And yes, if you want to use a custom forward operator within the same optimization framework, the options are to live with the slow CG solution or extend the prox to support a fast solution for your custom operator. A better alternative would probably be to use scico.optimize.ProximalADMM, whcih supports a variable splitting that would result in the l2 prox being computed without a forward operator.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Sep 9, 2023

Thanks @bwohlberg, I tried ProximalADMM as well replacing the corresponding solver code snippet with the following:

#======
# SOLVER

f = loss.SquaredL2Loss(y=Cx, A=C)
lbd = 5e-1#50  # L1 norm regularization parameter
g = lbd * functional.L21Norm()
D = linop.FiniteDifference(input_shape=im_jx.shape, circular=True)

maxiter = 20  

mu, nu = ProximalADMM.estimate_parameters(D)
solver_padmm = ProximalADMM(
    f=f,
    g=g,
    A=C,
    rho=1e0,
    mu=mu,
    nu=nu,
    x0=Cx,
    maxiter=50,
    itstat_options={"display": True, "period": 10},
)
print("\nProximal ADMM solver")
solver_padmm.solve()
hist_padmm = solver_padmm.itstat_object.history(transpose=True)

But I'm getting the same TypeError involving __rmul__:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[64], line 40
     28 solver_padmm = ProximalADMM(
     29     f=f,
     30     g=g,
   (...)
     37     itstat_options={"display": True, "period": 10},
     38 )
     39 print("\nProximal ADMM solver")
---> 40 solver_padmm.solve()
     41 hist_padmm = solver_padmm.itstat_object.history(transpose=True)
     43 plt.subplot(121); plt.imshow(im_jx); plt.title('Blurred image')

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198), in Optimizer.solve(self, callback)
    196 self.timer.start()
    197 for self.itnum in range(self.itnum, self.itnum + self.maxiter):
--> 198     self.step()
    199     if self.nanstop and not self._working_vars_finite():
    200         raise ValueError(
    201             f"NaN or Inf value encountered in working variable in iteration {self.itnum}."
    202             ""
    203         )

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_padmm.py:351](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_padmm.py:351), in ProximalADMM.step(self)
...
     50 if np.isscalar(b) or isinstance(b, jax.core.Tracer):
     51     return func(a, b)
---> 53 raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

TypeError: Operation __rmul__ not defined between  and .

Is this a bug too? Can you please help me getting this to run, so I can try out the more complex forward operator with sum over convolutions? Also, what about PGM, does that also allow for l2 prox be computed without a forward operator?

@bwohlberg
Copy link
Collaborator

This is probably the same bug as before, and should be resolveable in the same way.

Note, though, that you haven't set the problem up in a way that avoids the expensive l2 prox. Take a look at this example to see how it should be done.

@bwohlberg
Copy link
Collaborator

Changing label for this issue since the bug component has been moved to #445.

@bwohlberg bwohlberg added question Further information is requested and removed bug Something isn't working labels Sep 9, 2023
@shnaqvi
Copy link
Contributor Author

shnaqvi commented Sep 9, 2023

Thanks, I'm trying to reformulate the problem with ProximalADMM.

In the meantime, can you please tell if the PGM also requires the expensive l2 prox computation?

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Sep 9, 2023

For my space-variant deconvolution problem, thanks @bwohlberg to your mockup of the forward operator formulation, I was able to write my forward model as the composition of several operators 👍!

And thanks alot for assisting with ProximalADMM speedup as well, I was able to formulate my problem in the way that's suggested in your above-referenced example:

im_jx = jax.device_put(im) #im_s  

u = U[:, :U.shape[1]].T.reshape((U.shape[1], *im_jx.shape))
w = W[:U.shape[1]].reshape((U.shape[1], *im_jx.shape))

u_jx = jax.device_put(u)  
w_jx = jax.device_put(w)

C = linop.CircularConvolve(h=u, input_shape=im_jx.shape, ndims=2)
D = linop.Diagonal(w_jx)
S = linop.Sum(input_shape=D.output_shape, axis=0)

Blur = S @ D @ C

# ProximalADMM
f = functional.ZeroFunctional()
g0 = loss.SquaredL2Loss(y=im_jx)
lbd = 5.0e-1  # L1 norm regularization parameter
g1 = lbd * functional.L21Norm()
g = functional.SeparableFunctional((g0, g1))

D = linop.FiniteDifference(input_shape=im_jx.shape, circular=True)#, append=0)
A = linop.VerticalStack((Blur, D))

maxiter = 20  # number of ADMM iterations

mu, nu = ProximalADMM.estimate_parameters(D)
solver_padmm = ProximalADMM(
    f=f,
    g=g,
    A=A,
    rho=5e1,
    mu=float(mu),
    nu=float(nu),
    x0=im_jx,
    maxiter=maxiter,
    itstat_options={"display": True, "period": 10},
)
print("\nProximal ADMM solver")
solver_padmm.solve()
hist_padmm = solver_padmm.itstat_object.history(transpose=True)

Now, the solver executes but rapidly diverges to inf, with rho=5e1 or 5e2. Would you please recommend what range of values for rho can I try?

Also, I'm getting my Python kernel to crash upon re-run as it is taking up huge memory space on device (after 14 iterations, python kernel was taking up a whopping 23GB of memory). Do you know if and how that can be mitigated?

Proximal ADMM solver
Iter  Time      Objective  Prml Rsdl  Dual Rsdl
-----------------------------------------------
   0  2.79e+01  2.886e+13  9.116e+04  7.597e+06
  10  1.59e+02        nan        nan        nan
  14  2.11e+02        nan        nan        nan

@bwohlberg
Copy link
Collaborator

Closing since remaining questions now appear as separate issues.

bwohlberg added a commit that referenced this issue Oct 24, 2023
* Resolve #442

* Add utility function for checking if an object is a scalar of an array of unit size

* Add tests for operator mult/div by singleton arrays

* Modify conditional for scalar equivalence and corresponding test function

* Typo fix

* Simplify conditional

* Add an assertion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants