diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 38207741..d8c3d149 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -28,13 +28,21 @@ from dipy.core.gradients import gradient_table from joblib import Parallel, delayed +a_min_S0 = 1e-5 # Should these be made method params? +a_max_S0 = 1.0 +bval_upper_cap = 1000 +bval_lower_th = 50 +bval_upper_th = 10000 +percentil_percentage = 75 +timeframe_midpoint_tol = 1e-2 # in seconds + def _exec_fit(model, data, chunk=None): retval = model.fit(data) return retval, chunk -def _exec_predict(model, gradient, chunk=None, **kwargs): +def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): """Propagate model parameters and call predict.""" return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk @@ -86,51 +94,18 @@ class BaseModel: __slots__ = ( "_model", "_mask", - "_S0", - "_b_max", "_models", "_datashape", ) _modelargs = () - def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): + def __init__(self, mask=None, **kwargs): """Base initialization.""" - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip( - S0.astype("float32") / S0.max(), - a_min=1e-5, - a_max=1.0, - ) + self._model = None # Setup brain mask self._mask = mask - if mask is None and S0 is not None: - self._mask = self._S0 > np.percentile(self._S0, 35) - - # Cap b-values, if requested - self._b_max = None - if b_max and b_max > 1000: - # Saturate b-values at b_max, since signal stops dropping - gtab[-1, gtab[-1] > b_max] = b_max - # A possibly good alternative is completely remove very high b-values - # bval_mask = gtab[-1] < b_max - # data = data[..., bval_mask] - # gtab = gtab[:, bval_mask] - self._b_max = b_max - - kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} - - model_str = getattr(self, "_model_class", None) - if not model_str: - raise TypeError("No model defined") - - from importlib import import_module - - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) self._datashape = None self._models = None @@ -166,12 +141,73 @@ def fit(self, data, n_jobs=None, **kwargs): self._model = None # Preempt further actions on the model - def predict(self, gradient, **kwargs): + def predict(self, *args, **kwargs): + pass + + +class BaseDWIModel(BaseModel): + """Interface and default methods for DWI models.""" + + __slots__ = ( + "_gtab", + "_S0", + "_b_max", + ) + + def __init__(self, gtab, S0=None, b_max=None, **kwargs): + """Initialization. + + Parameters + ---------- + gtab : :obj:`numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + S0 : :obj:`numpy.ndarray` + :math:`S_{0}` signal. + b_max : :obj:`int` + Maximum value to cap b-values. + """ + + super().__init__(**kwargs) + + # Setup B0 map + self._S0 = None + if S0 is not None: + self._S0 = np.clip( + S0.astype("float32") / S0.max(), + a_min=a_min_S0, + a_max=a_max_S0, + ) + + # Cap b-values, if requested + self._gtab = gtab + self._b_max = None + if b_max and b_max > bval_upper_cap: + # Saturate b-values at b_max, since signal stops dropping + self._gtab[-1, self._gtab[-1] > b_max] = b_max + # A possibly good alternative is completely remove very high b-values + # bval_mask = gtab[-1] < b_max + # data = data[..., bval_mask] + # gtab = gtab[:, bval_mask] + self._b_max = b_max + + kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + + model_str = getattr(self, "_model_class", None) + if not model_str: + raise TypeError("No model defined") + + from importlib import import_module + + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + + def predict(self, index, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" if self._b_max is not None: - gradient[-1] = min(gradient[-1], self._b_max) + index[-1] = min(index[-1], self._b_max) - gradient = _rasb2dipy(gradient) + self._gtab = _rasb2dipy(self._gtab) S0 = None if self._S0 is not None: @@ -184,7 +220,7 @@ def predict(self, gradient, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict(self._model, gradient, S0=S0, **kwargs) + predicted, _ = _exec_predict_dwi(self._model, self._gtab, S0=S0, **kwargs) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -193,7 +229,7 @@ def predict(self, gradient, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_models) as executor: results = executor( - delayed(_exec_predict)(model, gradient, S0=S0[i], chunk=i, **kwargs) + delayed(_exec_predict_dwi)(model, self._gtab, S0=S0[i], chunk=i, **kwargs) for i, model in enumerate(self._models) ) for subprediction, index in results: @@ -210,27 +246,25 @@ def predict(self, gradient, **kwargs): return retval -class TrivialB0Model: +class TrivialB0Model(BaseDWIModel): """A trivial model that returns a *b=0* map always.""" - __slots__ = ("_S0",) - - def __init__(self, S0=None, **kwargs): + def __init__(self, **kwargs): """Implement object initialization.""" - if S0 is None: - raise ValueError("S0 must be provided") + super().__init__(**kwargs) - self._S0 = S0 + if self._S0 is None: + raise ValueError("S0 must be provided") - def fit(self, *args, **kwargs): + def fit(self, data, **kwargs): """Do nothing.""" - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the *b=0* map.""" return self._S0 -class AverageDWModel: +class AverageDWModel(BaseDWIModel): """A trivial model that returns an average map.""" __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat") @@ -241,35 +275,33 @@ def __init__(self, **kwargs): Parameters ---------- - gtab : :obj:`~numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - th_low : :obj:`~numbers.Number` + th_low : :obj:`numbers.Number` A lower bound for the b-value corresponding to the diffusion weighted images that will be averaged. - th_high : :obj:`~numbers.Number` + th_high : :obj:`numbers.Number` An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. bias : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be - standardized and centered around the global 75th percentile. + standardized and centered around the global ``percentil_percentage`` percentile. stat : :obj:`str` Whether the summary statistic to apply is ``"mean"`` or ``"median"``. """ - self._th_low = kwargs.get("th_low", 50) - self._th_high = kwargs.get("th_high", 10000) + super().__init__(**kwargs) + + self._th_low = kwargs.get("th_low", bval_lower_th) + self._th_high = kwargs.get("th_high", bval_upper_th) self._bias = kwargs.get("bias", True) self._stat = kwargs.get("stat", "median") self._data = None def fit(self, data, **kwargs): """Calculate the average.""" - gtab = kwargs.pop("gtab", None) # Select the interval of b-values for which DWIs will be averaged b_mask = ( - ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) - if gtab is not None + ((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high)) + if self._gtab is not None else np.ones((data.shape[-1],), dtype=bool) ) shells = data[..., b_mask] @@ -277,7 +309,7 @@ def fit(self, data, **kwargs): # Regress out global signal differences if self._bias: centers = np.median(shells, axis=(0, 1, 2)) - reference = np.percentile(centers[centers >= 1.0], 75) + reference = np.percentile(centers[centers >= 1.0], percentil_percentage) centers[centers < 1.0] = reference drift = reference / centers shells = shells * drift @@ -287,17 +319,17 @@ def fit(self, data, **kwargs): # Calculate the average self._data = avg_func(shells, axis=-1) - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the average map.""" return self._data -class PETModel: +class PETModel(BaseModel): """A PET imaging realignment model based on B-Spline approximation.""" - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_mask", "_shape", "_n_ctrl") + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, **kwargs): + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): """ Create the B-Spline interpolating matrix. @@ -314,18 +346,19 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, model. """ + super.__init__(**kwargs) + if timepoints is None or xlim is None: raise TypeError("timepoints must be provided in initialization") self._order = order - self._mask = mask self._x = np.array(timepoints, dtype="float32") self._xlim = xlim - if self._x[0] < 1e-2: + if self._x[0] < timeframe_midpoint_tol: raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - 1e-2): + if self._x[-1] > (self._xlim - timeframe_midpoint_tol): raise ValueError("Last frame midpoint should not be equal or greater than duration") # Calculate index coordinates in the B-Spline grid @@ -334,10 +367,9 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, # B-Spline knots self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - self._shape = None self._coeff = None - def fit(self, data, *args, **kwargs): + def fit(self, data, **kwargs): """Fit the model.""" from scipy.interpolate import BSpline from scipy.sparse.linalg import cg @@ -347,7 +379,7 @@ def fit(self, data, *args, **kwargs): timepoints = kwargs.get("timepoints", None) or self._x x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - self._shape = data.shape[:3] + self._datashape = data.shape[:3] # Convert data into V (voxels) x T (timepoints) data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] @@ -368,12 +400,12 @@ def fit(self, data, *args, **kwargs): self._coeff = np.array([r[0] for r in results]) - def predict(self, timepoint, **kwargs): - """Return the *b=0* map.""" + def predict(self, index, **kwargs): + """Return the corrected volume using B-spline interpolation.""" from scipy.interpolate import BSpline # Project sample timing into B-Spline coordinates - x = (timepoint / self._xlim) * self._n_ctrl + x = (index / self._xlim) * self._n_ctrl A = BSpline.design_matrix(x, self._t, k=self._order) # A is 1 (num. timepoints) x C (num. coeff) @@ -381,14 +413,14 @@ def predict(self, timepoint, **kwargs): predicted = np.squeeze(A @ self._coeff.T) if self._mask is None: - return predicted.reshape(self._shape) + return predicted.reshape(self._datashape) - retval = np.zeros(self._shape, dtype="float32") + retval = np.zeros(self._datashape, dtype="float32") retval[self._mask] = predicted return retval -class DTIModel(BaseModel): +class DTIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" _modelargs = ( @@ -402,7 +434,7 @@ class DTIModel(BaseModel): _model_class = "dipy.reconst.dti.TensorModel" -class DKIModel(BaseModel): +class DKIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`.""" _modelargs = DTIModel._modelargs