Skip to content

Commit

Permalink
various fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Jun 1, 2024
1 parent 2b8d0f6 commit c96af11
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/fmri/operators/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
from mrinufft import get_operator
from modopt.base.backend import get_array_module

try:
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps
Expand Down Expand Up @@ -171,7 +172,8 @@ def op(self, images):
def adj_op(self, coeffs):
"""Apply Adjoint Operator."""
c = 1 if self.uses_sense else self.n_coils
final_image = np.empty((self.n_frames, c, *self.shape), dtype=np.complex64)
xp = get_array_module(coeffs)
final_image = xp.empty((self.n_frames, c, *self.shape), dtype=np.complex64)
for i in range(len(coeffs)):
final_image[i] = self.fourier_ops[i].adj_op(coeffs[i])
return final_image.squeeze()
Expand Down
1 change: 1 addition & 0 deletions src/fmri/operators/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def __init__(self, linear_op, fourier_op, verbose=0, **kwargs):
n_channels = fourier_op.n_coils if not fourier_op.uses_sense else 1
coef = linear_op.op(np.squeeze(np.zeros((n_channels, *fourier_op.shape))))
self.linear_op_coeffs_shape = coef.shape
self.shape = coef.shape
super().__init__(
self._op_method,
self._trans_op_method,
Expand Down
10 changes: 9 additions & 1 deletion src/fmri/operators/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,15 @@ def __init__(
thresh_range="global",
threshold_estimation="sure",
threshold_scaler=1.0,
synthesis=True,
**kwargs
):
if linear is None:
linear = Identity()
self._n_op_calls = 0
self.cf_shape = coeffs_shape
self._update_period = update_period
self.synthesis = synthesis

if thresh_range not in ["bands", "scale", "global"]:
raise ValueError("Unsupported threshold range.")
Expand Down Expand Up @@ -451,10 +453,16 @@ def _op_method(self, input_data, extra_factor=1.0):
Thresholded data
"""
if not self.synthesis:
input_data = self._linear.op(input_data)
if self._update_period == 0 and self._n_op_calls == 0:
self.weights = self._auto_thresh(input_data)
if self._update_period != 0 and self._n_op_calls % self._update_period == 0:
self.weights = self._auto_thresh(input_data)

self._n_op_calls += 1
return super()._op_method(input_data, extra_factor=extra_factor)
threshed = super()._op_method(input_data, extra_factor=extra_factor)

if not self.synthesis:
return self._linear.adj_op(threshed)
return threshed
158 changes: 157 additions & 1 deletion src/fmri/reconstructors/frame_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
"""

from modopt.base.backend import get_backend
import gc
from functools import cached_property

from modopt.base.backend import get_backend, get_array_module
import numpy as np
import copy
from tqdm.auto import tqdm, trange
Expand All @@ -14,6 +17,10 @@
from .base import BaseFMRIReconstructor
from .utils import OPTIMIZERS, initialize_opt

from modopt.opt.algorithms import POGM
from modopt.opt.linear import Identity
from ..optimizer import AccProxSVRG, MS2GD


class SequentialReconstructor(BaseFMRIReconstructor):
"""Sequential Reconstruction of fMRI data.
Expand Down Expand Up @@ -210,3 +217,152 @@ def reconstruct(

progbar.close()
return final_estimate


class CustomGradAnalysis:
"""Custom Gradient Analysis Operator."""

def __init__(self, fourier_op, obs_data):
self.fourier_op = fourier_op
self.obs_data = obs_data
self.shape = fourier_op.shape

def get_grad(self, x):
"""Get the gradient value"""
self.grad = self.fourier_op.data_consistency(x, self.obs_data)
return self.grad

@cached_property
def spec_rad(self):
return self.fourier_op.get_lipschitz_cst()

def inv_spec_rad(self):
return 1.0 / self.spec_rad

def cost(self, x, *args, **kwargs):
xp = get_array_module(x)
cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data)
if xp != np:
return cost.get()
return cost


class StochasticSequentialReconstructor(BaseFMRIReconstructor):
"""Stochastic Sequential Reconstruction of fMRI data."""

def __init__(
self,
fourier_op,
space_linear_op,
space_prox_op,
progbar_disable=False,
compute_backend="numpy",
**kwargs,
):
super().__init__(fourier_op, space_linear_op, space_prox_op, **kwargs)

self.progbar_disable = progbar_disable
self.compute_backend = compute_backend

def reconstruct(
self,
kspace_data,
x_init=None,
max_iter_per_frame=15,
grad_kwargs=None,
algorithm="accproxsvrg",
progbar_disable=False,
algorithm_kwargs=None,
):
"""Reconstruct using sequential method."""
self.progbar_disable = progbar_disable

if algorithm_kwargs is None:
algorithm_kwargs = {}

xp, _ = get_backend(self.compute_backend)
# Create the gradients operators
grad_list = []
for i, fop in enumerate(self.fourier_op.fourier_ops):
# L = fop.get_lipschitz_cst()

# g = GradSynthesis(
# linear_op=self.space_linear_op,
# fourier_op=fop,
# verbose=self.verbose,
# dtype=kspace_data.dtype,
# lipschitz_cst=L,
# num_check_lips=0, # trust me
# input_data_writeable=True,
# )
# g._obs_data = kspace_data[i, ...]
g = CustomGradAnalysis(fop, kspace_data[i, ...])
grad_list.append(g)

max_lip = max(g.spec_rad for g in grad_list)

if algorithm == "accproxsvrg":

opt = AccProxSVRG(
x=xp.zeros(grad_list[0].shape, dtype="complex64"),
grad_list=grad_list,
prox=self.space_prox_op,
step_size=1.0 / max_lip,
auto_iterate=False,
cost=None,
update_frequency=10,
compute_backend=self.compute_backend,
**algorithm_kwargs,
)

elif algorithm == "m2sg":

opt = MS2GD(
x=xp.zeros(self.fourier_op.shape, dtype="complex64"),
grad_list=grad_list,
prox=self.space_prox_op,
step_size=1.0 / max_lip,
auto_iterate=False,
update_frequency=10,
cost=None,
**algorithm_kwargs,
)

opt.iterate(max_iter=20)

x_anat = opt.x_final.squeeze()

progbar_main = trange(len(kspace_data), disable=self.progbar_disable)
progbar = tqdm(total=max_iter_per_frame, disable=self.progbar_disable)
final_img = np.zeros(
(len(kspace_data), *self.fourier_op.shape),
dtype=self.fourier_op.cpx_dtype,
)
del opt
gc.collect()
for i in progbar_main: # Parallel

opt = POGM(
x_anat,
x_anat,
x_anat,
x_anat,
grad=grad_list[i],
prox=self.space_prox_op,
linear=Identity(),
beta=grad_list[i].inv_spec_rad,
compute_backend=self.compute_backend,
auto_iterate=False,
cost=None,
)
opt.iterate(progbar=progbar, max_iter=max_iter_per_frame)

progbar.reset(total=max_iter_per_frame)
img = opt.x_final

if self.compute_backend == "cupy":
final_img[i] = img.get()
else:
final_img[i] = img

return final_img, x_anat

0 comments on commit c96af11

Please sign in to comment.