Skip to content

Commit

Permalink
Fix NaN handling in cube fitting (spacetelescope#3191)
Browse files Browse the repository at this point in the history
* Fix NaN handling in cube fitting

Remove debugging prints, add comment for context

Codestyle, changelog

Remove unit conversion stuff to just fix NaN handling

* Move changelog

* Add test

* Codestyle

* Adding test for spectrum with existing mask

* Fix a couple bugs with a loading a masked cube

* Linking fix for case with extra components like mask

* Test mask addition behavior with a subset instead

* Codestyle

* Ignore UserWarning for Ubuntu tests

* Update jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py

Co-authored-by: P. L. Lim <[email protected]>

* Do this as in above test

---------

Co-authored-by: P. L. Lim <[email protected]>
  • Loading branch information
rosteen and pllim authored Sep 13, 2024
1 parent 6b83d77 commit d26971f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ Cubeviz

- No longer incorrectly swap RA and Dec axes when loading Spectrum1D objects. [#3133]


Imviz
^^^^^

Expand Down Expand Up @@ -239,6 +238,8 @@ Bug Fixes
Cubeviz
^^^^^^^

- Fixed fitting a model to the entire cube when NaNs are present. [#3191]

Imviz
^^^^^

Expand Down
5 changes: 3 additions & 2 deletions jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,9 @@ def _link_new_data(self, reference_data=None, data_to_be_linked=None):
return

elif self.config == 'cubeviz' and linked_data.ndim == 1:
ref_wavelength_component = dc[0].components[-2]
ref_flux_component = dc[0].components[-1]
# Don't want to use negative indices in case there are extra components like a mask
ref_wavelength_component = dc[0].components[5]
ref_flux_component = dc[0].components[6]
linked_wavelength_component = dc[-1].components[1]
linked_flux_component = dc[-1].components[-1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,9 @@ def _extract_from_aperture(self, cube, uncert_cube, aperture,
wcs = cube.meta['_orig_spec'].wcs.spectral
elif hasattr(cube.coords, 'spectral'):
wcs = cube.coords.spectral
elif hasattr(cube.coords, 'spectral_wcs'):
# This is the attribute for a PaddedSpectrumWCS in the 3D case
wcs = cube.coords.spectral_wcs
else:
wcs = None

Expand Down
4 changes: 2 additions & 2 deletions jdaviz/configs/default/plugins/model_fitting/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def initialize(self, instance, x, y):
instance: `~astropy.modeling.Model`
The initialized model.
"""
y_mean = np.mean(y)
y_mean = np.nanmean(y)
x_range = x[-1] - x[0]
position = x_range / 2.0 + x[0]

Expand Down Expand Up @@ -190,7 +190,7 @@ def initialize(self, instance, x, y):

# width can be estimated by the weighted
# 2nd moment of the X coordinate.
dx = x - np.mean(x)
dx = x - np.nanmean(x)
fwhm = 2 * np.sqrt(np.sum((dx * dx) * y) / np.sum(y))

# amplitude is derived from area.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None):
masked_spectrum.flux[~mask] if mask is not None else masked_spectrum.flux)

# need to loop over parameters again as the initializer may have overridden
# the original default value
# the original default value.
for param_name in get_model_parameters(model_cls, new_model["model_kwargs"]):
param_quant = getattr(initialized_model, param_name)
new_model["parameters"].append({"name": param_name,
Expand Down Expand Up @@ -917,6 +917,12 @@ def _fit_model_to_cube(self, add_data):

# Apply masks from selected spectral subset
spec = self._apply_subset_masks(spec, self.spectral_subset)
# Also mask out NaNs for fitting. Simply adding filter_non_finite to the cube fit
# didn't work out of the box, so doing this for now.
if spec.mask is None:
spec.mask = np.isnan(spec.flux)
else:
spec.mask = spec.mask | np.isnan(spec.flux)

try:
fitted_model, fitted_spectrum = fit_model_to_spectrum(
Expand Down
39 changes: 39 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astropy.nddata import StdDevUncertainty
from astropy.tests.helper import assert_quantity_allclose
from astropy.wcs import WCS
from glue.core.roi import XRangeROI
from numpy.testing import assert_allclose, assert_array_equal
from specutils.spectra import Spectrum1D

Expand Down Expand Up @@ -378,3 +379,41 @@ def test_incompatible_units(specviz_helper, spectrum1d):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit(add_data=True)


def test_cube_fit_with_nans(cubeviz_helper):
flux = np.ones((7, 8, 9)) * u.nJy
flux[:, :, 0] = np.nan
spec = Spectrum1D(flux=flux)
cubeviz_helper.load_data(spec, data_label="test")

mf = cubeviz_helper.plugins["Model Fitting"]
mf.cube_fit = True
mf.create_model_component("Const1D")
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit()
result = cubeviz_helper.app.data_collection['model']
assert np.all(result.get_component("flux").data == 1)


def test_cube_fit_with_subset_and_nans(cubeviz_helper):
# Also test with existing mask
flux = np.ones((7, 8, 9)) * u.nJy
flux[:, :, 0] = np.nan
spec = Spectrum1D(flux=flux)
spec.flux[5, 5, 7] = 10 * u.nJy
cubeviz_helper.load_data(spec, data_label="test")

sv = cubeviz_helper.app.get_viewer('spectrum-viewer')
sv.apply_roi(XRangeROI(0, 5))

mf = cubeviz_helper.plugins["Model Fitting"]
mf.cube_fit = True
mf.spectral_subset = 'Subset 1'
mf.create_model_component("Const1D")
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit()
result = cubeviz_helper.app.data_collection['model']
assert np.all(result.get_component("flux").data == 1)

0 comments on commit d26971f

Please sign in to comment.