Skip to content

Commit

Permalink
Merge pull request #166 from jhlegarreta/CheckModelsAreFitBeforePredi…
Browse files Browse the repository at this point in the history
…cting

ENH: Check if `fit` has been called prior to `predict`
  • Loading branch information
oesteban authored Jun 8, 2024
2 parents 2b07b0e + e65c250 commit 59600ee
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/developers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Information on specific functions, classes, and methods.
api/eddymotion.data
api/eddymotion.data.dmri
api/eddymotion.estimator
api/eddymotion.exceptions
api/eddymotion.math
api/eddymotion.model
api/eddymotion.utils
Expand Down
32 changes: 32 additions & 0 deletions src/eddymotion/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2024 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#


class ModelNotFittedError(ValueError, AttributeError):
"""
Exception class to raise if estimator is used before fitting.
This class inherits from both ValueError and AttributeError to help with
exception handling.
"""
41 changes: 40 additions & 1 deletion src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed

from eddymotion.exceptions import ModelNotFittedError


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
Expand Down Expand Up @@ -90,12 +92,15 @@ class BaseModel:
"_b_max",
"_models",
"_datashape",
"_is_fitted",
)
_modelargs = ()

def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
"""Base initialization."""

self._is_fitted = False

# Setup B0 map
self._S0 = None
if S0 is not None:
Expand Down Expand Up @@ -134,6 +139,11 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):

self._datashape = None
self._models = None
self._is_fitted = False

@property
def is_fitted(self):
return self._is_fitted

def fit(self, data, n_jobs=None, **kwargs):
"""Fit the model chunk-by-chunk asynchronously"""
Expand Down Expand Up @@ -164,10 +174,15 @@ def fit(self, data, n_jobs=None, **kwargs):
for submodel, index in results:
self._models[index] = submodel

self._is_fitted = True
self._model = None # Preempt further actions on the model

def predict(self, gradient, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

if self._b_max is not None:
gradient[-1] = min(gradient[-1], self._b_max)

Expand Down Expand Up @@ -222,18 +237,24 @@ def __init__(self, S0=None, **kwargs):

self._S0 = S0

@property
def is_fitted(self):
return True

def fit(self, *args, **kwargs):
"""Do nothing."""

def predict(self, gradient, **kwargs):
"""Return the *b=0* map."""

# No need to check fit (if not fitted, has raised already)
return self._S0


class AverageDWModel:
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat")
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted")

def __init__(self, **kwargs):
r"""
Expand Down Expand Up @@ -262,6 +283,7 @@ def __init__(self, **kwargs):
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")
self._data = None
self._is_fitted = False

def fit(self, data, **kwargs):
"""Calculate the average."""
Expand All @@ -287,8 +309,18 @@ def fit(self, data, **kwargs):
# Calculate the average
self._data = avg_func(shells, axis=-1)

self._is_fitted = self._data is not None

@property
def is_fitted(self):
return self._is_fitted

def predict(self, gradient, **kwargs):
"""Return the average map."""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

return self._data


Expand Down Expand Up @@ -337,6 +369,10 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3,
self._shape = None
self._coeff = None

@property
def is_fitted(self):
return self._coeff is not None

def fit(self, data, *args, **kwargs):
"""Fit the model."""
from scipy.interpolate import BSpline
Expand Down Expand Up @@ -372,6 +408,9 @@ def predict(self, timepoint, **kwargs):
"""Return the *b=0* map."""
from scipy.interpolate import BSpline

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

# Project sample timing into B-Spline coordinates
x = (timepoint / self._xlim) * self._n_ctrl
A = BSpline.design_matrix(x, self._t, k=self._order)
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from eddymotion import model
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split
from eddymotion.exceptions import ModelNotFittedError


def test_trivial_model():
Expand Down Expand Up @@ -75,6 +76,9 @@ def test_average_model():
stat="mean",
)

with pytest.raises(ModelNotFittedError):
tmodel_mean.predict([0, 0, 0])

# Verify that fit function returns nothing
assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None

Expand Down Expand Up @@ -121,6 +125,10 @@ def test_two_initialisations(datadir):
bias=False,
stat="mean",
)

with pytest.raises(ModelNotFittedError):
model2.predict(data_test[1])

model2.fit(data_train[0], gtab=data_train[1])
predicted2 = model2.predict(data_test[1])

Expand Down

0 comments on commit 59600ee

Please sign in to comment.