diff --git a/starfish/core/image/Filter/reduce.py b/starfish/core/image/Filter/reduce.py index 4a3f4af5e..fca8e5847 100644 --- a/starfish/core/image/Filter/reduce.py +++ b/starfish/core/image/Filter/reduce.py @@ -180,7 +180,7 @@ def run( """ # Apply the reducing function - reduced = stack._data.reduce( + reduced = stack.xarray.reduce( self.func, dim=[Axes(dim).value for dim in self.dims], **self.kwargs) # Add the reduced dims back and align with the original stack diff --git a/starfish/core/image/Filter/zero_by_channel_magnitude.py b/starfish/core/image/Filter/zero_by_channel_magnitude.py index d5ffe8d07..28aac8b4b 100644 --- a/starfish/core/image/Filter/zero_by_channel_magnitude.py +++ b/starfish/core/image/Filter/zero_by_channel_magnitude.py @@ -61,7 +61,7 @@ def run( """ # The default is False, so even if code requests True require config to be True as well verbose = verbose and StarfishConfig().verbose - channels_per_round = stack._data.groupby(Axes.ROUND.value) + channels_per_round = stack.xarray.groupby(Axes.ROUND.value) channels_per_round = tqdm(channels_per_round) if verbose else channels_per_round if not in_place: diff --git a/starfish/core/imagestack/imagestack.py b/starfish/core/imagestack/imagestack.py index 5dec5eb40..3ec62d1de 100644 --- a/starfish/core/imagestack/imagestack.py +++ b/starfish/core/imagestack/imagestack.py @@ -1,10 +1,11 @@ import collections import warnings +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial from itertools import product -from json import loads from pathlib import Path +from threading import Lock from typing import ( Any, Callable, @@ -89,6 +90,7 @@ class ImageStack: def __init__(self, data: MPDataArray, tile_data: Optional[TileCollectionData]=None): self._data = data + self._data_loaded = False self._tile_data = tile_data self._log: List[dict] = list() @@ -129,49 +131,38 @@ def from_tile_collection_data(cls, tile_data: TileCollectionData) -> "ImageStack data = MPDataArray.from_shape_and_dtype( shape=data_shape, dtype=np.float32, - initial_value=0, + initial_value=np.nan, dims=data_dimensions, coords=data_tick_marks, ) - imagestack = ImageStack(data, tile_data) - # TODO: (ttung) move more of the initialization code above the constructor call. - all_selectors = list(imagestack._iter_axes({Axes.ROUND, Axes.CH, Axes.ZPLANE})) first_selector = all_selectors[0] - tile = tile_data.get_tile(r=first_selector[Axes.ROUND], - ch=first_selector[Axes.CH], - z=first_selector[Axes.ZPLANE]) + first_tile = tile_data.get_tile( + r=first_selector[Axes.ROUND], + ch=first_selector[Axes.CH], + z=first_selector[Axes.ZPLANE]) # Set up coordinates imagestack._data[Coordinates.X.value] = xr.DataArray( - tile.coordinates[Coordinates.X], dims=Axes.X.value) + first_tile.coordinates[Coordinates.X], dims=Axes.X.value) imagestack._data[Coordinates.Y.value] = xr.DataArray( - tile.coordinates[Coordinates.Y], dims=Axes.Y.value) + first_tile.coordinates[Coordinates.Y], dims=Axes.Y.value) # Fill with nan for now, then replace with calculated midpoints imagestack._data[Coordinates.Z.value] = xr.DataArray( - np.full(imagestack.xarray.sizes[Axes.ZPLANE.value], np.nan), + np.full(imagestack._data.sizes[Axes.ZPLANE.value], np.nan), dims=Axes.ZPLANE.value) - # Get coords on first tile, then verify all subsequent tiles are aligned - starting_coords = tile.coordinates - - tile_dtypes = set() - for selector in tqdm(all_selectors): + for selector in all_selectors: tile = tile_data.get_tile( r=selector[Axes.ROUND], ch=selector[Axes.CH], z=selector[Axes.ZPLANE]) - data = tile.numpy_array - tile_dtypes.add(data.dtype) - - data = img_as_float32(data) - imagestack.set_slice(selector=selector, data=data) if not ( np.array_equal( - starting_coords[Coordinates.X], tile.coordinates[Coordinates.X]) + first_tile.coordinates[Coordinates.X], tile.coordinates[Coordinates.X]) and np.array_equal( - starting_coords[Coordinates.Y], tile.coordinates[Coordinates.Y]) + first_tile.coordinates[Coordinates.Y], tile.coordinates[Coordinates.Y]) ): raise ValueError(f"Tiles must be aligned") if Coordinates.Z in tile.coordinates: @@ -179,17 +170,6 @@ def from_tile_collection_data(cls, tile_data: TileCollectionData) -> "ImageStack imagestack._data[Coordinates.Z.value].loc[selector[Axes.ZPLANE]] = \ tile.coordinates[Coordinates.Z][0] - tile_dtype_kinds = set(tile_dtype.kind for tile_dtype in tile_dtypes) - tile_dtype_sizes = set(tile_dtype.itemsize for tile_dtype in tile_dtypes) - if len(tile_dtype_kinds) != 1: - raise TypeError("All tiles should have the same kind of dtype") - if len(tile_dtype_sizes) != 1: - warnings.warn("Not all tiles have the same precision data", DataFormatWarning) - - # check for existing log info - if STARFISH_EXTRAS_KEY in tile_data.extras and LOG in tile_data.extras[STARFISH_EXTRAS_KEY]: - imagestack._log = loads(tile_data.extras[STARFISH_EXTRAS_KEY])[LOG] - return imagestack @staticmethod @@ -206,6 +186,48 @@ def _validate_data_dtype_and_range(data: Union[np.ndarray, xr.DataArray]) -> Non f"data using skimage.img_as_float32 prior to calling set_slice." ) + def _ensure_data_loaded(self) -> "ImageStack": + """Loads the data into the imagestack object. All operations should automatically call this + before operating on the data. + """ + if self._data_loaded: + return self + + all_selectors = list(self._iter_axes({Axes.ROUND, Axes.CH, Axes.ZPLANE})) + pbar = tqdm(total=len(all_selectors)) + lock = Lock() + + def load_by_selector(selector): + tile = self._tile_data.get_tile( + r=selector[Axes.ROUND], ch=selector[Axes.CH], z=selector[Axes.ZPLANE]) + data = tile.numpy_array + tile_dtype = data.dtype + + data = img_as_float32(data) + with lock: + # setting data is not thread-safe. + self.set_slice(selector=selector, data=data, from_loader=True) + + pbar.update(1) + + return tile_dtype + + with ThreadPoolExecutor() as tpe: + # gather all the data types of the tiles to ensure that they are compatible. + tile_dtypes = set(tpe.map(load_by_selector, all_selectors)) + pbar.close() + + tile_dtype_kinds = set(tile_dtype.kind for tile_dtype in tile_dtypes) + tile_dtype_sizes = set(tile_dtype.itemsize for tile_dtype in tile_dtypes) + if len(tile_dtype_kinds) != 1: + raise TypeError("All tiles should have the same kind of dtype") + if len(tile_dtype_sizes) != 1: + warnings.warn("Not all tiles have the same precision data", DataFormatWarning) + + self._data_loaded = True + + return self + def __repr__(self): shape = ', '.join(f'{k}: {v}' for k, v in self._data.sizes.items()) return f"" @@ -351,6 +373,7 @@ def from_numpy( @property def xarray(self) -> xr.DataArray: """Retrieves the image data as an :py:class:`xarray.DataArray`""" + self._ensure_data_loaded() return self._data.data def sel(self, indexers: Mapping[Axes, Union[int, tuple]]): @@ -385,6 +408,7 @@ def sel(self, indexers: Mapping[Axes, Union[int, tuple]]): ImageStack : a new image stack indexed by given value or range. """ + self._ensure_data_loaded() stack = deepcopy(self) selector = indexing_utils.convert_to_selector(indexers) stack._data._data = indexing_utils.index_keep_dimensions(self.xarray, selector) @@ -506,7 +530,7 @@ def get_slice( """ formatted_indexers = indexing_utils.convert_to_selector(selector) _, axes = self._build_slice_list(selector) - result = self._data.sel(formatted_indexers).values + result = self.xarray.sel(formatted_indexers).values if result.dtype != np.float32: warnings.warn( @@ -521,7 +545,9 @@ def set_slice( self, selector: Mapping[Axes, Union[int, slice]], data: np.ndarray, - axes: Optional[Sequence[Axes]]=None): + axes: Optional[Sequence[Axes]]=None, + from_loader: bool = False, + ): """ Given a dictionary mapping the index name to either a value or a slice range and a source numpy array, set the slice of the array of this ImageStack to the values in the source @@ -596,6 +622,8 @@ def set_slice( >>> new_data = np.zeros((3, 2, 10, 20), dtype=np.float32) >>> stack.set_slice({Axes.ZPLANE: 5, Axes.CH: slice(2, 4)}, new_data) """ + if not from_loader: + self._ensure_data_loaded() self._validate_data_dtype_and_range(data) @@ -764,7 +792,7 @@ def apply( # scale based on values of whole image if clip_method == Clip.SCALE_BY_IMAGE: - self._data.data.values = preserve_float_range(self._data.data.values, rescale=True) + self._data.data.values = preserve_float_range(self._data.values, rescale=True) return self @@ -813,6 +841,8 @@ def transform( List[Any] : The results of applying func to stored image data """ + self._ensure_data_loaded() + # default grouping is by (x, y) tile if group_by is None: group_by = {Axes.ROUND, Axes.CH, Axes.ZPLANE} @@ -1010,7 +1040,7 @@ def axis_labels(self, axis: Axes) -> Sequence[int]: instance, ``imagestack.axis_labels(Axes.ROUND)`` returns all the round ids in this imagestack.""" - return [int(val) for val in self.xarray.coords[axis.value].values] + return [int(val) for val in self._data.coords[axis.value].values] @property def tile_shape(self): @@ -1147,6 +1177,7 @@ def max_proj(self, *dims: Axes) -> "ImageStack": max projection """ + self._ensure_data_loaded() max_projection = self._data.max([dim.value for dim in dims]) max_projection = max_projection.expand_dims(tuple(dim.value for dim in dims)) max_projection = max_projection.transpose(*self.xarray.dims) diff --git a/starfish/core/imagestack/test/test_slicedimage_dtype.py b/starfish/core/imagestack/test/test_slicedimage_dtype.py index adbf2eb13..ab51da55d 100644 --- a/starfish/core/imagestack/test/test_slicedimage_dtype.py +++ b/starfish/core/imagestack/test/test_slicedimage_dtype.py @@ -70,12 +70,13 @@ def get_tile( def test_multiple_tiles_of_different_kind(): + stack = synthetic_stack( + NUM_ROUND, NUM_CH, NUM_Z, + HEIGHT, WIDTH, + tile_fetcher=CornerDifferentDtype(np.uint32, np.float32), + ) with pytest.raises(TypeError): - synthetic_stack( - NUM_ROUND, NUM_CH, NUM_Z, - HEIGHT, WIDTH, - tile_fetcher=CornerDifferentDtype(np.uint32, np.float32), - ) + stack._ensure_data_loaded() def test_multiple_tiles_of_same_dtype(): @@ -94,12 +95,13 @@ def test_multiple_tiles_of_same_dtype(): def test_int_type_promotion(): + stack = synthetic_stack( + NUM_ROUND, NUM_CH, NUM_Z, + HEIGHT, WIDTH, + tile_fetcher=CornerDifferentDtype(np.int32, np.int8), + ) with warnings.catch_warnings(record=True) as warnings_: - stack = synthetic_stack( - NUM_ROUND, NUM_CH, NUM_Z, - HEIGHT, WIDTH, - tile_fetcher=CornerDifferentDtype(np.int32, np.int8), - ) + stack._ensure_data_loaded() assert len(warnings_) == 2 assert issubclass(warnings_[0].category, UserWarning) assert issubclass(warnings_[1].category, DataFormatWarning) @@ -117,12 +119,13 @@ def test_int_type_promotion(): def test_uint_type_promotion(): + stack = synthetic_stack( + NUM_ROUND, NUM_CH, NUM_Z, + HEIGHT, WIDTH, + tile_fetcher=CornerDifferentDtype(np.uint32, np.uint8), + ) with warnings.catch_warnings(record=True) as warnings_: - stack = synthetic_stack( - NUM_ROUND, NUM_CH, NUM_Z, - HEIGHT, WIDTH, - tile_fetcher=CornerDifferentDtype(np.uint32, np.uint8), - ) + stack._ensure_data_loaded() assert len(warnings_) == 2 assert issubclass(warnings_[0].category, UserWarning) assert issubclass(warnings_[1].category, DataFormatWarning) @@ -140,12 +143,13 @@ def test_uint_type_promotion(): def test_float_type_demotion(): + stack = synthetic_stack( + NUM_ROUND, NUM_CH, NUM_Z, + HEIGHT, WIDTH, + tile_fetcher=CornerDifferentDtype(np.float64, np.float32), + ) with warnings.catch_warnings(record=True) as warnings_: - stack = synthetic_stack( - NUM_ROUND, NUM_CH, NUM_Z, - HEIGHT, WIDTH, - tile_fetcher=CornerDifferentDtype(np.float64, np.float32), - ) + stack._ensure_data_loaded() assert len(warnings_) == 2 assert issubclass(warnings_[0].category, UserWarning) assert issubclass(warnings_[1].category, DataFormatWarning) diff --git a/starfish/core/multiprocessing/test/test_multiprocessing.py b/starfish/core/multiprocessing/test/test_multiprocessing.py index 8bcf3db7b..ccea77374 100644 --- a/starfish/core/multiprocessing/test/test_multiprocessing.py +++ b/starfish/core/multiprocessing/test/test_multiprocessing.py @@ -126,7 +126,7 @@ def test_imagestack_deepcopy(nitems: int=10) -> None: shape = (nitems, 3, 4, 5, 6) dtype = np.float32 source = np.zeros(shape, dtype=np.float32) - imagestack = ImageStack.from_numpy(source) + imagestack = ImageStack.from_numpy(source)._ensure_data_loaded() imagestack_copy = copy.deepcopy(imagestack) _start_process_to_test_shmem( array_holder=imagestack_copy._data._backing_mp_array,