Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Check if fit has been called prior to predict #166

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/developers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Information on specific functions, classes, and methods.
api/eddymotion.data
api/eddymotion.data.dmri
api/eddymotion.estimator
api/eddymotion.exceptions
api/eddymotion.math
api/eddymotion.model
api/eddymotion.utils
Expand Down
32 changes: 32 additions & 0 deletions src/eddymotion/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2024 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#


class ModelNotFittedError(ValueError, AttributeError):
"""
Exception class to raise if estimator is used before fitting.

This class inherits from both ValueError and AttributeError to help with
exception handling.

"""
41 changes: 40 additions & 1 deletion src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed

from eddymotion.exceptions import ModelNotFittedError


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
Expand Down Expand Up @@ -90,12 +92,15 @@ class BaseModel:
"_b_max",
"_models",
"_datashape",
"_is_fitted",
)
_modelargs = ()

def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
"""Base initialization."""

self._is_fitted = False

# Setup B0 map
self._S0 = None
if S0 is not None:
Expand Down Expand Up @@ -134,6 +139,11 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):

self._datashape = None
self._models = None
self._is_fitted = False

@property
def is_fitted(self):
return self._is_fitted

def fit(self, data, n_jobs=None, **kwargs):
"""Fit the model chunk-by-chunk asynchronously"""
Expand Down Expand Up @@ -164,10 +174,15 @@ def fit(self, data, n_jobs=None, **kwargs):
for submodel, index in results:
self._models[index] = submodel

self._is_fitted = True
self._model = None # Preempt further actions on the model

def predict(self, gradient, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

if self._b_max is not None:
gradient[-1] = min(gradient[-1], self._b_max)

Expand Down Expand Up @@ -222,18 +237,24 @@ def __init__(self, S0=None, **kwargs):

self._S0 = S0

@property
def is_fitted(self):
return True

def fit(self, *args, **kwargs):
"""Do nothing."""

def predict(self, gradient, **kwargs):
"""Return the *b=0* map."""

# No need to check fit (if not fitted, has raised already)
return self._S0


class AverageDWModel:
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat")
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted")

def __init__(self, **kwargs):
r"""
Expand Down Expand Up @@ -262,6 +283,7 @@ def __init__(self, **kwargs):
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")
self._data = None
self._is_fitted = False

def fit(self, data, **kwargs):
"""Calculate the average."""
Expand All @@ -287,8 +309,18 @@ def fit(self, data, **kwargs):
# Calculate the average
self._data = avg_func(shells, axis=-1)

self._is_fitted = self._data is not None

@property
def is_fitted(self):
return self._is_fitted

def predict(self, gradient, **kwargs):
"""Return the average map."""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

return self._data


Expand Down Expand Up @@ -337,6 +369,10 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3,
self._shape = None
self._coeff = None

@property
def is_fitted(self):
return self._coeff is not None

def fit(self, data, *args, **kwargs):
"""Fit the model."""
from scipy.interpolate import BSpline
Expand Down Expand Up @@ -372,6 +408,9 @@ def predict(self, timepoint, **kwargs):
"""Return the *b=0* map."""
from scipy.interpolate import BSpline

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

# Project sample timing into B-Spline coordinates
x = (timepoint / self._xlim) * self._n_ctrl
A = BSpline.design_matrix(x, self._t, k=self._order)
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from eddymotion import model
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split
from eddymotion.exceptions import ModelNotFittedError


def test_trivial_model():
Expand Down Expand Up @@ -75,6 +76,9 @@ def test_average_model():
stat="mean",
)

with pytest.raises(ModelNotFittedError):
tmodel_mean.predict([0, 0, 0])

# Verify that fit function returns nothing
assert tmodel_mean.fit(data[..., 1:], gtab=gtab[1:].T) is None

Expand Down Expand Up @@ -121,6 +125,10 @@ def test_two_initialisations(datadir):
bias=False,
stat="mean",
)

with pytest.raises(ModelNotFittedError):
model2.predict(data_test[1])

model2.fit(data_train[0], gtab=data_train[1])
predicted2 = model2.predict(data_test[1])

Expand Down