Skip to content

Commit

Permalink
ENH: Add initial label_map support
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed May 26, 2020
1 parent 6064d7a commit d657c79
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 56 deletions.
6 changes: 3 additions & 3 deletions examples/3DImage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ff7e3fd0c36d47829ba3326ebc805032",
"model_id": "57448261fbe849c4bacc9af9a83f9f14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Viewer(geometries=[], gradient_opacity=0.4, point_sets=[], rendered_image=<itkImagePython.itkImageSS3; proxy o…"
"Viewer(geometries=[], gradient_opacity=0.9, point_sets=[], rendered_image=<itkImagePython.itkImageSS3; proxy o…"
]
},
"metadata": {},
Expand All @@ -61,7 +61,7 @@
],
"source": [
"image = itk.imread(file_name)\n",
"view(image, rotate=True, vmin=4000, vmax=17000, gradient_opacity=0.4)"
"view(image, rotate=True, vmin=4000, vmax=17000, gradient_opacity=0.9)"
]
},
{
Expand Down
148 changes: 110 additions & 38 deletions itkwidgets/widget_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ class Viewer(ViewerParent):
_rendering_image = CBool(
default_value=False,
help="We are currently volume rendering the image.").tag(sync=True)
label_map = ITKImage(
default_value=None,
allow_none=True,
help="Label map for the image.").tag(
sync=False,
**itkimage_serialization)
rendered_label_map = ITKImage(
default_value=None,
allow_none=True).tag(
sync=True,
**itkimage_serialization)
interpolation = CBool(
default_value=True,
help="Use linear interpolation in slicing planes.").tag(sync=True)
Expand Down Expand Up @@ -298,23 +309,28 @@ def __init__(self, **kwargs): # noqa: C901

super(Viewer, self).__init__(**kwargs)

if not self.image:
if not self.image and not self.label_map:
return
dimension = self.image.GetImageDimension()
largest_region = self.image.GetLargestPossibleRegion()
if self.image:
image = self.image
else:
image = self.label_map
dimension = image.GetImageDimension()
largest_region = image.GetLargestPossibleRegion()
size = largest_region.GetSize()

# Cache this so we do not need to recompute on it when resetting the
# roi
self._largest_roi_rendered_image = None
self._largest_roi_rendered_label_map = None
self._largest_roi = np.zeros((2, 3), dtype=np.float64)
if not np.any(self.roi):
largest_index = largest_region.GetIndex()
self.roi[0][:dimension] = np.array(
self.image.TransformIndexToPhysicalPoint(largest_index))
image.TransformIndexToPhysicalPoint(largest_index))
largest_index_upper = largest_index + size
self.roi[1][:dimension] = np.array(
self.image.TransformIndexToPhysicalPoint(largest_index_upper))
image.TransformIndexToPhysicalPoint(largest_index_upper))
self._largest_roi = self.roi.copy()

if dimension == 2:
Expand All @@ -325,32 +341,39 @@ def __init__(self, **kwargs): # noqa: C901
for dim in range(dimension):
if size[dim] > self.size_limit_3d[dim]:
self._downsampling = True
if self._downsampling:
if self._downsampling and self.image:
self.extractor = itk.ExtractImageFilter.New(self.image)
self.shrinker = itk.BinShrinkImageFilter.New(self.extractor)
if self._downsampling and self.label_map:
self.label_map_extractor = itk.ExtractImageFilter.New(self.label_map)
self.label_map_shrinker = itk.ShrinkImageFilter.New(self.label_map_extractor)
self._update_rendered_image()
if self._downsampling:
self.observe(self._on_roi_changed, ['roi'])

self.observe(self._on_reset_crop_requested, ['_reset_crop_requested'])
self.observe(self.update_rendered_image, ['image'])
self.observe(self.update_rendered_image, ['image', 'label_map'])

def _on_roi_changed(self, change=None):
if self._downsampling:
self._update_rendered_image()

def _on_reset_crop_requested(self, change=None):
if change.new is True and self._downsampling:
dimension = self.image.GetImageDimension()
largest_region = self.image.GetLargestPossibleRegion()
if self.image:
image = self.image
else:
image = self.label_map
dimension = image.GetImageDimension()
largest_region = image.GetLargestPossibleRegion()
size = largest_region.GetSize()
largest_index = largest_region.GetIndex()
new_roi = self.roi.copy()
new_roi[0][:dimension] = np.array(
self.image.TransformIndexToPhysicalPoint(largest_index))
image.TransformIndexToPhysicalPoint(largest_index))
largest_index_upper = largest_index + size
new_roi[1][:dimension] = np.array(
self.image.TransformIndexToPhysicalPoint(largest_index_upper))
image.TransformIndexToPhysicalPoint(largest_index_upper))
self._largest_roi = new_roi.copy()
self.roi = new_roi
if change.new is True:
Expand All @@ -359,6 +382,7 @@ def _on_reset_crop_requested(self, change=None):
@debounced(delay_seconds=0.2, method=True)
def update_rendered_image(self, change=None):
self._largest_roi_rendered_image = None
self._largest_roi_rendered_label_map = None
self._largest_roi = np.zeros((2, 3), dtype=np.float64)
self._update_rendered_image()

Expand All @@ -371,7 +395,7 @@ def _find_scale_factors(limit, dimension, size):
return scale_factors

def _update_rendered_image(self):
if self.image is None:
if self.image is None and self.label_map is None:
return
if self._rendering_image:
@yield_for_change(self, '_rendering_image')
Expand All @@ -382,10 +406,14 @@ def f():
self._rendering_image = True

if self._downsampling:
dimension = self.image.GetImageDimension()
index = self.image.TransformPhysicalPointToIndex(
if self.image:
image = self.image
else:
image = self.label_map
dimension = image.GetImageDimension()
index = image.TransformPhysicalPointToIndex(
self.roi[0][:dimension])
upper_index = self.image.TransformPhysicalPointToIndex(
upper_index = image.TransformPhysicalPointToIndex(
self.roi[1][:dimension])
size = upper_index - index

Expand All @@ -396,43 +424,72 @@ def f():
scale_factors = self._find_scale_factors(
self.size_limit_3d, dimension, size)
self._scale_factors = np.array(scale_factors, dtype=np.uint8)
self.shrinker.SetShrinkFactors(scale_factors[:dimension])
if self.image:
self.shrinker.SetShrinkFactors(scale_factors[:dimension])
if self.label_map:
self.label_map_shrinker.SetShrinkFactors(scale_factors[:dimension])

region = itk.ImageRegion[dimension]()
region.SetIndex(index)
region.SetSize(tuple(size))
# Account for rounding
# truncation issues
region.PadByRadius(1)
region.Crop(self.image.GetLargestPossibleRegion())
region.Crop(image.GetLargestPossibleRegion())

self.extractor.SetInput(self.image)
self.extractor.SetExtractionRegion(region)
if self.image:
self.extractor.SetInput(self.image)
self.extractor.SetExtractionRegion(region)
if self.label_map:
self.label_map_extractor.SetInput(self.label_map)
self.label_map_extractor.SetExtractionRegion(region)

size = region.GetSize()

is_largest = False
if np.any(self._largest_roi) and np.all(
self._largest_roi == self.roi):
is_largest = True
if self._largest_roi_rendered_image is not None:
self.rendered_image = self._largest_roi_rendered_image
if self._largest_roi_rendered_image is not None or self._largest_roi_rendered_label_map is not None:
if self.image:
self.rendered_image = self._largest_roi_rendered_image
if self.label_map:
self.rendered_label_map = self._largest_roi_rendered_label_map
return

self.shrinker.UpdateLargestPossibleRegion()
if self.image:
self.shrinker.UpdateLargestPossibleRegion()
if self.label_map:
self.label_map_shrinker.UpdateLargestPossibleRegion()
if is_largest:
self._largest_roi_rendered_image = self.shrinker.GetOutput()
self._largest_roi_rendered_image.DisconnectPipeline()
self._largest_roi_rendered_image.SetOrigin(
self.roi[0][:dimension])
self.rendered_image = self._largest_roi_rendered_image
if self.image:
self._largest_roi_rendered_image = self.shrinker.GetOutput()
self._largest_roi_rendered_image.DisconnectPipeline()
self._largest_roi_rendered_image.SetOrigin(
self.roi[0][:dimension])
self.rendered_image = self._largest_roi_rendered_image
if self.label_map:
self._largest_roi_rendered_label_map = self.label_map_shrinker.GetOutput()
self._largest_roi_rendered_label_map.DisconnectPipeline()
self._largest_roi_rendered_label_map.SetOrigin(
self.roi[0][:dimension])
self.rendered_label_map = self._largest_roi_rendered_label_map
return
shrunk = self.shrinker.GetOutput()
shrunk.DisconnectPipeline()
shrunk.SetOrigin(self.roi[0][:dimension])
self.rendered_image = shrunk
if self.image:
shrunk = self.shrinker.GetOutput()
shrunk.DisconnectPipeline()
shrunk.SetOrigin(self.roi[0][:dimension])
self.rendered_image = shrunk
if self.label_map:
shrunk = self.label_map_shrinker.GetOutput()
shrunk.DisconnectPipeline()
shrunk.SetOrigin(self.roi[0][:dimension])
self.rendered_label_map = shrunk
else:
self.rendered_image = self.image
if self.image:
self.rendered_image = self.image
if self.label_map:
self.rendered_label_map = self.image

@validate('gradient_opacity')
def _validate_gradient_opacity(self, proposal):
Expand Down Expand Up @@ -541,23 +598,31 @@ def _on_geometries_changed(self, change=None):

def roi_region(self):
"""Return the itk.ImageRegion corresponding to the roi."""
dimension = self.image.GetImageDimension()
index = self.image.TransformPhysicalPointToIndex(
if self.image:
image = self.image
else:
image = self.label_map
dimension = image.GetImageDimension()
index = image.TransformPhysicalPointToIndex(
tuple(self.roi[0][:dimension]))
upper_index = self.image.TransformPhysicalPointToIndex(
upper_index = image.TransformPhysicalPointToIndex(
tuple(self.roi[1][:dimension]))
size = upper_index - index
for dim in range(dimension):
size[dim] += 1
region = itk.ImageRegion[dimension]()
region.SetIndex(index)
region.SetSize(tuple(size))
region.Crop(self.image.GetLargestPossibleRegion())
region.Crop(image.GetLargestPossibleRegion())
return region

def roi_slice(self):
"""Return the numpy array slice corresponding to the roi."""
dimension = self.image.GetImageDimension()
if self.image:
image = self.image
else:
image = self.label_map
dimension = image.GetImageDimension()
region = self.roi_region()
index = region.GetIndex()
upper_index = np.array(index) + np.array(region.GetSize())
Expand All @@ -568,6 +633,7 @@ def roi_slice(self):


def view(image=None, # noqa: C901
label_map=None, # noqa: C901
cmap=cm.viridis,
select_roi=False,
interpolation=True,
Expand All @@ -584,7 +650,8 @@ def view(image=None, # noqa: C901
Creates and returns an ipywidget to visualize an image, and/or point sets
and/or geometries .
The image can be 2D or 3D.
The image can be 2D or 3D. A label map that corresponds to the image can
also be provided. The image and label map must have the same size.
The type of the image can be an numpy.array, itk.Image,
vtk.vtkImageData, pyvista.UniformGrid, imglyb.ReferenceGuardingRandomAccessibleInterval,
Expand Down Expand Up @@ -634,6 +701,10 @@ def view(image=None, # noqa: C901
image : array_like, itk.Image, or vtk.vtkImageData
The 2D or 3D image to visualize.
label_map : array_like, itk.Image, or vtk.vtkImageData
The 2D or 3D label map to visualize. If an image is also provided, the
label map must have the same size.
vmin: float, optional, default: None
Value that maps to the minimum of image colormap. Defaults to minimum of
the image pixel buffer.
Expand Down Expand Up @@ -798,6 +869,7 @@ def view(image=None, # noqa: C901
image = images[0]

viewer = Viewer(image=image,
label_map=label_map,
cmap=cmap,
select_roi=select_roi,
interpolation=interpolation,
Expand Down
Loading

0 comments on commit d657c79

Please sign in to comment.