Skip to content

Commit

Permalink
Merge pull request #318 from ercius/fix_remove_flyback
Browse files Browse the repository at this point in the history
Fix remove_flyback algorithm
  • Loading branch information
cjh1 authored Oct 2, 2024
2 parents 08593dc + 7339501 commit 0188853
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
23 changes: 13 additions & 10 deletions python/stempy/io/sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,30 @@ def from_hdf5(cls, filepath, keep_flyback=True, **init_kwargs):
scan_positions_group = f['electron_events/scan_positions']
scan_shape = [scan_positions_group.attrs[x] for x in ['Nx', 'Ny']]
frame_shape = [frames.attrs[x] for x in ['Nx', 'Ny']]

if keep_flyback:
data = frames[()] # load the full data set
scan_positions = scan_positions_group[()]
else:
# Generate the original scan indices from the scan_shape
orig_indices = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape)
# Remove the indices of the last column
crop_indices = np.delete(orig_indices, orig_indices[scan_shape[0]-1::scan_shape[0]])
# Load only the data needed
data = frames[crop_indices]
# Reduce the column shape by 1
scan_shape[0] = scan_shape[0] - 1
num = frames.shape[0] // np.prod(scan_shape, dtype=int) # number of frames per probe position
data = np.empty(((scan_shape[0]-1) * scan_shape[1] * num), dtype=object)
new_num_cols = scan_shape[0]-1 # number of columns without flyback
for ii in range(scan_shape[1]):
start = ii*new_num_cols*num # start of cropped data
end = (ii+1)*new_num_cols*num
start2 = ii*new_num_cols*num + num*ii # start of uncropped data
end2 = (ii+1)*new_num_cols*num + num*ii
data[start:end] = frames[start2:end2]
scan_shape = (scan_shape[0]-1, scan_shape[1]) # update scan shape
# Create the proper scan_positions without the flyback column
scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape)
scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)], scan_shape)

# Load any metadata
metadata = {}
if 'metadata' in f:
load_h5_to_dict(f['metadata'], metadata)

# reverse the scan shape to match expected shape
scan_shape = scan_shape[::-1]

if version >= 3:
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def cropped_multi_frames_v2(cropped_multi_frames_data_v2):
def cropped_multi_frames_v3(cropped_multi_frames_data_v3):
return SparseArray.from_hdf5(cropped_multi_frames_data_v3, dtype=np.uint16)

@pytest.fixture
def cropped_multi_frames_v3_noflyback(cropped_multi_frames_data_v3):
return SparseArray.from_hdf5(cropped_multi_frames_data_v3,
dtype=np.uint16, keep_flyback=False)

@pytest.fixture
def simulate_sparse_array():

Expand Down
12 changes: 7 additions & 5 deletions tests/test_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,11 +710,13 @@ def compare_with_sparse(full, sparse):
assert np.array_equal(m_array[[False, True], 0][0], position_one)


def test_keep_flyback(electron_data_small):
flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=True)
assert flyback.scan_shape[1] == 50
no_flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=False)
assert no_flyback.scan_shape[1] == 49
def test_keep_flyback(cropped_multi_frames_v3, cropped_multi_frames_v3_noflyback):
# Test keeping the flyback
assert cropped_multi_frames_v3.scan_shape[1] == 20
assert cropped_multi_frames_v3.num_frames_per_scan == 2
# Test removing the flyback
assert cropped_multi_frames_v3_noflyback.scan_shape[1] == 19
assert cropped_multi_frames_v3_noflyback.num_frames_per_scan == 2

# Test binning until this number
TEST_BINNING_UNTIL = 33
Expand Down

0 comments on commit 0188853

Please sign in to comment.