From d4f4e75e7c4676460d5424c565ea116290fa96b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sun, 7 Apr 2024 10:36:04 -0400 Subject: [PATCH] ENH: Implement Gaussian Process Implement Gaussian Process. --- pyproject.toml | 1 + src/eddymotion/estimator.py | 1 + src/eddymotion/model/__init__.py | 2 + src/eddymotion/model/base.py | 67 ++++++++++++++++++++++++++++++++ test/test_model.py | 19 +++++++++ 5 files changed, 90 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6faf5185..8d876958 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "numpy>=1.17.3", "nest-asyncio>=1.5.1", "scikit-image>=0.14.2", + "scikit_learn", "scipy>=1.8.0", ] dynamic = ["version"] diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index 5c28f113..d274cfd8 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -120,6 +120,7 @@ def estimate( "avg", "average", "mean", + "gp", ) or model.lower().startswith("full") dwmodel = None diff --git a/src/eddymotion/model/__init__.py b/src/eddymotion/model/__init__.py index 3c44e2ad..b64712ff 100644 --- a/src/eddymotion/model/__init__.py +++ b/src/eddymotion/model/__init__.py @@ -26,6 +26,7 @@ AverageDWModel, DKIModel, DTIModel, + GaussianProcessModel, ModelFactory, PETModel, TrivialB0Model, @@ -36,6 +37,7 @@ "AverageDWModel", "DKIModel", "DTIModel", + "GaussianProcessModel", "TrivialB0Model", "PETModel", ) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 6d7653a7..1d5d2768 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -27,6 +27,7 @@ import numpy as np from dipy.core.gradients import gradient_table from joblib import Parallel, delayed +from sklearn.gaussian_process import GaussianProcessRegressor from eddymotion.exceptions import ModelNotFittedError @@ -521,6 +522,72 @@ class DKIModel(BaseDWIModel): _model_class = "dipy.reconst.dki.DiffusionKurtosisModel" +class GaussianProcessModel(BaseModel): + """A Gaussian process model for DWI data based on [Andersson15]_.""" + + __slots__ = ( + "_dwi", + "_kernel", + "_gpr", + ) + + def __init__(self, dwi, kernel, **kwargs): + """Implement object initialization. + + Parameters + ---------- + dwi : :obj:`~eddymotion.dmri.DWI` + The DWI data. + kernel : :obj:`~sklearn.gaussian_process.kernels.Kernel` + Kernel instance. + """ + + self._dwi = dwi + self._kernel = kernel + + def fit(self, X, y, *args, **kwargs): + """Fit the Gaussian process model to the training data. + + Parameters + ---------- + X : :obj:`~numpy.ndarray` of shape (n_samples, n_features) + Feature values for training. For the DWI cae, ``n_samples`` is the + number of diffusion-encoding gradient vectors, and ``n_features`` + being 3 (the spatial coordinates). + y : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets) + Target values: the DWI signal values. + """ + + self._gpr = GaussianProcessRegressor(kernel=self._kernel, random_state=0) + self._gpr.fit(X, y) + self._is_fitted = True + + def predict(self, X, **kwargs): + """Predict using the Gaussian process model of the DWI signal, where + ``X`` is a diffusion-encoding gradient vector whose DWI data needs to be + estimated. + + Parameters + ---------- + X : :obj:`~numpy.ndarray` of shape (n_samples,) + Query points where the Gaussian process is evaluated: the + diffusion-encoding gradient vectors of interest. + + Returns + ------- + y_mean : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets) + Mean of predictive distribution at query points. + y_std : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets) + Standard deviation of predictive distribution at query points. + """ + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + y_mean, y_std = self._gpr.predict(X, return_std=True) + return y_mean, y_std + + def _rasb2dipy(gradient): gradient = np.asanyarray(gradient) if gradient.ndim == 1: diff --git a/test/test_model.py b/test/test_model.py index 7c7a906f..23a097b0 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -24,6 +24,8 @@ import numpy as np import pytest +from sklearn.datasets import make_friedman2 +from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel from eddymotion import model from eddymotion.data.dmri import DWI @@ -105,6 +107,23 @@ def test_average_model(): assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100) +def test_gp_model(datadir): + dwi = DWI.from_filename(datadir / "dwi.h5") + + kernel = DotProduct() + WhiteKernel() + + gp = model.GaussianProcessModel(dwi=dwi, kernel=kernel) + + assert isinstance(gp, model.GaussianProcessModel) + + X, y = make_friedman2(n_samples=500, noise=0, random_state=0) + gp.fit(X, y) + X_qry = X[:2, :] + prediction, _ = gp.predict(X_qry, return_std=True) + + assert prediction.shape == (X_qry.shape[0],) + + def test_two_initialisations(datadir): """Check that the two different initialisations result in the same models"""