diff --git a/python/stempy/io/sparse_array.py b/python/stempy/io/sparse_array.py index ab2624a8..f8c402a6 100644 --- a/python/stempy/io/sparse_array.py +++ b/python/stempy/io/sparse_array.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import copy from functools import wraps import inspect @@ -11,6 +12,11 @@ def _format_axis(func): @wraps(func) def wrapper(self, axis=None, *args, **kwargs): + if axis == 'scan': + axis = self.scan_axes + elif axis == 'frame': + axis = self.frame_axes + if axis is None: axis = self._default_axis elif not isinstance(axis, (list, tuple)): @@ -516,7 +522,7 @@ def bin_scans(self, bin_factor, in_place=False): # No need to modify the data, just the scan shape and the scan # positions. new_scan_shape = tuple(x // bin_factor for x in self.scan_shape) - all_positions = np.arange(self._scan_shape_flat[0]) + all_positions = self.scan_positions all_positions_reshaped = all_positions.reshape( new_scan_shape[0], bin_factor, new_scan_shape[1], bin_factor) @@ -643,49 +649,12 @@ def bin_frames(self, bin_factor, in_place=False): return SparseArray(**kwargs) def __getitem__(self, key): - # Make sure it is a list - if not isinstance(key, (list, tuple)): - key = [key] - else: - key = list(key) - - # Add any missing slices - while len(key) < len(self.shape): - key += [NONE_SLICE] - - # Convert all to slices - non_slice_indices = () - for i, item in enumerate(key): - if not isinstance(item, slice): - if item >= self.shape[i] or item < -self.shape[i]: - raise IndexError(f'index {item} is out of bounds for ' - f'axis {i} with size {self.shape[i]}') - - non_slice_indices += (i,) - if item == -1: - # slice(-1, 0) will not work, since negative and - # positive numbers are treated differently. - # Instead, set the next number to the last number. - next_num = self.shape[i] - else: - next_num = item + 1 - - key[i] = slice(item, next_num) - - scan_indices = np.arange(len(self.scan_shape)) - is_single_frame = all(x in non_slice_indices for x in scan_indices) + scan_slices, frame_slices = self._split_slices(key) - kwargs = { - 'slices': key, - 'non_slice_indices': non_slice_indices, - } - - if is_single_frame or not self.sparse_slicing: - f = self._slice_dense + if self.sparse_slicing: + return self._slice_sparse(scan_slices, frame_slices) else: - f = self._slice_sparse - - return f(**kwargs) + return self._slice_dense(scan_slices, frame_slices) @property def scan_positions(self): @@ -714,6 +683,26 @@ def num_frames_per_scan(self): """ return self.data.shape[1] + @property + def scan_axes(self): + """Get the axes for the scan positions + + :return: the axes for the scan positions + :rtype: tuple(int) + """ + num = len(self.scan_shape) + return tuple(range(num)) + + @property + def frame_axes(self): + """Get the axes for the frame positions + + :return: the axes for the frame positions + :rtype: tuple(int) + """ + start = len(self.scan_shape) + return tuple(range(start, len(self.shape))) + @property def _frame_shape_flat(self): return (np.prod(self.frame_shape),) @@ -727,11 +716,111 @@ def _default_axis(self): return tuple(np.arange(len(self.shape))) def _is_scan_axes(self, axis): - shape = self.shape - if len(shape) == 3 and tuple(axis) == (0,): - return True + return tuple(sorted(axis)) == self.scan_axes - return len(shape) == 4 and tuple(sorted(axis)) == (0, 1) + def _split_slices(self, slices): + """Split the slices into scan slices and frame slices + + This will also perform some validation to make sure there are no + issues. + + Returns `scan_slices, frame_slices`. + """ + def validate_if_advanced_indexing(obj, max_ndim, is_scan): + """Check if the obj is advanced indexing + (and convert to ndarray if it is) + + If it is advanced indexing, ensure the number of dimensions do + not exceed the provided max. + + returns the obj (converted to an ndarray if advanced indexing) + """ + if isinstance(obj, Sequence): + # If it's a sequence, it is advanced indexing. + # Convert to ndarray. + obj = np.asarray(obj) + + if isinstance(obj, np.ndarray): + # If it is a numpy array, it is advanced indexing. + # Ensure that there are not too many dimensions. + if obj.ndim > max_ndim: + msg = 'Too many advanced indexing dimensions.' + if is_scan: + msg += ( + ' Cannot perform advanced indexing on both the ' + 'scan positions and the frame positions ' + 'simultaneously' + ) + raise IndexError(msg) + + return obj + + if not isinstance(slices, tuple): + # Wrap it with a tuple to simplify + slices = (slices,) + + if not slices: + # It's an empty tuple + return tuple(), tuple() + + # Figure out which slices belong to which parts. + # The first slice should definitely be part of the scan. + first_slice = slices[0] + first_slice = validate_if_advanced_indexing(first_slice, + len(self.scan_shape), + is_scan=True) + + frame_start = 1 + if len(self.scan_shape) > 1: + # We might have 2 slices for the scan shape + if not isinstance(first_slice, + np.ndarray) or first_slice.ndim == 1: + # We have another scan slice + frame_start += 1 + + if frame_start == 2 and len(slices) > 1: + # Validate the second scan slice. + second_slice = validate_if_advanced_indexing(slices[1], 1, + is_scan=True) + scan_slices = (first_slice, second_slice) + else: + scan_slices = (first_slice,) + + # If there are frame indices, validate them too + frame_slices = tuple() + for i in range(frame_start, len(slices)): + max_ndim = frame_start + 2 - i + if max_ndim == 0: + raise IndexError('Too many indices for frame positions') + frame_slices += (validate_if_advanced_indexing(slices[i], max_ndim, + is_scan=False),) + + # Verify that we are not doing advanced indexing on both the scan + # positions and frame positions simultaneously. + if (any(isinstance(x, np.ndarray) for x in scan_slices) and + any(isinstance(x, np.ndarray) for x in frame_slices)): + msg = ( + 'Cannot perform advanced indexing on both scan positions ' + 'and frame positions simultaneously' + ) + raise IndexError(msg) + + # Verify that if there are any 2D advanced indexing arrays, they + # must be boolean and of the same shape. + first_frame_slice = (slice(None) if not frame_slices + else frame_slices[0]) + for i, to_check in enumerate((first_slice, first_frame_slice)): + req_shape = self.scan_shape if i == 0 else self.frame_shape + if (isinstance(to_check, np.ndarray) and to_check.ndim == 2 and + (to_check.dtype != np.bool_ or to_check.shape != req_shape)): + msg = ( + '2D advanced indexing is only allowed for boolean arrays ' + 'that match either the scan shape or the frame shape ' + '(whichever it is indexing into)' + ) + raise IndexError(msg) + + return scan_slices, frame_slices def _sparse_frames(self, indices): if not isinstance(indices, (list, tuple)): @@ -749,109 +838,70 @@ def _sparse_frames(self, indices): return self.data[scan_ind] - def _slice_dense(self, slices, non_slice_indices=None): - # non_slice_indices indicate which indices should be squeezed - # out of the result. - if non_slice_indices is None: - non_slice_indices = [] - - if all(x == NONE_SLICE for x in slices) and not self.allow_full_expand: - raise FullExpansionDenied('Full expansion is not allowed') - - data_shape = self.shape + def _slice_dense(self, scan_slices, frame_slices): + new_scan_shape = np.empty(self.scan_shape, + dtype=bool)[scan_slices].shape + new_frame_shape = np.empty(self.frame_shape, + dtype=bool)[frame_slices].shape - def slice_range(ind): - # Get the range generated by this slice - return range(*slices[ind].indices(data_shape[ind])) + result_shape = new_scan_shape + new_frame_shape - # Determine the shape of the result - result_shape = () - for i in range(len(data_shape)): - result_shape += (len(slice_range(i)),) + if result_shape == self.shape and not self.allow_full_expand: + raise FullExpansionDenied('Full expansion is not allowed') # Create the result result = np.zeros(result_shape, dtype=self.dtype) + all_positions = self.scan_positions.reshape(self.scan_shape) + sliced = all_positions[scan_slices].ravel() + + scan_indices = np.unravel_index(sliced, self.scan_shape) + scan_indices = np.array(scan_indices).T + # We will currently expand whole frames at a time - expand_num = len(self.scan_shape) - - # Lists to use in the recursion - current_indices = [] - result_indices = [] - - def iterate(): - ind = len(current_indices) - result_indices.append(0) - for i in slice_range(ind): - current_indices.append(i) - if len(current_indices) == expand_num: - output = self._expand(current_indices) - # This could be faster if we only expand what we need - output = output[tuple(slices[-output.ndim:])] - result[tuple(result_indices)] = output - else: - iterate() - result_indices[-1] += 1 - current_indices.pop() - result_indices.pop() - - iterate() - - # Squeeze out the non-slice indices - return result.squeeze(axis=non_slice_indices) - - def _slice_sparse(self, slices, non_slice_indices=None): - # non_slice_indices indicate which indices should be squeezed - # out of the result. - if non_slice_indices is None: - non_slice_indices = [] - - if len(slices) != len(self.shape): - raise Exception('Slices must be same length as shape') - - if any(not isinstance(x, slice) for x in slices): - raise Exception('All slices must be slice objects') - - scan_slices = tuple(slices[:len(self.scan_shape)]) - frame_slices = tuple(slices[len(self.scan_shape):]) - - scan_shape_modified = any(x != NONE_SLICE for x in scan_slices) - frame_shape_modified = any(x != NONE_SLICE for x in frame_slices) - - def slice_range(slice, length): - # Get the range generated by this slice - return range(*slice.indices(length)) + for i, indices in enumerate(scan_indices): + output = self._expand(tuple(indices)) + # This could be faster if we only expand what we need + output = output[frame_slices] + result_indices = np.unravel_index(i, new_scan_shape) + result[tuple(result_indices)] = output + + return result + + def _slice_sparse(self, scan_slices, frame_slices): + scan_shape_modified = any(not _is_none_slice(x) for x in scan_slices) + frame_shape_modified = any(not _is_none_slice(x) for x in frame_slices) if scan_shape_modified: - new_scan_shape = () - for i, (s, length) in enumerate(zip(scan_slices, self.scan_shape)): - if i not in non_slice_indices: - new_scan_shape += (len(slice_range(s, length)),) - - all_positions = np.arange(self._scan_shape_flat[0]).reshape( - self.scan_shape) - positions_to_keep = all_positions[scan_slices].ravel() + all_positions = self.scan_positions.reshape(self.scan_shape) + sliced = all_positions[scan_slices] + + if isinstance(sliced, np.integer): + # Everything was sliced except one frame. + # Return the dense frame instead. + return self._slice_dense(scan_slices, frame_slices) + + new_scan_shape = sliced.shape + positions_to_keep = sliced.ravel() new_frames = self.data[positions_to_keep] else: new_scan_shape = self.scan_shape new_frames = self.data if frame_shape_modified: - new_frame_shape = () - for i, (s, length) in enumerate(zip(frame_slices, - self.frame_shape)): - if i + len(self.scan_shape) not in non_slice_indices: - new_frame_shape += (len(slice_range(s, length)),) + + # Map old frame indices to new ones. Invalid values will be -1. + frame_indices = np.arange(self._frame_shape_flat[0]).reshape( + self.frame_shape) + sliced = frame_indices[frame_slices] + new_frame_shape = sliced.shape if not new_frame_shape: # We don't support an empty frame shape. # Just return the dense array instead. - return self._slice_dense(slices, non_slice_indices) + return self._slice_dense(scan_slices, frame_slices) - # Map old frame indices to new ones. Invalid values will be -1. - frame_indices = np.arange(self._frame_shape_flat[0]).reshape( - self.frame_shape) - valid_flat_frame_indices = frame_indices[frame_slices].ravel() + valid_flat_frame_indices = sliced.ravel() new_frame_indices_map = np.full(self._frame_shape_flat, -1) new_indices = np.arange(len(valid_flat_frame_indices)) @@ -921,6 +971,42 @@ def _expand(self, indices): max_length = len(self.shape) + 1 raise ValueError(f'0 < len(indices) < {max_length} is required') + @staticmethod + def _is_advanced_indexing(obj): + """Look at the object to see if it is considered advanced indexing + + We will follow the logic taken from here: + https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing + + "Advanced indexing is triggered when the selection object, obj, is a + non-tuple sequence object, an ndarray (of data type integer or bool), + or a tuple with at least one sequence object or ndarray (of data + type integer or bool)." + """ + + def is_int_or_bool_ndarray(x): + """Check if x is an ndarray of type int or bool""" + if not isinstance(x, np.ndarray): + return False + + return issubclass(x.dtype.type, (np.integer, np.bool_)) + + if not isinstance(obj, tuple) and isinstance(obj, Sequence): + return True + + if is_int_or_bool_ndarray(obj): + return True + + if isinstance(obj, tuple): + return any(isinstance(x, Sequence) or is_int_or_bool_ndarray(x) + for x in obj) + + return False + + +def _is_none_slice(x): + return isinstance(x, slice) and x == NONE_SLICE + def _warning(msg): print(f'Warning: {msg}', file=sys.stderr) diff --git a/tests/test_sparse_array.py b/tests/test_sparse_array.py index f4526ae0..3fd70377 100644 --- a/tests/test_sparse_array.py +++ b/tests/test_sparse_array.py @@ -97,6 +97,26 @@ def test_sparse_slicing(sparse_array_small, full_array_small): assert len(was_tested) >= 5 +def test_frame_slicing(sparse_array_small, full_array_small): + array = sparse_array_small + full = full_array_small + + # These are common use cases we want to make sure are supported + array.allow_full_expand = True + array.sparse_slicing = True + + # First test the sparse slicing version + sliced = array[:, :, 100:200, 200:300] + sliced.sparse_slicing = False + + assert np.array_equal(sliced[:], full[:, :, 100:200, 200:300]) + + # Now test the dense slicing version + array.sparse_slicing = False + assert np.array_equal(array[:, :, 100:200, 200:300], + full[:, :, 100:200, 200:300]) + + def test_arithmetic(sparse_array_small, full_array_small): array = sparse_array_small full = full_array_small @@ -241,9 +261,9 @@ def run_it(sparse_slicing): assert array[0:1, :, :, :].shape == full[0:1, :, :, :].shape assert array[0:2, :, :, :].shape == full[0:2, :, :, :].shape - assert array[:, 0:, :, :].shape == full[:, 0:, :, :].shape assert array[:, 0:1, :, :].shape == full[:, 0:1, :, :].shape assert array[:, 0:2, :, :].shape == full[:, 0:2, :, :].shape + assert array[:, 0:3, :, :].shape == full[:, 0:3, :, :].shape assert array[:, :, 0, :].shape == full[:, :, 0, :].shape assert array[:, :, 0:1, :].shape == full[:, :, 0:1, :].shape @@ -514,6 +534,143 @@ def test_data_conversion(cropped_multi_frames_v1, cropped_multi_frames_v2, assert array.num_frames_per_scan == 2 +def test_advanced_indexing(sparse_array_small, full_array_small): + array = sparse_array_small + full = full_array_small + + # First test advanced integer indexing + to_test = [ + [[0]], + [[1], [14]], + [[10], [-1]], + [[10, 7, 3], [7, 0, 9]], + [[3, 2, 1], [0, 1, 2]], + [[-1, 4, -3], [-3, 8, 7]], + [[7], 4], + [3, [8]], + [[-1, 4, -3], slice(2, 18, 2)], + [3, 4, [3, 2, 1], [9, 8, 7]], + [slice(None), slice(None), [5, 45, 32], [-3, -7, 23]], + [slice(None), slice(2, 18, 2), slice(0, 100, 3), [-5, 1, 2]], + ] + + def compare_with_sparse(full, sparse): + # This will fully expand the sparse array and perform a comparison + try: + if isinstance(sparse, SparseArray): + # This will be a SparseArray except for single frame cases + prev_allow_full_expand = sparse.allow_full_expand + prev_sparse_slicing = sparse.sparse_slicing + + sparse.allow_full_expand = True + sparse.sparse_slicing = False + + return np.array_equal(full, sparse[:]) + finally: + if isinstance(sparse, SparseArray): + sparse.allow_full_expand = prev_allow_full_expand + sparse.sparse_slicing = prev_sparse_slicing + + # First test sparse slicing + array.sparse_slicing = True + for index in to_test: + index = tuple(index) + assert compare_with_sparse(full[index], array[index]) + + # Now test dense slicing + array.sparse_slicing = False + for index in to_test: + index = tuple(index) + assert np.array_equal(full[index], array[index]) + + # Now test boolean indexing. + # We'll make some random masks for this. + rng = np.random.default_rng(0) + + num_tries = 3 + for i in range(num_tries): + scan_mask = rng.choice([True, False], array.scan_shape) + frame_mask = rng.choice([True, False], array.frame_shape) + + array.sparse_slicing = True + assert compare_with_sparse(full[scan_mask], array[scan_mask]) + assert compare_with_sparse(full[scan_mask[0]], array[scan_mask[0]]) + + assert compare_with_sparse(full[:, :, frame_mask], array[:, :, frame_mask]) + assert compare_with_sparse(full[:, :, frame_mask[0]], array[:, :, frame_mask[0]]) + + array.sparse_slicing = False + assert np.array_equal(full[scan_mask], array[scan_mask]) + assert np.array_equal(full[scan_mask[0]], array[scan_mask[0]]) + + assert np.array_equal(full[:, :, frame_mask], array[:, :, frame_mask]) + assert np.array_equal(full[:, :, frame_mask[0]], array[:, :, frame_mask[0]]) + + # These should raise index errors + with pytest.raises(IndexError): + # Can't do advanced indexing for both the scan and the frame + # simultaneously. + array[scan_mask, frame_mask] + + with pytest.raises(IndexError): + # Too many dimensions for the scan mask. + array[scan_mask[:, :, np.newaxis]] + + with pytest.raises(IndexError): + # We don't allow 2D arrays that aren't boolean masks, because + # an extra dimension gets added, which doesn't make sense for us. + array[[[1, 2], [5, 4]]] + + # Now test a few arithmetic operations combined with the indexing + array.sparse_slicing = True + assert np.array_equal(array[scan_mask].sum(axis='scan'), + full[scan_mask].sum(axis=(0,))) + assert np.array_equal(array[scan_mask].min(axis='scan'), + full[scan_mask].min(axis=(0,))) + assert np.array_equal(array[scan_mask].max(axis='scan'), + full[scan_mask].max(axis=(0,))) + assert np.allclose(array[scan_mask].mean(axis='scan'), + full[scan_mask].mean(axis=(0,))) + + # Now do some basic tests with multiple frames per scan + data = np.empty((2, 2), dtype=object) + data[0][0] = np.array([0]) + data[0][1] = np.array([0, 2]) + data[1][0] = np.array([0]) + data[1][1] = np.array([0, 1]) + kwargs = { + 'data': data, + 'scan_shape': (2, 1), + 'frame_shape': (2, 2), + 'sparse_slicing': False, + 'allow_full_expand': True, + } + # The multiple frames per scan array + m_array = SparseArray(**kwargs) + + # These are our expected expansions of the positions + position_zero = np.array([[2, 0], [1, 0]]) + position_one = np.array([[2, 1], [0, 0]]) + + # Verify our assumption is true + assert np.array_equal(m_array[0, 0], position_zero) + assert np.array_equal(m_array[1, 0], position_one) + + # Now test some simple fancy indexing + assert np.array_equal(m_array[[0], [0]][0], position_zero) + assert np.array_equal(m_array[[0, 1], [0]][0], position_zero) + assert np.array_equal(m_array[[0, 1], [0]][1], position_one) + assert np.array_equal(m_array[[1, 0], [0]][0], position_one) + assert np.array_equal(m_array[[1, 0], [0]][1], position_zero) + assert np.array_equal(m_array[0, 0, [0, 1], [0, 0]], [2, 1]) + assert np.array_equal(m_array[0, 0, [[True, False], [True, False]]], + [2, 1]) + assert np.array_equal(m_array[0, 0, [[False, True], [False, True]]], + [0, 0]) + assert np.array_equal(m_array[[True, False], 0][0], position_zero) + assert np.array_equal(m_array[[False, True], 0][0], position_one) + + # Test binning until this number TEST_BINNING_UNTIL = 33 @@ -538,6 +695,6 @@ def test_data_conversion(cropped_multi_frames_v1, cropped_multi_frames_v2, slice(None, None, 5)), (slice(3, None, 2), slice(5, None, 5), slice(4, None, 4), slice(20, None, 5)), - (slice(None, None, -1), slice(20, 4, -2), slice(4, None, -3), + (slice(20, 4, -2), slice(None, None, -1), slice(4, None, -3), slice(100, 3, -5)), ]