diff --git a/python/io.cpp b/python/io.cpp index fc300c3a..5e6c8757 100644 --- a/python/io.cpp +++ b/python/io.cpp @@ -126,5 +126,10 @@ PYBIND11_MODULE(_io, m) .def(py::init(), py::arg("path"), py::arg("threads") = 0) .def(py::init&, int>(), py::arg("files"), - py::arg("threads") = 0); + py::arg("threads") = 0) + .def("_load_frames", &SectorStreamMultiPassThreadedReader::loadFrames) + .def("_num_frames_per_scan", + &SectorStreamMultiPassThreadedReader::numFramesPerScan) + .def("_scan_dimensions", + &SectorStreamMultiPassThreadedReader::scanDimensions); } diff --git a/python/stempy/io/__init__.py b/python/stempy/io/__init__.py index 86bdc048..d9b9efd3 100644 --- a/python/stempy/io/__init__.py +++ b/python/stempy/io/__init__.py @@ -64,7 +64,89 @@ class SectorThreadedReader(ReaderMixin, _threaded_reader): pass class SectorThreadedMultiPassReader(ReaderMixin, _threaded_multi_pass_reader): - pass + @property + def num_frames_per_scan(self): + """Get the number of frames per scan + + It will be cached if it has already been computed. If not, all of + the header files must be read (which can take some time), and then + it will check how many frames each position has. + """ + return self._num_frames_per_scan() + + @property + def scan_shape(self): + """Get the scan shape + + If it hasn't been done already, one header file will be read to + obtain this info. + """ + scan_dimensions = self._scan_dimensions() + # We treat the "shape" as having reversed axes than the "dimensions" + return scan_dimensions[::-1] + + def read_frames(self, scan_position, frames_slice=None): + """Read frames from the specified scan position and return them + + The scan_position is either a tuple of a valid position in the + scan_shape, or an integer that is a flattened index of the position. + + The frames_slice object will be used as an index in numpy to + determine which frames need to be read. If None, all frames + will be returned. + + Returns a list of blocks for the associated frames. + """ + if isinstance(scan_position, (list, tuple)): + # Unravel the scan position + scan_shape = self.scan_shape + if (any(not 0 <= scan_position[i] < scan_shape[i] + for i in range(len(scan_position)))): + raise IndexError( + f'Invalid position {scan_position} ' + f'for scan_shape {scan_shape}' + ) + + image_number = scan_position[0] * scan_shape[1] + scan_position[1] + else: + # Should just be an integer representing the image number + image_number = scan_position + + single_index_frame = False + + # First, get the number of frames per scan + num_frames = self.num_frames_per_scan + + # Create an arange containing all frame positions + frames = np.arange(num_frames) + + if frames_slice is not None: + if isinstance(frames_slice, (int, np.integer)): + frames_slice = [frames_slice] + single_index_frame = True + + # Slice into the frames object + try: + frames = frames[frames_slice] + except IndexError: + msg = ( + f'frames_slice "{frames_slice}" is invalid for ' + f'num_frames "{num_frames}"' + ) + raise IndexError(msg) + + blocks = [] + + raw_blocks = self._load_frames(image_number, frames) + for b in raw_blocks: + block = namedtuple('Block', ['header', 'data']) + block._block = b + block.header = b.header + block.data = np.array(b, copy=False)[0] + blocks.append(block) + + return blocks[0] if single_index_frame else blocks + def get_hdf5_reader(h5file): # the initialization is at the io.cpp diff --git a/stempy/reader.h b/stempy/reader.h index a3011d44..4292a915 100644 --- a/stempy/reader.h +++ b/stempy/reader.h @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -15,6 +16,7 @@ #include #include #include +#include #include #include @@ -514,22 +516,80 @@ class SectorStreamMultiPassThreadedReader : public SectorStreamThreadedReader template std::future readAll(Functor& f); + // Read frames at a single scan position. + // This function will first read all of the headers if they have not been + // read yet (so that it may find the scan position that is requested). + // It will then read the specified frames at the frame position. + template + void readFrames(Functor& func, uint32_t imageNumber, + const std::vector& frameIndices); + + // Read the frames and return the associated blocks. + // This is the same as the readFrames() function, but its functor just + // saves a vector of the blocks. + std::vector loadFrames(uint32_t imageNumber, + const std::vector& frameIndices); + + Dimensions2D scanDimensions() + { + readFirstHeader(); + return m_scanDimensions; + } + + uint32_t numFramesPerScan() + { + if (m_numFramesPerScan != 0) { + return m_numFramesPerScan; + } + + // It hasn't been computed yet. We need to read the headers. + initializeScanMap(); + + uint32_t numFrames = 0; + for (size_t i = 0; i < m_scanMap.size(); ++i) { + if (m_scanMap[i].size() > numFrames) { + numFrames = m_scanMap[i].size(); + } + } + + m_numFramesPerScan = numFrames; + return numFrames; + } + private: ScanMap m_scanMap; + uint32_t m_scanNumber = 0; uint32_t m_scanMapOffset = 0; uint32_t m_scanMapSize = 0; uint32_t m_streamsOffset = 0; uint32_t m_streamsSize = 0; + uint32_t m_numFramesPerScan = 0; + Dimensions2D m_scanDimensions; + // atomic to keep track of the header or frame being processed std::atomic m_processed = { 0 }; - // Mutex to lock the map of frames at each scan position std::vector> m_scanPositionMutexes; void readHeaders(); template - void processFrames(Functor& func, Header header); + void processFrames(Functor& func); + + // Process a single frame + template + void processFrame(Functor& func, uint32_t imageNumber, uint32_t frameNumber, + std::array& frameMap); + + // Initialize the thread pool if needed (does nothing if already initialized) + void initializePool(); + + // Read the first header to save some settings internally. + // Does nothing if we have already read the first header. + void readFirstHeader(); + + // Initialize the scan map if needed (does nothing if already initialized) + void initializeScanMap(); #ifdef USE_MPI int m_rank; @@ -543,11 +603,55 @@ class SectorStreamMultiPassThreadedReader : public SectorStreamThreadedReader #endif }; +template +void SectorStreamMultiPassThreadedReader::processFrame( + Functor& func, uint32_t imageNumber, uint32_t frameNumber, + std::array& frameMap) +{ + Block b; + b.header.version = version(); + b.header.scanNumber = m_scanNumber; + b.header.scanDimensions = m_scanDimensions; + b.header.imagesInBlock = 1; + b.header.frameNumber = frameNumber; + b.header.imageNumbers.resize(1); + b.header.imageNumbers[0] = imageNumber; + b.header.complete.resize(1); + + b.header.frameDimensions = FRAME_DIMENSIONS; + + b.data.reset(new uint16_t[b.header.frameDimensions.first * + b.header.frameDimensions.second], + std::default_delete()); + std::fill(b.data.get(), + b.data.get() + + b.header.frameDimensions.first * b.header.frameDimensions.second, + 0); + + short sectors = 0; + for (int j = 0; j < 4; j++) { + auto& sectorLocation = frameMap[j]; + + if (sectorLocation.sectorStream != nullptr) { + auto sectorStream = sectorLocation.sectorStream; + std::unique_lock lock(*sectorStream->mutex.get()); + sectorStream->stream->seekg(sectorLocation.offset); + readSectorData(*sectorStream->stream, b, j); + sectors++; + } + } + + // Mark if the frame is complete + b.header.complete[0] = sectors == 4; + + // Finally process the frame + func(b); +} + // Read the FrameMaps for scan and reconstruct the frame before performing the // processing template -void SectorStreamMultiPassThreadedReader::processFrames(Functor& func, - Header header) +void SectorStreamMultiPassThreadedReader::processFrames(Functor& func) { while (m_processed < m_scanMapOffset + m_scanMapSize) { auto imageNumber = m_processed++; @@ -559,59 +663,38 @@ void SectorStreamMultiPassThreadedReader::processFrames(Functor& func, auto& frameMaps = m_scanMap[imageNumber]; - // Iterate over frame maps for this scan position + // Find the frame numbers for these maps, and loop over them in order. + std::vector frameNumbers; for (const auto& f : frameMaps) { - auto frameNumber = f.first; - auto& frameMap = f.second; - - Block b; - b.header.version = version(); - b.header.scanNumber = header.scanNumber; - b.header.scanDimensions = header.scanDimensions; - b.header.imagesInBlock = 1; - b.header.frameNumber = frameNumber; - b.header.imageNumbers.resize(1); - b.header.imageNumbers[0] = imageNumber; - b.header.complete.resize(1); - - b.header.frameDimensions = FRAME_DIMENSIONS; - - b.data.reset(new uint16_t[b.header.frameDimensions.first * - b.header.frameDimensions.second], - std::default_delete()); - std::fill(b.data.get(), - b.data.get() + b.header.frameDimensions.first * - b.header.frameDimensions.second, - 0); - - short sectors = 0; - for (int j = 0; j < 4; j++) { - auto& sectorLocation = frameMap[j]; - - if (sectorLocation.sectorStream != nullptr) { - auto sectorStream = sectorLocation.sectorStream; - std::unique_lock lock(*sectorStream->mutex.get()); - sectorStream->stream->seekg(sectorLocation.offset); - readSectorData(*sectorStream->stream, b, j); - sectors++; - } - } + frameNumbers.push_back(f.first); + } - // Mark if the frame is complete - b.header.complete[0] = sectors == 4; + // Now sort them to make sure that we loop over them in a consistent order. + std::sort(frameNumbers.begin(), frameNumbers.end()); - // Finally process the frame - func(b); + // Iterate over frame maps for this scan position + for (size_t i = 0; i < frameNumbers.size(); ++i) { + auto frameNumber = frameNumbers[i]; + auto& frameMap = frameMaps[frameNumber]; + processFrame(func, imageNumber, frameNumber, frameMap); } } } -template -std::future SectorStreamMultiPassThreadedReader::readAll(Functor& func) +inline void SectorStreamMultiPassThreadedReader::initializePool() { - m_pool = std::make_unique(m_threads); + if (!m_pool) { + m_pool = std::make_unique(m_threads); + } +} + +inline void SectorStreamMultiPassThreadedReader::readFirstHeader() +{ + if (m_scanMapSize != 0) { + // We must have already read the first header. Don't do it again. + return; + } - // Read one header to get scan size auto stream = m_streams[0].stream.get(); auto header = readHeader(*stream); // Reset the stream @@ -619,7 +702,20 @@ std::future SectorStreamMultiPassThreadedReader::readAll(Functor& func) // Resize the vector to hold the frame sector locations for the scan m_scanMapSize = header.scanDimensions.first * header.scanDimensions.second; - m_scanMap.clear(); + m_scanDimensions = header.scanDimensions; + m_scanNumber = header.scanNumber; +} + +inline void SectorStreamMultiPassThreadedReader::initializeScanMap() +{ + initializePool(); + readFirstHeader(); + + if (!m_scanMap.empty()) { + // It has already been initialized. Just return. + return; + } + m_scanMap.resize(m_scanMapSize); // Allocate the mutexes @@ -657,11 +753,74 @@ std::future SectorStreamMultiPassThreadedReader::readAll(Functor& func) // Reset counter m_processed = m_scanMapOffset; +} + +template +void SectorStreamMultiPassThreadedReader::readFrames( + Functor& func, uint32_t imageNumber, + const std::vector& frameIndices) +{ + // This will only initialize the scan map if it hasn't already been + // initialized. + initializeScanMap(); + + if (imageNumber >= m_scanMap.size()) { + std::ostringstream msg; + msg << "Image number " << imageNumber << " is out of bounds! " + << "There are " << m_scanMap.size() << " scans.\n"; + throw std::invalid_argument(msg.str()); + } + + auto& frameMaps = m_scanMap[imageNumber]; + + // Find the frame numbers for these maps, and loop over them in order. + std::vector frameNumbers; + for (const auto& f : frameMaps) { + frameNumbers.push_back(f.first); + } + + // Now sort them to make sure that we loop over them in a consistent order. + std::sort(frameNumbers.begin(), frameNumbers.end()); + + for (size_t i = 0; i < frameIndices.size(); ++i) { + auto idx = frameIndices[i]; + if (idx >= frameNumbers.size()) { + std::ostringstream msg; + msg << "Frame index " << idx << " for image number " << imageNumber + << " is out of bounds! " + << "The number of frames is: " << frameNumbers.size(); + throw std::invalid_argument(msg.str()); + } + + auto frameNumber = frameNumbers[idx]; + auto& frameMap = frameMaps[frameNumber]; + processFrame(func, imageNumber, frameNumber, frameMap); + } +} + +inline std::vector SectorStreamMultiPassThreadedReader::loadFrames( + uint32_t imageNumber, const std::vector& frameIndices) +{ + std::vector ret; + + // For the functor, we will just append the blocks and return them. + auto functor = [&ret](Block& b) { ret.push_back(b); }; + + readFrames(functor, imageNumber, frameIndices); + + return ret; +} + +template +std::future SectorStreamMultiPassThreadedReader::readAll(Functor& func) +{ + initializePool(); + initializeScanMap(); // Now enqueue lambda's to read the frames and run processing for (int i = 0; i < m_threads; i++) { - m_futures.emplace_back(m_pool->enqueue( - [this, &func, header]() { processFrames(func, header); })); + m_futures.emplace_back( + m_pool->enqueue([this, &func]() { processFrames(func); })); } // Return a future that is resolved once the processing is complete