Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PDHG solver with support for non-linear operators #322

Merged
merged 20 commits into from
Aug 12, 2022
Merged

Conversation

bwohlberg
Copy link
Collaborator

Add PDHG solver with support for non-linear operators.

@bwohlberg bwohlberg added the enhancement New feature or request label Aug 3, 2022
@codecov
Copy link

codecov bot commented Aug 4, 2022

Codecov Report

Merging #322 (5aebd9c) into main (45ea985) will increase coverage by 0.05%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #322      +/-   ##
==========================================
+ Coverage   93.82%   93.88%   +0.05%     
==========================================
  Files          54       55       +1     
  Lines        3418     3451      +33     
==========================================
+ Hits         3207     3240      +33     
  Misses        211      211              
Flag Coverage Δ
unittests 93.88% <100.00%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/linop/_linop.py 98.09% <ø> (ø)
scico/_generic_operators.py 92.10% <100.00%> (+0.17%) ⬆️
scico/examples.py 96.00% <100.00%> (+0.12%) ⬆️
scico/operator/__init__.py 100.00% <100.00%> (ø)
scico/operator/_func.py 100.00% <100.00%> (ø)
scico/optimize/_primaldual.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

x_mag /= x_mag.max()
# Create reference image with structured magnitude and random phase
x_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0])
x_gt = jax.device_put(x_gt) # convert to jax type, push to device
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael-T-McCann: Is this really necessary? Since x_gt is created by an snp function, it's already a DeviceArray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, remove it.

@@ -273,12 +273,33 @@ def jvp(self, primals, tangents):

return jax.jvp(self, primals, tangents)

def jhvp(self, *primals):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael-T-McCann: Thoughts on the name of the method? I'm open to suggestions if you think this one is not so great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind it given that jax has settled on jvp and vjp.

@bwohlberg bwohlberg enabled auto-merge (squash) August 12, 2022 21:07
@bwohlberg bwohlberg merged commit 9005b0f into main Aug 12, 2022
@bwohlberg bwohlberg deleted the brendt/nlpdhg branch August 23, 2022 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants