Skip to content

Commit

Permalink
ENH: Make models inherit from base model
Browse files Browse the repository at this point in the history
Make models inherit from base model.
  • Loading branch information
jhlegarreta committed Apr 30, 2024
1 parent f68f822 commit df51602
Showing 1 changed file with 112 additions and 80 deletions.
192 changes: 112 additions & 80 deletions src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,21 @@
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed

a_min_S0 = 1e-5 # Should these be made method params?
a_max_S0 = 1.0
bval_upper_cap = 1000
bval_lower_th = 50
bval_upper_th = 10000
percentil_percentage = 75
timeframe_midpoint_tol = 1e-2 # in seconds


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
return retval, chunk


def _exec_predict(model, gradient, chunk=None, **kwargs):
def _exec_predict_dwi(model, gradient, chunk=None, **kwargs):
"""Propagate model parameters and call predict."""
return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk

Expand Down Expand Up @@ -86,51 +94,18 @@ class BaseModel:
__slots__ = (
"_model",
"_mask",
"_S0",
"_b_max",
"_models",
"_datashape",
)
_modelargs = ()

def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
def __init__(self, mask=None, **kwargs):
"""Base initialization."""

# Setup B0 map
self._S0 = None
if S0 is not None:
self._S0 = np.clip(
S0.astype("float32") / S0.max(),
a_min=1e-5,
a_max=1.0,
)
self._model = None

# Setup brain mask
self._mask = mask
if mask is None and S0 is not None:
self._mask = self._S0 > np.percentile(self._S0, 35)

# Cap b-values, if requested
self._b_max = None
if b_max and b_max > 1000:
# Saturate b-values at b_max, since signal stops dropping
gtab[-1, gtab[-1] > b_max] = b_max
# A possibly good alternative is completely remove very high b-values
# bval_mask = gtab[-1] < b_max
# data = data[..., bval_mask]
# gtab = gtab[:, bval_mask]
self._b_max = b_max

kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs}

model_str = getattr(self, "_model_class", None)
if not model_str:
raise TypeError("No model defined")

from importlib import import_module

module_name, class_name = model_str.rsplit(".", 1)
self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs)

self._datashape = None
self._models = None
Expand Down Expand Up @@ -166,12 +141,73 @@ def fit(self, data, n_jobs=None, **kwargs):

self._model = None # Preempt further actions on the model

def predict(self, gradient, **kwargs):
def predict(self, *args, **kwargs):
pass


class BaseDWIModel(BaseModel):
"""Interface and default methods for DWI models."""

__slots__ = (
"_gtab",
"_S0",
"_b_max",
)

def __init__(self, gtab, S0=None, b_max=None, **kwargs):
"""Initialization.
Parameters
----------
gtab : :obj:`numpy.ndarray`
An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and
columns are b-vector components and corresponding b-value, respectively.
S0 : :obj:`numpy.ndarray`
:math:`S_{0}` signal.
b_max : :obj:`int`
Maximum value to cap b-values.
"""

super().__init__(**kwargs)

# Setup B0 map
self._S0 = None
if S0 is not None:
self._S0 = np.clip(
S0.astype("float32") / S0.max(),
a_min=a_min_S0,
a_max=a_max_S0,
)

# Cap b-values, if requested
self._gtab = gtab
self._b_max = None
if b_max and b_max > bval_upper_cap:
# Saturate b-values at b_max, since signal stops dropping
self._gtab[-1, self._gtab[-1] > b_max] = b_max
# A possibly good alternative is completely remove very high b-values
# bval_mask = gtab[-1] < b_max
# data = data[..., bval_mask]
# gtab = gtab[:, bval_mask]
self._b_max = b_max

kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs}

model_str = getattr(self, "_model_class", None)
if not model_str:
raise TypeError("No model defined")

from importlib import import_module

module_name, class_name = model_str.rsplit(".", 1)
self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs)

def predict(self, index, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""
if self._b_max is not None:
gradient[-1] = min(gradient[-1], self._b_max)
index[-1] = min(index[-1], self._b_max)

gradient = _rasb2dipy(gradient)
self._gtab = _rasb2dipy(self._gtab)

S0 = None
if self._S0 is not None:
Expand All @@ -184,7 +220,7 @@ def predict(self, gradient, **kwargs):
n_models = len(self._models) if self._model is None and self._models else 1

if n_models == 1:
predicted, _ = _exec_predict(self._model, gradient, S0=S0, **kwargs)
predicted, _ = _exec_predict_dwi(self._model, self._gtab, S0=S0, **kwargs)
else:
S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models

Expand All @@ -193,7 +229,7 @@ def predict(self, gradient, **kwargs):
# Parallelize process with joblib
with Parallel(n_jobs=n_models) as executor:
results = executor(
delayed(_exec_predict)(model, gradient, S0=S0[i], chunk=i, **kwargs)
delayed(_exec_predict_dwi)(model, self._gtab, S0=S0[i], chunk=i, **kwargs)
for i, model in enumerate(self._models)
)
for subprediction, index in results:
Expand All @@ -210,27 +246,25 @@ def predict(self, gradient, **kwargs):
return retval


class TrivialB0Model:
class TrivialB0Model(BaseDWIModel):
"""A trivial model that returns a *b=0* map always."""

__slots__ = ("_S0",)

def __init__(self, S0=None, **kwargs):
def __init__(self, **kwargs):
"""Implement object initialization."""
if S0 is None:
raise ValueError("S0 must be provided")
super().__init__(**kwargs)

self._S0 = S0
if self._S0 is None:
raise ValueError("S0 must be provided")

def fit(self, *args, **kwargs):
def fit(self, data, **kwargs):
"""Do nothing."""

def predict(self, gradient, **kwargs):
def predict(self, *_, **kwargs):
"""Return the *b=0* map."""
return self._S0


class AverageDWModel:
class AverageDWModel(BaseDWIModel):
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat")
Expand All @@ -241,43 +275,41 @@ def __init__(self, **kwargs):
Parameters
----------
gtab : :obj:`~numpy.ndarray`
An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and
columns are b-vector components and corresponding b-value, respectively.
th_low : :obj:`~numbers.Number`
th_low : :obj:`numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`~numbers.Number`
th_high : :obj:`numbers.Number`
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
bias : :obj:`bool`
Whether the overall distribution of each diffusion weighted image will be
standardized and centered around the global 75th percentile.
standardized and centered around the global ``percentil_percentage`` percentile.
stat : :obj:`str`
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
"""
self._th_low = kwargs.get("th_low", 50)
self._th_high = kwargs.get("th_high", 10000)
super().__init__(**kwargs)

self._th_low = kwargs.get("th_low", bval_lower_th)
self._th_high = kwargs.get("th_high", bval_upper_th)
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")
self._data = None

def fit(self, data, **kwargs):
"""Calculate the average."""
gtab = kwargs.pop("gtab", None)
# Select the interval of b-values for which DWIs will be averaged
b_mask = (
((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high))
if gtab is not None
((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high))
if self._gtab is not None
else np.ones((data.shape[-1],), dtype=bool)
)
shells = data[..., b_mask]

# Regress out global signal differences
if self._bias:
centers = np.median(shells, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], 75)
reference = np.percentile(centers[centers >= 1.0], percentil_percentage)
centers[centers < 1.0] = reference
drift = reference / centers
shells = shells * drift
Expand All @@ -287,17 +319,17 @@ def fit(self, data, **kwargs):
# Calculate the average
self._data = avg_func(shells, axis=-1)

def predict(self, gradient, **kwargs):
def predict(self, *_, **kwargs):
"""Return the average map."""
return self._data


class PETModel:
class PETModel(BaseModel):
"""A PET imaging realignment model based on B-Spline approximation."""

__slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_mask", "_shape", "_n_ctrl")
__slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl")

def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, **kwargs):
def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
"""
Create the B-Spline interpolating matrix.
Expand All @@ -314,18 +346,19 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3,
model.
"""
super.__init__(**kwargs)

if timepoints is None or xlim is None:
raise TypeError("timepoints must be provided in initialization")

self._order = order
self._mask = mask

self._x = np.array(timepoints, dtype="float32")
self._xlim = xlim

if self._x[0] < 1e-2:
if self._x[0] < timeframe_midpoint_tol:
raise ValueError("First frame midpoint should not be zero or negative")
if self._x[-1] > (self._xlim - 1e-2):
if self._x[-1] > (self._xlim - timeframe_midpoint_tol):
raise ValueError("Last frame midpoint should not be equal or greater than duration")

# Calculate index coordinates in the B-Spline grid
Expand All @@ -334,10 +367,9 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3,
# B-Spline knots
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")

self._shape = None
self._coeff = None

def fit(self, data, *args, **kwargs):
def fit(self, data, **kwargs):
"""Fit the model."""
from scipy.interpolate import BSpline
from scipy.sparse.linalg import cg
Expand All @@ -347,7 +379,7 @@ def fit(self, data, *args, **kwargs):
timepoints = kwargs.get("timepoints", None) or self._x
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl

self._shape = data.shape[:3]
self._datashape = data.shape[:3]

# Convert data into V (voxels) x T (timepoints)
data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask]
Expand All @@ -368,27 +400,27 @@ def fit(self, data, *args, **kwargs):

self._coeff = np.array([r[0] for r in results])

def predict(self, timepoint, **kwargs):
"""Return the *b=0* map."""
def predict(self, index, **kwargs):
"""Return the corrected volume using B-spline interpolation."""
from scipy.interpolate import BSpline

# Project sample timing into B-Spline coordinates
x = (timepoint / self._xlim) * self._n_ctrl
x = (index / self._xlim) * self._n_ctrl
A = BSpline.design_matrix(x, self._t, k=self._order)

# A is 1 (num. timepoints) x C (num. coeff)
# self._coeff is V (num. voxels) x K - 4
predicted = np.squeeze(A @ self._coeff.T)

if self._mask is None:
return predicted.reshape(self._shape)
return predicted.reshape(self._datashape)

retval = np.zeros(self._shape, dtype="float32")
retval = np.zeros(self._datashape, dtype="float32")
retval[self._mask] = predicted
return retval


class DTIModel(BaseModel):
class DTIModel(BaseDWIModel):
"""A wrapper of :obj:`dipy.reconst.dti.TensorModel`."""

_modelargs = (
Expand All @@ -402,7 +434,7 @@ class DTIModel(BaseModel):
_model_class = "dipy.reconst.dti.TensorModel"


class DKIModel(BaseModel):
class DKIModel(BaseDWIModel):
"""A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`."""

_modelargs = DTIModel._modelargs
Expand Down

0 comments on commit df51602

Please sign in to comment.