Skip to content

Commit

Permalink
Merge pull request #924 from keflavich/argmax_rays
Browse files Browse the repository at this point in the history
BUGFIX: Argmax/argmin raywise needs to return int
  • Loading branch information
e-koch authored Oct 18, 2024
2 parents 9919364 + 25127dc commit 790acad
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
16 changes: 12 additions & 4 deletions spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ def argmax(self, axis=None, how='auto', **kwargs):
"""
return self.apply_numpy_function(np.nanargmax, fill=-np.inf,
reduce=False, projection=False,
how=how, axis=axis, **kwargs)
how=how, axis=axis,
**kwargs)

@aggregation_docstring
@warn_slow
Expand All @@ -808,7 +809,8 @@ def argmin(self, axis=None, how='auto', **kwargs):
"""
return self.apply_numpy_function(np.nanargmin, fill=np.inf,
reduce=False, projection=False,
how=how, axis=axis, **kwargs)
how=how, axis=axis,
**kwargs)

def _argmaxmin_world(self, axis, method, **kwargs):
'''
Expand Down Expand Up @@ -997,7 +999,8 @@ def _cube_on_cube_operation(self, function, cube, equivalencies=[], **kwargs):

def apply_function(self, function, axis=None, weights=None, unit=None,
projection=False, progressbar=False,
update_function=None, keep_shape=False, **kwargs):
update_function=None, keep_shape=False,
**kwargs):
"""
Apply a function to valid data along the specified axis or to the whole
cube, optionally using a weight array that is the same shape (or at
Expand Down Expand Up @@ -1054,7 +1057,12 @@ def apply_function(self, function, axis=None, weights=None, unit=None,
nz = self.shape[axis] if keep_shape else 1

# allocate memory for output array
out = np.empty([nz, nx, ny]) * np.nan
# check dtype first (for argmax/argmin)
result = function(np.arange(3, dtype=self._data.dtype), **kwargs)
if 'int' in str(result.dtype):
out = np.zeros([nz, nx, ny], dtype=result.dtype)
else:
out = np.empty([nz, nx, ny]) * np.nan

if progressbar:
progressbar = ProgressBar(nx*ny)
Expand Down
10 changes: 10 additions & 0 deletions spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ def test_argmin(self):
self._check_numpy(self.c.argmin, d, np.nanargmin)
self.c = self.d = None

def test_arg_rays(self, use_dask):
"""
regression test: argmax must have integer dtype
"""
if not use_dask:
result = self.c.argmax(how='ray')
assert 'int' in str(result.dtype)
result = self.c.argmin(how='ray')
assert 'int' in str(result.dtype)

@pytest.mark.parametrize('iterate_rays', (True,False))
def test_median(self, iterate_rays, use_dask):
# Make sure that medians ignore empty/bad/NaN values
Expand Down

0 comments on commit 790acad

Please sign in to comment.