diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 1f2608e2..6d7653a7 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -30,15 +30,36 @@ from eddymotion.exceptions import ModelNotFittedError +DEFAULT_MIN_S0 = 1e-5 +"""Minimum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_S0 = 1.0 +"""Maximum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_BVALUE = 1000 +"""Maximum allowed value for the b-value.""" + +DEFAULT_LOWB_THRESHOLD = 50 +"""The lower bound for the b-value so that the orientation is considered a DW volume.""" + +DEFAULT_HIGHB_THRESHOLD = 10000 +"""A b-value cap for DWI data.""" + +DEFAULT_CLIP_PERCENTILE = 75 +"""Upper percentile threshold for intensity clipping.""" + +DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 +"""Time frame tolerance in seconds.""" + def _exec_fit(model, data, chunk=None): retval = model.fit(data) return retval, chunk -def _exec_predict(model, gradient, chunk=None, **kwargs): +def _exec_predict(model, chunk=None, **kwargs): """Propagate model parameters and call predict.""" - return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk + return np.squeeze(model.predict(**kwargs)), chunk class ModelFactory: @@ -62,7 +83,7 @@ def init(model="DTI", **kwargs): """ if model.lower() in ("s0", "b0"): - return TrivialB0Model(S0=kwargs.pop("S0")) + return TrivialB0Model(S0=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) if model.lower() in ("avg", "average", "mean"): return AverageDWModel(**kwargs) @@ -88,38 +109,190 @@ class BaseModel: __slots__ = ( "_model", "_mask", - "_S0", - "_b_max", "_models", "_datashape", "_is_fitted", + "_modelargs", ) - _modelargs = () - def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): + def __init__(self, mask=None, **kwargs): """Base initialization.""" + # Keep model state + self._model = None # "Main" model + self._models = None # For parallel (chunked) execution + self._is_fitted = False + + # Setup brain mask + self._mask = mask + + self._datashape = None self._is_fitted = False + self._modelargs = () + + @property + def is_fitted(self): + return self._is_fitted + + def fit(self, data, **kwargs): + """Abstract member signature of fit().""" + raise NotImplementedError("Cannot call fit() on a BaseModel instance.") + + def predict(self, *args, **kwargs): + """Abstract member signature of predict().""" + raise NotImplementedError("Cannot call predict() on a BaseModel instance.") + + +class PETModel(BaseModel): + """A PET imaging realignment model based on B-Spline approximation.""" + + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") + + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): + """ + Create the B-Spline interpolating matrix. + + Parameters: + ----------- + timepoints : :obj:`list` + The timing (in sec) of each PET volume. + E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., + 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + + n_ctrl : :obj:`int` + Number of B-Spline control points. If `None`, then one control point every + six timepoints will be used. The less control points, the smoother is the + model. + + """ + super.__init__(**kwargs) + + if timepoints is None or xlim is None: + raise TypeError("timepoints must be provided in initialization") + + self._order = order + + self._x = np.array(timepoints, dtype="float32") + self._xlim = xlim + + if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: + raise ValueError("First frame midpoint should not be zero or negative") + if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): + raise ValueError("Last frame midpoint should not be equal or greater than duration") + + # Calculate index coordinates in the B-Spline grid + self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + + # B-Spline knots + self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + + self._coeff = None + + @property + def is_fitted(self): + return self._coeff is not None + + def fit(self, data, **kwargs): + """Fit the model.""" + from scipy.interpolate import BSpline + from scipy.sparse.linalg import cg + + n_jobs = kwargs.pop("n_jobs", None) or 1 + + timepoints = kwargs.get("timepoints", None) or self._x + x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl + + self._datashape = data.shape[:3] + + # Convert data into V (voxels) x T (timepoints) + data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] + + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) + A = BSpline.design_matrix(x, self._t, k=self._order) + AT = A.T + ATdotA = AT @ A + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) + return + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) + + self._coeff = np.array([r[0] for r in results]) + + def predict(self, index=None, **kwargs): + """Return the corrected volume using B-spline interpolation.""" + from scipy.interpolate import BSpline + + if index is None: + raise ValueError("A timepoint index to be simulated must be provided.") + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + # Project sample timing into B-Spline coordinates + x = (index / self._xlim) * self._n_ctrl + A = BSpline.design_matrix(x, self._t, k=self._order) + + # A is 1 (num. timepoints) x C (num. coeff) + # self._coeff is V (num. voxels) x K - 4 + predicted = np.squeeze(A @ self._coeff.T) + + if self._mask is None: + return predicted.reshape(self._datashape) + + retval = np.zeros(self._datashape, dtype="float32") + retval[self._mask] = predicted + return retval + + +class BaseDWIModel(BaseModel): + """Interface and default methods for DWI models.""" + + __slots__ = ( + "_gtab", + "_S0", + "_b_max", + "_model_class", # Defining a model class, DIPY models are instantiated automagically + "_modelargs", + ) + + def __init__(self, gtab, S0=None, b_max=None, **kwargs): + """Initialization. + + Parameters + ---------- + gtab : :obj:`numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + S0 : :obj:`numpy.ndarray` + :math:`S_{0}` signal. + b_max : :obj:`int` + Maximum value to cap b-values. + + """ + + super().__init__(**kwargs) + # Setup B0 map self._S0 = None if S0 is not None: self._S0 = np.clip( S0.astype("float32") / S0.max(), - a_min=1e-5, - a_max=1.0, + a_min=DEFAULT_MIN_S0, + a_max=DEFAULT_MAX_S0, ) - # Setup brain mask - self._mask = mask - if mask is None and S0 is not None: - self._mask = self._S0 > np.percentile(self._S0, 35) - # Cap b-values, if requested + self._gtab = gtab self._b_max = None - if b_max and b_max > 1000: + if b_max and b_max > DEFAULT_MAX_BVALUE: # Saturate b-values at b_max, since signal stops dropping - gtab[-1, gtab[-1] > b_max] = b_max + self._gtab[-1, self._gtab[-1] > b_max] = b_max # A possibly good alternative is completely remove very high b-values # bval_mask = gtab[-1] < b_max # data = data[..., bval_mask] @@ -128,22 +301,16 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + # DIPY models (or one with a fully-compliant interface) model_str = getattr(self, "_model_class", None) - if not model_str: - raise TypeError("No model defined") - - from importlib import import_module - - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + if model_str: + from importlib import import_module - self._datashape = None - self._models = None - self._is_fitted = False - - @property - def is_fitted(self): - return self._is_fitted + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr( + import_module(module_name), + class_name, + )(_rasb2dipy(gtab), **kwargs) def fit(self, data, n_jobs=None, **kwargs): """Fit the model chunk-by-chunk asynchronously""" @@ -177,14 +344,19 @@ def fit(self, data, n_jobs=None, **kwargs): self._is_fitted = True self._model = None # Preempt further actions on the model - def predict(self, gradient, **kwargs): + def predict(self, gradient=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + if gradient is None: + raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") + if not self._is_fitted: raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - if self._b_max is not None: - gradient[-1] = min(gradient[-1], self._b_max) + gradient = np.array(gradient) # Tuples are unmutable + + # Cap the b-value if b_max is defined + gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) gradient = _rasb2dipy(gradient) @@ -199,7 +371,7 @@ def predict(self, gradient, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict(self._model, gradient, S0=S0, **kwargs) + predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -208,7 +380,11 @@ def predict(self, gradient, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_models) as executor: results = executor( - delayed(_exec_predict)(model, gradient, S0=S0[i], chunk=i, **kwargs) + delayed(_exec_predict)( + model, + chunk=i, + **(kwargs | {"gtab": gradient, "S0": S0[i]}), + ) for i, model in enumerate(self._models) ) for subprediction, index in results: @@ -225,33 +401,31 @@ def predict(self, gradient, **kwargs): return retval -class TrivialB0Model: +class TrivialB0Model(BaseDWIModel): """A trivial model that returns a *b=0* map always.""" - __slots__ = ("_S0",) - - def __init__(self, S0=None, **kwargs): + def __init__(self, **kwargs): """Implement object initialization.""" - if S0 is None: - raise ValueError("S0 must be provided") + super().__init__(**kwargs) - self._S0 = S0 + if self._S0 is None: + raise ValueError("S0 must be provided") @property def is_fitted(self): return True - def fit(self, *args, **kwargs): + def fit(self, data, **kwargs): """Do nothing.""" - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the *b=0* map.""" # No need to check fit (if not fitted, has raised already) return self._S0 -class AverageDWModel: +class AverageDWModel(BaseDWIModel): """A trivial model that returns an average map.""" __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted") @@ -262,24 +436,24 @@ def __init__(self, **kwargs): Parameters ---------- - gtab : :obj:`~numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - th_low : :obj:`~numbers.Number` + th_low : :obj:`numbers.Number` A lower bound for the b-value corresponding to the diffusion weighted images that will be averaged. - th_high : :obj:`~numbers.Number` + th_high : :obj:`numbers.Number` An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. bias : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be - standardized and centered around the global 75th percentile. + standardized and centered around the + :data:`src.eddymotion.model.base.DEFAULT_CLIP_PERCENTILE` percentile. stat : :obj:`str` Whether the summary statistic to apply is ``"mean"`` or ``"median"``. """ - self._th_low = kwargs.get("th_low", 50) - self._th_high = kwargs.get("th_high", 10000) + super().__init__(**kwargs) + + self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) + self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) self._bias = kwargs.get("bias", True) self._stat = kwargs.get("stat", "median") self._data = None @@ -287,7 +461,10 @@ def __init__(self, **kwargs): def fit(self, data, **kwargs): """Calculate the average.""" - gtab = kwargs.pop("gtab", None) + + if (gtab := kwargs.pop("gtab", None)) is None: + raise ValueError("A gradient table must be provided.") + # Select the interval of b-values for which DWIs will be averaged b_mask = ( ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) @@ -299,7 +476,7 @@ def fit(self, data, **kwargs): # Regress out global signal differences if self._bias: centers = np.median(shells, axis=(0, 1, 2)) - reference = np.percentile(centers[centers >= 1.0], 75) + reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) centers[centers < 1.0] = reference drift = reference / centers shells = shells * drift @@ -308,14 +485,13 @@ def fit(self, data, **kwargs): avg_func = np.median if self._stat == "median" else np.mean # Calculate the average self._data = avg_func(shells, axis=-1) - - self._is_fitted = self._data is not None + self._is_fitted = True @property def is_fitted(self): return self._is_fitted - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the average map.""" if not self._is_fitted: @@ -324,110 +500,7 @@ def predict(self, gradient, **kwargs): return self._data -class PETModel: - """A PET imaging realignment model based on B-Spline approximation.""" - - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_mask", "_shape", "_n_ctrl") - - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, **kwargs): - """ - Create the B-Spline interpolating matrix. - - Parameters: - ----------- - timepoints : :obj:`list` - The timing (in sec) of each PET volume. - E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., - 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` - - n_ctrl : :obj:`int` - Number of B-Spline control points. If `None`, then one control point every - six timepoints will be used. The less control points, the smoother is the - model. - - """ - if timepoints is None or xlim is None: - raise TypeError("timepoints must be provided in initialization") - - self._order = order - self._mask = mask - - self._x = np.array(timepoints, dtype="float32") - self._xlim = xlim - - if self._x[0] < 1e-2: - raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - 1e-2): - raise ValueError("Last frame midpoint should not be equal or greater than duration") - - # Calculate index coordinates in the B-Spline grid - self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 - - # B-Spline knots - self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - - self._shape = None - self._coeff = None - - @property - def is_fitted(self): - return self._coeff is not None - - def fit(self, data, *args, **kwargs): - """Fit the model.""" - from scipy.interpolate import BSpline - from scipy.sparse.linalg import cg - - n_jobs = kwargs.pop("n_jobs", None) or 1 - - timepoints = kwargs.get("timepoints", None) or self._x - x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - - self._shape = data.shape[:3] - - # Convert data into V (voxels) x T (timepoints) - data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - - # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) - A = BSpline.design_matrix(x, self._t, k=self._order) - AT = A.T - ATdotA = AT @ A - - # One single CPU - linear execution (full model) - if n_jobs == 1: - self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) - return - - # Parallelize process with joblib - with Parallel(n_jobs=n_jobs) as executor: - results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - - self._coeff = np.array([r[0] for r in results]) - - def predict(self, timepoint, **kwargs): - """Return the *b=0* map.""" - from scipy.interpolate import BSpline - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - # Project sample timing into B-Spline coordinates - x = (timepoint / self._xlim) * self._n_ctrl - A = BSpline.design_matrix(x, self._t, k=self._order) - - # A is 1 (num. timepoints) x C (num. coeff) - # self._coeff is V (num. voxels) x K - 4 - predicted = np.squeeze(A @ self._coeff.T) - - if self._mask is None: - return predicted.reshape(self._shape) - - retval = np.zeros(self._shape, dtype="float32") - retval[self._mask] = predicted - return retval - - -class DTIModel(BaseModel): +class DTIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" _modelargs = ( @@ -441,7 +514,7 @@ class DTIModel(BaseModel): _model_class = "dipy.reconst.dti.TensorModel" -class DKIModel(BaseModel): +class DKIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`.""" _modelargs = DTIModel._modelargs diff --git a/test/test_model.py b/test/test_model.py index c5f319bf..7c7a906f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -29,22 +29,32 @@ from eddymotion.data.dmri import DWI from eddymotion.data.splitting import lovo_split from eddymotion.exceptions import ModelNotFittedError +from eddymotion.model.base import DEFAULT_MAX_S0, DEFAULT_MIN_S0 def test_trivial_model(): """Check the implementation of the trivial B0 model.""" + rng = np.random.default_rng(1234) + # Should not allow initialization without a B0 with pytest.raises(ValueError): model.TrivialB0Model(gtab=np.eye(4)) - _S0 = np.random.normal(size=(10, 10, 10)) + _S0 = rng.normal(size=(2, 2, 2)) + + _clipped_S0 = np.clip( + _S0.astype("float32") / _S0.max(), + a_min=DEFAULT_MIN_S0, + a_max=DEFAULT_MAX_S0, + ) - tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_S0) + tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_clipped_S0) - assert tmodel.fit() is None + data = None + assert tmodel.fit(data) is None - assert np.all(_S0 == tmodel.predict((1, 0, 0))) + assert np.all(_clipped_S0 == tmodel.predict((1, 0, 0))) def test_average_model(): @@ -106,6 +116,7 @@ def test_two_initialisations(datadir): # Direct initialisation model1 = model.AverageDWModel( + gtab=data_train[1], S0=dmri_dataset.bzero, th_low=100, th_high=1000,