Skip to content

Commit

Permalink
Clean up binary mask (#1622)
Browse files Browse the repository at this point in the history
1. Use typed axes/coordinates when possible.
2. Have `_get_axes_names` return axes/coordinates in the order they are expected in the numpy arrays.
3. Improve how we render the mask names.

Test plan: `pytest starfish/core/binary_mask/test`
  • Loading branch information
Tony Tung authored Nov 13, 2019
1 parent 4afc114 commit fdf8962
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 52 deletions.
39 changes: 19 additions & 20 deletions starfish/core/binary_mask/binary_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MutableSequence,
Optional,
Sequence,
Set,
Tuple,
Union,
)
Expand All @@ -21,9 +22,9 @@
from skimage.measure import regionprops
from skimage.measure._regionprops import _RegionProperties

from starfish.core.types import Axes, Coordinates
from starfish.core.types import Axes, Coordinates, Number
from .expand import fill_from_mask
from .util import _get_axes_names, AXES_ORDER
from .util import _get_axes_names


def _validate_binary_mask(arr: xr.DataArray):
Expand All @@ -44,13 +45,14 @@ def _validate_binary_mask(arr: xr.DataArray):
raise TypeError(f"expected dtype of bool; got {arr.dtype}")

axes, coords = _get_axes_names(arr.ndim)
dims = set(axes)
dims: Set[str] = set(axis.value for axis in axes)

if dims != set(arr.dims):
raise TypeError(f"missing dimensions '{dims.difference(arr.dims)}'")

if dims.union(coords) != set(arr.coords):
raise TypeError(f"missing coordinates '{dims.union(coords).difference(arr.coords)}'")
dims = dims.union(set(coord.value for coord in coords))
if dims != set(arr.coords):
raise TypeError(f"missing coordinates '{dims.difference(arr.coords)}'")


@dataclass
Expand Down Expand Up @@ -130,7 +132,7 @@ def mask_regionprops(self, mask_id: int) -> _RegionProperties:
image = np.zeros(
shape=tuple(
self.max_shape[axis]
for axis in AXES_ORDER
for axis, _ in zip(*_get_axes_names(3))
if self.max_shape[axis] != 0
),
dtype=np.uint32,
Expand All @@ -139,7 +141,7 @@ def mask_regionprops(self, mask_id: int) -> _RegionProperties:
mask_data.binary_mask,
mask_id + 1,
image,
[axis for axis in AXES_ORDER if self.max_shape[axis] > 0],
[axis for axis, _ in zip(*_get_axes_names(3)) if self.max_shape[axis] != 0],
)
mask_data.region_properties = regionprops(image)
return mask_data.region_properties
Expand All @@ -148,7 +150,7 @@ def mask_regionprops(self, mask_id: int) -> _RegionProperties:
def from_label_image(
cls,
label_image: np.ndarray,
physical_ticks: Dict[Coordinates, Sequence[float]]
physical_ticks: Dict[Coordinates, Sequence[Number]]
) -> "BinaryMaskCollection":
"""Creates binary masks from a label image.
Expand All @@ -157,14 +159,15 @@ def from_label_image(
label_image : int array
Integer array where each integer corresponds to a region.
physical_ticks : Dict[Coordinates, Sequence[float]]
Physical coordinates for each axis.
Physical ticks for each axis.
Returns
-------
masks : BinaryMaskCollection
Masks generated from the label image.
"""
props = regionprops(label_image)
len_max_label = len(str(len(props) - 1))

dims, _ = _get_axes_names(label_image.ndim)

Expand All @@ -173,25 +176,21 @@ def from_label_image(

# for each region (and its properties):
for label, prop in enumerate(props):
# create pixel coordinate labels from the bounding box
# to preserve spatial indexing relative to the original image
coords = {d: list(range(prop.bbox[i], prop.bbox[i + len(dims)]))
# create pixel ticks from the bounding box to preserve spatial indexing relative to the
# original image
coords = {d.value: list(range(prop.bbox[i], prop.bbox[i + len(dims)]))
for i, d in enumerate(dims)}

# create physical coordinate labels by taking the overlapping
# subset from the full span of labels
# create physical ticks by taking the overlapping subset from the full span of labels
for d, c in physical_ticks.items():
axis = d.value[0]
i = dims.index(axis)
coords[d.value] = (axis, c[prop.bbox[i]:prop.bbox[i + len(dims)]])

name = str(label + 1)
name = name.zfill(len(str(len(props)))) # pad with zeros

mask = xr.DataArray(prop.image,
dims=dims,
dims=[dim.value for dim in dims],
coords=coords,
name=name)
name=f"{label:0{len_max_label}d}")
masks.append(mask)

return cls(masks, props)
Expand All @@ -200,7 +199,7 @@ def to_label_image(
self,
shape: Optional[Tuple[int, ...]] = None,
*,
ordering: Sequence[Axes] = AXES_ORDER,
ordering: Sequence[Axes] = (Axes.ZPLANE, Axes.Y, Axes.X),
):
"""Create a label image from the contained masks.
Expand Down
4 changes: 2 additions & 2 deletions starfish/core/binary_mask/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import xarray as xr

from starfish.types import Axes
from .util import _get_axes_names, AXES_ORDER
from .util import _get_axes_names


def fill_from_mask(
mask: xr.DataArray,
fill_value: int,
result_array: np.ndarray,
axes_order: Sequence[Union[str, Axes]] = AXES_ORDER,
axes_order: Sequence[Union[str, Axes]] = (Axes.ZPLANE, Axes.Y, Axes.X),
):
"""Take a binary mask with labeled axes and write `fill_value` to an array `result_array` where
the binary mask has a True value. The output array is assumed to have a zero origin. The input
Expand Down
24 changes: 12 additions & 12 deletions starfish/core/binary_mask/test/test_binary_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,30 @@ def test_from_label_image():

assert len(masks) == 2

region_1, region_2 = masks
region_0, region_1 = masks

assert region_0.name == '0'
assert region_1.name == '1'
assert region_2.name == '2'

assert np.array_equal(region_1, np.ones((1, 5), dtype=np.bool))
assert np.array_equal(region_0, np.ones((1, 5), dtype=np.bool))
temp = np.ones((2, 2), dtype=np.bool)
temp[-1, -1] = False
assert np.array_equal(region_2, temp)
assert np.array_equal(region_1, temp)

assert np.array_equal(region_1[Axes.Y.value], [0])
assert np.array_equal(region_1[Axes.X.value], [0, 1, 2, 3, 4])
assert np.array_equal(region_0[Axes.Y.value], [0])
assert np.array_equal(region_0[Axes.X.value], [0, 1, 2, 3, 4])

assert np.array_equal(region_2[Axes.Y.value], [3, 4])
assert np.array_equal(region_2[Axes.X.value], [3, 4])
assert np.array_equal(region_1[Axes.Y.value], [3, 4])
assert np.array_equal(region_1[Axes.X.value], [3, 4])

assert np.array_equal(region_1[Coordinates.Y.value],
assert np.array_equal(region_0[Coordinates.Y.value],
physical_ticks[Coordinates.Y][0:1])
assert np.array_equal(region_1[Coordinates.X.value],
assert np.array_equal(region_0[Coordinates.X.value],
physical_ticks[Coordinates.X][0:5])

assert np.array_equal(region_2[Coordinates.Y.value],
assert np.array_equal(region_1[Coordinates.Y.value],
physical_ticks[Coordinates.Y][3:5])
assert np.array_equal(region_2[Coordinates.X.value],
assert np.array_equal(region_1[Coordinates.X.value],
physical_ticks[Coordinates.X][3:5])


Expand Down
28 changes: 12 additions & 16 deletions starfish/core/binary_mask/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,28 @@
from starfish.core.types import Axes, Coordinates


AXES = [a.value for a in Axes if a not in (Axes.ROUND, Axes.CH)]
COORDS = [c.value for c in Coordinates]
AXES_ORDER = Axes.ZPLANE, Axes.Y, Axes.X


def _get_axes_names(ndim: int) -> Tuple[List[str], List[str]]:
"""Get needed axes names given the number of dimensions.
def _get_axes_names(ndim: int) -> Tuple[List[Axes], List[Coordinates]]:
"""Get needed axes and coordinates given the number of dimensions. The axes and coordinates are
returned in the order expected for binary masks. For instance, the first axis/coordinate
should be the first index into the mask.
Parameters
----------
ndim : int
Number of dimensions.
Returns
-------
axes : List[str]
Axes names.
coords : List[str]
Coordinates names.
axes : List[Axes]
Axes.
coords : List[Coordinates]
Coordinates.
"""
if ndim == 2:
axes = [axis for axis in AXES if axis != Axes.ZPLANE.value]
coords = [coord for coord in COORDS if coord != Coordinates.Z.value]
axes = [Axes.Y, Axes.X]
coords = [Coordinates.Y, Coordinates.X]
elif ndim == 3:
axes = AXES
coords = COORDS
axes = [Axes.ZPLANE, Axes.Y, Axes.X]
coords = [Coordinates.Z, Coordinates.Y, Coordinates.X]
else:
raise TypeError('expected 2- or 3-D image')

Expand Down
4 changes: 2 additions & 2 deletions starfish/test/full_pipelines/api/test_iss_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,5 @@ def test_iss_pipeline_cropped_data(tmpdir):
assert pipeline_log[2]['method'] == 'BlobDetector'
assert pipeline_log[3]['method'] == 'PerRoundMaxChannel'

# 28 of the spots are assigned to cell 1 (although most spots do not decode!)
assert np.sum(assigned['cell_id'] == '1') == 28
# 28 of the spots are assigned to cell 0 (although most spots do not decode!)
assert np.sum(assigned['cell_id'] == '0') == 28

0 comments on commit fdf8962

Please sign in to comment.