Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Oct 16, 2024
1 parent c96af11 commit c0c4692
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 56 deletions.
45 changes: 43 additions & 2 deletions src/fmri/operators/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
Adapted from pysap-mri and Modopt libraries.
"""

from functools import cached_property

import numpy as np
import cupy as cp
from modopt.math.matrix import PowerMethod
from modopt.opt.gradient import GradBasic
from modopt.base.backend import get_backend
from modopt.opt.gradient import GradBasic, GradParent
from modopt.base.backend import get_backend, get_array_module


def check_lipschitz_cst(f, x_shape, x_dtype, lipschitz_cst, max_nb_of_iter=10):
Expand Down Expand Up @@ -224,3 +227,41 @@ def _op_method(self, data):

def _trans_op_method(self, data):
return self.linear_op.op(self.fourier_op.adj_op(data))


class CustomGradAnalysis(GradParent):
"""Custom Gradient Analysis Operator."""

def __init__(self, fourier_op, obs_data, obs_data_gpu=None, lazy=True):
self.fourier_op = fourier_op
self._grad_data_type = np.complex64
self._obs_data = obs_data
if obs_data_gpu is None:
self.obs_data_gpu = cp.array(obs_data)
elif isinstance(obs_data_gpu, cp.ndarray):
self.obs_data_gpu = obs_data_gpu
else:
raise ValueError("Invalid data type for obs_data_gpu")
self.lazy = lazy
self.shape = fourier_op.shape

def get_grad(self, x):
"""Get the gradient value"""
if self.lazy:
self.obs_data_gpu.set(self.obs_data)
self.grad = self.fourier_op.data_consistency(x, self.obs_data_gpu)
return self.grad

@cached_property
def spec_rad(self):
return self.fourier_op.get_lipschitz_cst()

def inv_spec_rad(self):
return 1.0 / self.spec_rad

def cost(self, x, *args, **kwargs):
xp = get_array_module(x)
cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data)
if xp != np:
return cost.get()
return cost
2 changes: 0 additions & 2 deletions src/fmri/operators/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,6 @@ def _auto_thresh(self, input_data):
weights = self._thresh_scale(weights, self._n_op_calls)
else:
weights *= self._thresh_scale
xp = get_array_module(weights)
logger.info(xp.unique(weights))
return weights

def _op_method(self, input_data, extra_factor=1.0):
Expand Down
26 changes: 13 additions & 13 deletions src/fmri/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class AccProxSVRG(SetUp):
def __init__(
self,
x,
grad_list,
fourier_op_list,
prox,
cost="auto",
step_size=1.0,
Expand Down Expand Up @@ -79,7 +79,7 @@ def _update(self):
self._v_tld = self.xp.zeros_like(self._v_tld)
# Compute the average gradient.
for g in self._grad_ops:
self._v_tld += g.get_grad(self._x_old)
self._v_tld += g.get_grad(self._x_tld)
self._v_tld /= len(self._grad_ops)

self.xp.copyto(self._x_old, self._x_tld)
Expand All @@ -89,16 +89,17 @@ def _update(self):
self.xp.copyto(self._v, self._v_tld)
self._v *= self.batch_size
for g in gIk:
self._v += g.get_grad(self._x_tld)
self._v -= g.get_grad(self._y)
self._v -= g.get_grad(self._x_tld)
self._v += g.get_grad(self._y)
self._v *= self.step_size / self.batch_size
self._x_new = self._y - self._v # Reuse the array
self.xp.copyto(self._x_new, self._y)
self._x_new -= self._v # Reuse the array
self._x_new = self._prox.op(self._x_new, extra_factor=self.step_size)
self._v = self._x_new - self._x_old # Reuse the array

self._y = self._x_new + self.beta * self._v
self.xp.copyto(self._v, self._x_new)
self._v -= self._x_old # Reuse the array
self.xp.copyto(self._y, self._x_new)
self._y += self.beta * self._v
self.xp.copyto(self._x_old, self._x_new)

self.xp.copyto(self._x_tld, self._x_new)

# Test cost function for convergence.
Expand Down Expand Up @@ -184,14 +185,13 @@ def __init__(
super().__init__(**kwargs)

# Set the initial variable values
self._check_input_data(x)

self.step_size = step_size

self.update_frequency = update_frequency
self.batch_size = batch_size
self._grad_ops = grad_list
self._prox_op = prox
self._prox = prox

self._rng = np.random.default_rng(seed)

Expand All @@ -213,9 +213,9 @@ def _update(self):
self._g += g.get_grad(self._x)
self._g /= len(self._grad_ops)
self.xp.copyto(self._y, self._x)
tk = self.rng.randint(1, self.update_frequency)
tk = self._rng.integers(1, self.update_frequency)
for _ in range(tk):
Ak = self.rng.choices(self._grad_ops, k=self.batch_size)
Ak = self._rng.choice(self._grad_ops, size=self.batch_size, replace=False)
self.xp.copyto(self._g_sto, self._g)
self._g_sto *= self.batch_size
for g in Ak:
Expand Down
64 changes: 25 additions & 39 deletions src/fmri/reconstructors/frame_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"""

import cupy as cp
import logging

import gc
from functools import cached_property

Expand All @@ -13,14 +16,17 @@
import copy
from tqdm.auto import tqdm, trange

from ..operators.gradient import GradAnalysis, GradSynthesis
from ..operators.gradient import GradAnalysis, GradSynthesis, CustomGradAnalysis
from .base import BaseFMRIReconstructor
from .utils import OPTIMIZERS, initialize_opt

from modopt.opt.algorithms import POGM
from modopt.opt.linear import Identity
from modopt.opt.gradient import GradParent
from ..optimizer import AccProxSVRG, MS2GD

logger = logging.getLogger("pysap-fmri")


class SequentialReconstructor(BaseFMRIReconstructor):
"""Sequential Reconstruction of fMRI data.
Expand Down Expand Up @@ -107,6 +113,8 @@ def reconstruct(
final_estimate[i, ...] = x_iter
# Progressbar update
progbar.close()

logger.info("final prox weight: %f ", xp.unique(self.space_prox_op.weights))
return final_estimate

def _reconstruct_frame(
Expand Down Expand Up @@ -219,34 +227,6 @@ def reconstruct(
return final_estimate


class CustomGradAnalysis:
"""Custom Gradient Analysis Operator."""

def __init__(self, fourier_op, obs_data):
self.fourier_op = fourier_op
self.obs_data = obs_data
self.shape = fourier_op.shape

def get_grad(self, x):
"""Get the gradient value"""
self.grad = self.fourier_op.data_consistency(x, self.obs_data)
return self.grad

@cached_property
def spec_rad(self):
return self.fourier_op.get_lipschitz_cst()

def inv_spec_rad(self):
return 1.0 / self.spec_rad

def cost(self, x, *args, **kwargs):
xp = get_array_module(x)
cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data)
if xp != np:
return cost.get()
return cost


class StochasticSequentialReconstructor(BaseFMRIReconstructor):
"""Stochastic Sequential Reconstruction of fMRI data."""

Expand All @@ -255,12 +235,18 @@ def __init__(
fourier_op,
space_linear_op,
space_prox_op,
space_prox_op_refine=None,
progbar_disable=False,
compute_backend="numpy",
**kwargs,
):
super().__init__(fourier_op, space_linear_op, space_prox_op, **kwargs)

if space_prox_op_refine is None:
self.space_prox_op_refine = space_prox_op
else:
self.space_prox_op_refine = space_prox_op_refine

self.progbar_disable = progbar_disable
self.compute_backend = compute_backend

Expand All @@ -269,6 +255,7 @@ def reconstruct(
kspace_data,
x_init=None,
max_iter_per_frame=15,
max_iter_stochastic=20,
grad_kwargs=None,
algorithm="accproxsvrg",
progbar_disable=False,
Expand All @@ -283,6 +270,7 @@ def reconstruct(
xp, _ = get_backend(self.compute_backend)
# Create the gradients operators
grad_list = []
tmp_ksp = cp.zeros_like(kspace_data[0])
for i, fop in enumerate(self.fourier_op.fourier_ops):
# L = fop.get_lipschitz_cst()

Expand All @@ -296,7 +284,7 @@ def reconstruct(
# input_data_writeable=True,
# )
# g._obs_data = kspace_data[i, ...]
g = CustomGradAnalysis(fop, kspace_data[i, ...])
g = CustomGradAnalysis(fop, kspace_data[i, ...], obs_data_gpu=tmp_ksp)
grad_list.append(g)

max_lip = max(g.spec_rad for g in grad_list)
Expand All @@ -307,10 +295,9 @@ def reconstruct(
x=xp.zeros(grad_list[0].shape, dtype="complex64"),
grad_list=grad_list,
prox=self.space_prox_op,
step_size=1.0 / max_lip,
step_size=1.0 / 2 * max_lip,
auto_iterate=False,
cost=None,
update_frequency=10,
compute_backend=self.compute_backend,
**algorithm_kwargs,
)
Expand All @@ -323,12 +310,11 @@ def reconstruct(
prox=self.space_prox_op,
step_size=1.0 / max_lip,
auto_iterate=False,
update_frequency=10,
cost=None,
**algorithm_kwargs,
)

opt.iterate(max_iter=20)
opt.iterate(max_iter=max_iter_stochastic)

x_anat = opt.x_final.squeeze()

Expand All @@ -348,7 +334,7 @@ def reconstruct(
x_anat,
x_anat,
grad=grad_list[i],
prox=self.space_prox_op,
prox=self.space_prox_op_refine,
linear=Identity(),
beta=grad_list[i].inv_spec_rad,
compute_backend=self.compute_backend,
Expand All @@ -360,9 +346,9 @@ def reconstruct(
progbar.reset(total=max_iter_per_frame)
img = opt.x_final

if self.compute_backend == "cupy":
final_img[i] = img.get()
else:
final_img[i] = img
if self.compute_backend == "cupy":
final_img[i] = img.get().squeeze()
else:
final_img[i] = img

return final_img, x_anat

0 comments on commit c0c4692

Please sign in to comment.