Skip to content

Commit

Permalink
Merge pull request #280 from OpenChemistry/read-frames
Browse files Browse the repository at this point in the history
Add method for reading frames at a scan position
  • Loading branch information
psavery authored Apr 3, 2023
2 parents 76215f1 + 322e137 commit e8f264a
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 53 deletions.
7 changes: 6 additions & 1 deletion python/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,10 @@ PYBIND11_MODULE(_io, m)
.def(py::init<const std::string&, int>(), py::arg("path"),
py::arg("threads") = 0)
.def(py::init<const std::vector<std::string>&, int>(), py::arg("files"),
py::arg("threads") = 0);
py::arg("threads") = 0)
.def("_load_frames", &SectorStreamMultiPassThreadedReader::loadFrames)
.def("_num_frames_per_scan",
&SectorStreamMultiPassThreadedReader::numFramesPerScan)
.def("_scan_dimensions",
&SectorStreamMultiPassThreadedReader::scanDimensions);
}
84 changes: 83 additions & 1 deletion python/stempy/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,89 @@ class SectorThreadedReader(ReaderMixin, _threaded_reader):
pass

class SectorThreadedMultiPassReader(ReaderMixin, _threaded_multi_pass_reader):
pass
@property
def num_frames_per_scan(self):
"""Get the number of frames per scan
It will be cached if it has already been computed. If not, all of
the header files must be read (which can take some time), and then
it will check how many frames each position has.
"""
return self._num_frames_per_scan()

@property
def scan_shape(self):
"""Get the scan shape
If it hasn't been done already, one header file will be read to
obtain this info.
"""
scan_dimensions = self._scan_dimensions()
# We treat the "shape" as having reversed axes than the "dimensions"
return scan_dimensions[::-1]

def read_frames(self, scan_position, frames_slice=None):
"""Read frames from the specified scan position and return them
The scan_position is either a tuple of a valid position in the
scan_shape, or an integer that is a flattened index of the position.
The frames_slice object will be used as an index in numpy to
determine which frames need to be read. If None, all frames
will be returned.
Returns a list of blocks for the associated frames.
"""
if isinstance(scan_position, (list, tuple)):
# Unravel the scan position
scan_shape = self.scan_shape
if (any(not 0 <= scan_position[i] < scan_shape[i]
for i in range(len(scan_position)))):
raise IndexError(
f'Invalid position {scan_position} '
f'for scan_shape {scan_shape}'
)

image_number = scan_position[0] * scan_shape[1] + scan_position[1]
else:
# Should just be an integer representing the image number
image_number = scan_position

single_index_frame = False

# First, get the number of frames per scan
num_frames = self.num_frames_per_scan

# Create an arange containing all frame positions
frames = np.arange(num_frames)

if frames_slice is not None:
if isinstance(frames_slice, (int, np.integer)):
frames_slice = [frames_slice]
single_index_frame = True

# Slice into the frames object
try:
frames = frames[frames_slice]
except IndexError:
msg = (
f'frames_slice "{frames_slice}" is invalid for '
f'num_frames "{num_frames}"'
)
raise IndexError(msg)

blocks = []

raw_blocks = self._load_frames(image_number, frames)
for b in raw_blocks:
block = namedtuple('Block', ['header', 'data'])
block._block = b
block.header = b.header
block.data = np.array(b, copy=False)[0]
blocks.append(block)

return blocks[0] if single_index_frame else blocks


def get_hdf5_reader(h5file):
# the initialization is at the io.cpp
Expand Down
Loading

0 comments on commit e8f264a

Please sign in to comment.