Skip to content

Commit

Permalink
feat: inline import for faster discoverability.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Dec 5, 2023
1 parent d06d2e5 commit 860d893
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions src/simfmri/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,46 @@
and returns a reconstructed fMRI array.
"""
from __future__ import annotations
from typing import Literal
from typing import Literal, Protocol
import logging
import warnings
import numpy as np


from fmri.operators.fourier import FFT_Sense, RepeatOperator, PooledgpuNUFFTSpaceFourier
from fmri.operators.fourier import CartesianSpaceFourier, SpaceFourierBase
from modopt.opt.linear import LinearParent
from modopt.opt.proximity import ProximityParent
from mrinufft.operators import get_operator
from mrinufft.operators.stacked import traj3d2stacked
from mrinufft.trajectories.density import voronoi

from .base import BaseReconstructor
from simfmri.simulation import SimData

logger = logging.getLogger(__name__)


def _get_stacked_operator(backend: str, sim: SimData) -> RepeatOperator:
class SpaceFourier(Protocol):
"""Fourier operator interface."""

n_frames: int
shape: tuple[int]
n_coils: int
uses_sense: bool

def op(self, x: np.ndarray) -> np.ndarray:
"""Apply the Fourier operator."""
...

def adj_op(self, x: np.ndarray) -> np.ndarray:
"""Apply the adjoint of the Fourier operator."""
...


def _get_stacked_operator(backend: str, sim: SimData) -> SpaceFourier:
from mrinufft.operators.stacked import traj3d2stacked
from mrinufft.trajectories.density import voronoi
from mrinufft.operators import get_operator

from fmri.operators.fourier import (
RepeatOperator,
)

nufft_backend = backend.split("-")[1]
frame_ops = []
Ns = sim.extra_infos["traj_params"]["n_samples"]
Expand Down Expand Up @@ -60,10 +79,19 @@ def _get_stacked_operator(backend: str, sim: SimData) -> RepeatOperator:

def get_fourier_operator(
sim: SimData, cartesian_repeat: bool = False, **kwargs: None
) -> RepeatOperator | CartesianSpaceFourier:
) -> SpaceFourier:
"""Return a Fourier operator for the given simulation."""
kwargs = kwargs.copy() if kwargs is not None else {}

from fmri.operators.fourier import CartesianSpaceFourier, SpaceFourierBase
from mrinufft.operators import get_operator

from fmri.operators.fourier import (
FFT_Sense,
RepeatOperator,
PooledgpuNUFFTSpaceFourier,
)

density = True
backend = sim.extra_infos.get("operator", "fft")
logger.info(f"fourier backend is {backend}")
Expand Down Expand Up @@ -197,9 +225,7 @@ def setup(self, sim: SimData) -> None:
self.fourier_op, space_linear_op, space_prox_op, optimizer="pogm"
)

def reconstruct(
self, sim: SimData, fourier_op: SpaceFourierBase | None = None
) -> np.ndarray:
def reconstruct(self, sim: SimData, fourier_op: None = None) -> np.ndarray:
"""Reconstruct with Sequential."""
if fourier_op is not None:
self.fourier_op = fourier_op
Expand Down Expand Up @@ -236,7 +262,7 @@ def __init__(
time_linear_op: LinearParent = None,
time_prox_op: ProximityParent = None,
space_prox_op: ProximityParent = None,
fourier_op: SpaceFourierBase = None,
fourier_op: SpaceFourier = None,
):
super().__init__()
self.lambda_l = lambda_l
Expand Down Expand Up @@ -265,11 +291,11 @@ def setup(self, sim: SimData) -> None:
if self.fourier_op is None:
self.fourier_op = get_fourier_operator(sim, cartesian_repeat=False)

logger.debug(f"Space Fourier operator initialized")
logger.debug("Space Fourier operator initialized")
if self.time_linear_op is None:
self.time_linear_op = TimeFourier(time_axis=0)

logger.debug(f"Time Fourier operator initialized")
logger.debug("Time Fourier operator initialized")
if self.lambda_s == "sure":
adj_data = self.fourier_op.adj_op(sim.kspace_data)
sure_thresh = np.zeros(np.prod(adj_data.shape[1:]))
Expand All @@ -285,23 +311,23 @@ def setup(self, sim: SimData) -> None:
self.time_linear_op, self.lambda_s, thresh_type="soft"
)

logger.debug(f"Prox Time operator initialized")
logger.debug("Prox Time operator initialized")
if self.space_prox_op is None:
self.space_prox_op = FlattenSVT(
self.lambda_l, initial_rank=10, thresh_type="soft-rel"
)
logger.debug(f"Prox Space operator initialized")
logger.debug("Prox Space operator initialized")

self.reconstructor = LowRankPlusSparseReconstructor(
self.fourier_op,
space_prox_op=self.space_prox_op,
time_prox_op=self.time_prox_op,
cost="auto",
)
logger.debug(f"Reconstructor initialized")
logger.debug("Reconstructor initialized")

def reconstruct(
self, sim: SimData, fourier_op: SpaceFourierBase | None = None
self, sim: SimData, fourier_op: SpaceFourier | None = None
) -> np.ndarray:
"""Reconstruct using LowRank+Sparse Method."""
if fourier_op is not None:
Expand Down

0 comments on commit 860d893

Please sign in to comment.