Skip to content

Commit

Permalink
ENH: Implement Gaussian Process
Browse files Browse the repository at this point in the history
Implement Gaussian Process.
  • Loading branch information
jhlegarreta committed Jun 13, 2024
1 parent 8c0bf36 commit d4f4e75
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"numpy>=1.17.3",
"nest-asyncio>=1.5.1",
"scikit-image>=0.14.2",
"scikit_learn",
"scipy>=1.8.0",
]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def estimate(
"avg",
"average",
"mean",
"gp",
) or model.lower().startswith("full")

dwmodel = None
Expand Down
2 changes: 2 additions & 0 deletions src/eddymotion/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AverageDWModel,
DKIModel,
DTIModel,
GaussianProcessModel,
ModelFactory,
PETModel,
TrivialB0Model,
Expand All @@ -36,6 +37,7 @@
"AverageDWModel",
"DKIModel",
"DTIModel",
"GaussianProcessModel",
"TrivialB0Model",
"PETModel",
)
67 changes: 67 additions & 0 deletions src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed
from sklearn.gaussian_process import GaussianProcessRegressor

from eddymotion.exceptions import ModelNotFittedError

Expand Down Expand Up @@ -521,6 +522,72 @@ class DKIModel(BaseDWIModel):
_model_class = "dipy.reconst.dki.DiffusionKurtosisModel"


class GaussianProcessModel(BaseModel):
"""A Gaussian process model for DWI data based on [Andersson15]_."""

__slots__ = (
"_dwi",
"_kernel",
"_gpr",
)

def __init__(self, dwi, kernel, **kwargs):
"""Implement object initialization.
Parameters
----------
dwi : :obj:`~eddymotion.dmri.DWI`
The DWI data.
kernel : :obj:`~sklearn.gaussian_process.kernels.Kernel`
Kernel instance.
"""

self._dwi = dwi
self._kernel = kernel

def fit(self, X, y, *args, **kwargs):
"""Fit the Gaussian process model to the training data.
Parameters
----------
X : :obj:`~numpy.ndarray` of shape (n_samples, n_features)
Feature values for training. For the DWI cae, ``n_samples`` is the
number of diffusion-encoding gradient vectors, and ``n_features``
being 3 (the spatial coordinates).
y : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets)
Target values: the DWI signal values.
"""

self._gpr = GaussianProcessRegressor(kernel=self._kernel, random_state=0)
self._gpr.fit(X, y)
self._is_fitted = True

def predict(self, X, **kwargs):
"""Predict using the Gaussian process model of the DWI signal, where
``X`` is a diffusion-encoding gradient vector whose DWI data needs to be
estimated.
Parameters
----------
X : :obj:`~numpy.ndarray` of shape (n_samples,)
Query points where the Gaussian process is evaluated: the
diffusion-encoding gradient vectors of interest.
Returns
-------
y_mean : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets)
Mean of predictive distribution at query points.
y_std : :obj:`~numpy.ndarray` of shape (n_samples,) or (n_samples, n_targets)
Standard deviation of predictive distribution at query points.
"""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

y_mean, y_std = self._gpr.predict(X, return_std=True)
return y_mean, y_std


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
Expand Down
19 changes: 19 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import numpy as np
import pytest
from sklearn.datasets import make_friedman2
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel

from eddymotion import model
from eddymotion.data.dmri import DWI
Expand Down Expand Up @@ -105,6 +107,23 @@ def test_average_model():
assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100)


def test_gp_model(datadir):
dwi = DWI.from_filename(datadir / "dwi.h5")

kernel = DotProduct() + WhiteKernel()

gp = model.GaussianProcessModel(dwi=dwi, kernel=kernel)

assert isinstance(gp, model.GaussianProcessModel)

X, y = make_friedman2(n_samples=500, noise=0, random_state=0)
gp.fit(X, y)
X_qry = X[:2, :]
prediction, _ = gp.predict(X_qry, return_std=True)

assert prediction.shape == (X_qry.shape[0],)


def test_two_initialisations(datadir):
"""Check that the two different initialisations result in the same models"""

Expand Down

0 comments on commit d4f4e75

Please sign in to comment.