From 1c88ccfa47b85abe9f14dbd38961d4d22ac1da61 Mon Sep 17 00:00:00 2001 From: Alexander Ji Date: Sat, 8 Sep 2018 17:29:17 -0700 Subject: [PATCH] Draft fitter for spline models --- specutils/fitting/spline.py | 258 +++++++++++++++++++++++++++++++++ specutils/tests/test_spline.py | 60 ++++++++ 2 files changed, 318 insertions(+) create mode 100644 specutils/fitting/spline.py create mode 100644 specutils/tests/test_spline.py diff --git a/specutils/fitting/spline.py b/specutils/fitting/spline.py new file mode 100644 index 000000000..da1759642 --- /dev/null +++ b/specutils/fitting/spline.py @@ -0,0 +1,258 @@ +from __future__ import print_function, division, absolute_import + +import numpy as np +from scipy import interpolate + +from astropy.modeling.core import FittableModel, Model +from astropy.modeling.functional_models import Shift +from astropy.modeling.parameters import Parameter +from astropy.modeling.utils import poly_map_domain, comb +from astropy.modeling.fitting import _FitterMeta, fitter_unit_support +from astropy.utils import indent, check_broadcast +from astropy.units import Quantity + + +__all__ = [] + +class SplineModel(FittableModel): + """ + Wrapper around scipy.interpolate.splrep and splev + + Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified, + and scipy.interpolate.LSQUnivariateSpline if knots are specified + + There are two ways to make a spline model. + (1) you have the spline auto-determine knots from the data + (2) you specify the knots + + """ + + linear = False # I think? I have no idea? + col_fit_deriv = False # Not sure what this is + + def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0): + """ + Set up a spline model. + + degree: degree of the spline (default 3) + In scipy fitpack, this is "k" + + smoothing (optional): smoothing value for automatically determining knots + In scipy fitpack, this is "s" + By default, uses a + + knots (optional): spline knots (boundaries of piecewise polynomial) + If not specified, will automatically determine knots based on + degree + smoothing + + extrapolate_mode (optional): how to deal with solution outside of interval. + (see scipy.interpolate.splev) + if 0 (default): return the extrapolated value + if 1, return 0 + if 2, raise a ValueError + if 3, return the boundary value + """ + self._degree = degree + self._smoothing = smoothing + self._knots = self.verify_knots(knots) + self.extrapolate_mode = extrapolate_mode + + ## This is used to evaluate the spline + ## When None, raises an error when trying to evaluate the spline + self._tck = None + + self._param_names = () + + def verify_knots(self, knots): + """ + Basic knot array vetting. + The goal of having this is to enable more useful error messages + than scipy (if needed). + """ + if knots is None: return None + knots = np.array(knots) + assert len(knots.shape) == 1, knots.shape + knots = np.sort(knots) + assert len(np.unique(knots)) == len(knots), knots + return knots + + ############ + ## Getters + ############ + def get_degree(self): + """ Spline degree (k in FITPACK) """ + return self._degree + def get_smoothing(self): + """ Spline smoothing (s in FITPACK) """ + return self._smoothing + def get_knots(self): + """ Spline knots (t in FITPACK) """ + return self._knots + def get_coeffs(self): + """ Spline coefficients (c in FITPACK) """ + if self._tck is not None: + return self._tck[1] + else: + raise RuntimeError("SplineModel has not been fit yet") + + ############ + ## Spline methods: not tested at all + ############ + def derivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n) + def antiderivative(self, n=1): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n) + def integral(self, a, b): + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + else: + t, c, k = self._tck + return scipy.interpolate.BSpline.construct_fast( + t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b) + def derivatives(self, x): + raise NotImplementedError + def roots(self): + raise NotImplementedError + + ############ + ## Setters: not really implemented or tested + ############ + def reset_model(self): + """ Resets model so it needs to be refit to be valid """ + self._tck = None + def set_degree(self, degree): + """ Spline degree (k in FITPACK) """ + raise NotImplementedError + self._degree = degree + self.reset_model() + def set_smoothing(self, smoothing): + """ Spline smoothing (s in FITPACK) """ + raise NotImplementedError + self._smoothing = smoothing + self.reset_model() + def set_knots(self, knots): + """ Spline knots (t in FITPACK) """ + raise NotImplementedError + self._knots = self.verify_knots(knots) + self.reset_model() + + def set_model_from_tck(self, tck): + """ + Use output of scipy.interpolate.splrep + """ + self._tck = tck + + def __call__(self, x, der=0): + """ + Evaluate the model with the given inputs. + der is passed to scipy.interpolate.splev + """ + if self._tck is None: + raise RuntimeError("SplineModel has not been fit yet") + return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode) + + #################################### + ######### Stuff below here is stubs + @property + def param_names(self): + """ + Coefficient names generated based on the model's knots and polynomial degree. + Not Implemented + """ + raise NotImplementedError("SplineModel does not currently expose parameters") + return self._param_names + + #def __getattr__(self, attr): + # """ + # Fails right now. Future code: + # # From astropy.modeling.polynomial.PolynomialBase + # if self._param_names and attr in self._param_names: + # return Parameter(attr, default=0.0, model=self) + # raise AttributeError(attr) + # """ + # raise NotImplementedError("SplineModel does not currently expose parameters") + + #def __setattr__(self, attr, value): + # """ + # Fails right now. Future code: + # # From astropy.modeling.polynomial.PolynomialBase + # if attr[0] != '_' and self._param_names and attr in self._param_names: + # param = Parameter(attr, default=0.0, model=self) + # param.__set__(self, value) + # else: + # super().__setattr__(attr, value) + # """ + # raise NotImplementedError("SplineModel does not currently expose parameters") + + def _generate_coeff_names(self): + names = [] + degree, Nknots = self._degree, len(self._knots) + for i in range(Nknots): + for j in range(degree+1): + names.append("k{}_c{}".format(i,j)) + return tuple(names) + + def evaluate(self, *args, **kwargs): + return self(*args, **kwargs) + + + +class SplineFitter(metaclass=_FitterMeta): + """ + Run a spline fit. + """ + def __init__(self): + self.fit_info = {"fp": None, + "ier": None, + "msg": None} + super().__init__() + + def validate_model(self, model): + if not isinstance(model, SplineModel): + raise ValueError("model must be of type SplineModel (currently is {})".format( + type(model))) + + ## TODO do something about units + #@fitter_unit_support + def __call__(self, model, x, y, w=None): + """ + Fit a spline model to data. + Internally uses scipy.interpolate.splrep. + + """ + + self.validate_model(model) + + ## Case (1): fit smoothing spline + if model.get_knots() is None: + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, + t=None, + k=model.get_degree(), + s=model.get_smoothing(), + task=0, full_output=True + ) + ## Case (2): leastsq spline + else: + knots = model.get_knots() + ## TODO some sort of validation that the knots are internal, since + ## this procedure automatically adds knots at the two endpoints + tck, fp, ier, msg = interpolate.splrep(x, y, w=w, + t=knots, + k=model.get_degree(), + s=model.get_smoothing(), + task=-1, full_output=True + ) + + model.set_model_from_tck(tck) + self.fit_info.update({"fp":fp, "ier":ier, "msg":msg}) + diff --git a/specutils/tests/test_spline.py b/specutils/tests/test_spline.py new file mode 100644 index 000000000..84579d75a --- /dev/null +++ b/specutils/tests/test_spline.py @@ -0,0 +1,60 @@ +import astropy.units as u +import numpy as np + +from astropy.modeling import models, fitting +from specutils.fitting.spline import SplineModel, SplineFitter + +from scipy import interpolate + +def make_data(with_errs=True): + """ Arbitrary data """ + np.random.seed(348957) + x = np.linspace(0, 10, 200) + y = (x+1) - (x-5)**2. + 10.*np.exp(-0.5 * ((x-7.)/.5)**2.) + y = (y - np.min(y) + 10.)*10. + if with_errs: + ey = np.sqrt(y) + y = y + np.random.normal(0., ey, y.shape) + w = 1./y + return x, y, w + +def test_spline_fit(): + x, y, w = make_data() + make_plot=False + + # Construct three sets of splines and their scipy equivalents + knots = np.arange(1,10) + models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), SplineModel(smoothing=0)] + labels = ["Deg 3", "Deg 5", "Knots", "Interpolated"] + scipyfit = [interpolate.UnivariateSpline(x,y,w), + interpolate.UnivariateSpline(x,y,w,k=5), + interpolate.LSQUnivariateSpline(x,y,knots,w=w), + interpolate.InterpolatedUnivariateSpline(x,y,w)] + + fitter = SplineFitter() + for model, label, scipymodel in zip(models, labels, scipyfit): + fitter(model, x, y, w) + my_y = model(x) + sci_y = scipymodel(x) + assert np.allclose(my_y, sci_y, atol=1e-6) + + if make_plot: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.plot(x,y,'k.') + ymin, ymax = np.min(y), np.max(y) + for i,(model, label) in enumerate(zip(models, labels)): + l, = ax.plot(x, model(x), lw=1, label=label) + knots = model.get_knots() + # Hack for now + if knots is None: knots = model._tck[0] + print(knots) + dy = (ymax-ymin)/10. + dy /= i+1. + ax.vlines(knots, ymin, ymin + dy, color=l.get_color(), lw=1) + ax.legend() + plt.show() + +if __name__=="__main__": + test_spline_fit() +