diff --git a/examples/3DImage.ipynb b/examples/3DImage.ipynb index d6c1437f..94364dbf 100644 --- a/examples/3DImage.ipynb +++ b/examples/3DImage.ipynb @@ -47,12 +47,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ff7e3fd0c36d47829ba3326ebc805032", + "model_id": "57448261fbe849c4bacc9af9a83f9f14", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Viewer(geometries=[], gradient_opacity=0.4, point_sets=[], rendered_image= self.size_limit_3d[dim]: self._downsampling = True - if self._downsampling: + if self._downsampling and self.image: self.extractor = itk.ExtractImageFilter.New(self.image) self.shrinker = itk.BinShrinkImageFilter.New(self.extractor) + if self._downsampling and self.label_map: + self.label_map_extractor = itk.ExtractImageFilter.New(self.label_map) + self.label_map_shrinker = itk.ShrinkImageFilter.New(self.label_map_extractor) self._update_rendered_image() if self._downsampling: self.observe(self._on_roi_changed, ['roi']) self.observe(self._on_reset_crop_requested, ['_reset_crop_requested']) - self.observe(self.update_rendered_image, ['image']) + self.observe(self.update_rendered_image, ['image', 'label_map']) def _on_roi_changed(self, change=None): if self._downsampling: @@ -341,16 +360,20 @@ def _on_roi_changed(self, change=None): def _on_reset_crop_requested(self, change=None): if change.new is True and self._downsampling: - dimension = self.image.GetImageDimension() - largest_region = self.image.GetLargestPossibleRegion() + if self.image: + image = self.image + else: + image = self.label_map + dimension = image.GetImageDimension() + largest_region = image.GetLargestPossibleRegion() size = largest_region.GetSize() largest_index = largest_region.GetIndex() new_roi = self.roi.copy() new_roi[0][:dimension] = np.array( - self.image.TransformIndexToPhysicalPoint(largest_index)) + image.TransformIndexToPhysicalPoint(largest_index)) largest_index_upper = largest_index + size new_roi[1][:dimension] = np.array( - self.image.TransformIndexToPhysicalPoint(largest_index_upper)) + image.TransformIndexToPhysicalPoint(largest_index_upper)) self._largest_roi = new_roi.copy() self.roi = new_roi if change.new is True: @@ -359,6 +382,7 @@ def _on_reset_crop_requested(self, change=None): @debounced(delay_seconds=0.2, method=True) def update_rendered_image(self, change=None): self._largest_roi_rendered_image = None + self._largest_roi_rendered_label_map = None self._largest_roi = np.zeros((2, 3), dtype=np.float64) self._update_rendered_image() @@ -371,7 +395,7 @@ def _find_scale_factors(limit, dimension, size): return scale_factors def _update_rendered_image(self): - if self.image is None: + if self.image is None and self.label_map is None: return if self._rendering_image: @yield_for_change(self, '_rendering_image') @@ -382,10 +406,14 @@ def f(): self._rendering_image = True if self._downsampling: - dimension = self.image.GetImageDimension() - index = self.image.TransformPhysicalPointToIndex( + if self.image: + image = self.image + else: + image = self.label_map + dimension = image.GetImageDimension() + index = image.TransformPhysicalPointToIndex( self.roi[0][:dimension]) - upper_index = self.image.TransformPhysicalPointToIndex( + upper_index = image.TransformPhysicalPointToIndex( self.roi[1][:dimension]) size = upper_index - index @@ -396,7 +424,10 @@ def f(): scale_factors = self._find_scale_factors( self.size_limit_3d, dimension, size) self._scale_factors = np.array(scale_factors, dtype=np.uint8) - self.shrinker.SetShrinkFactors(scale_factors[:dimension]) + if self.image: + self.shrinker.SetShrinkFactors(scale_factors[:dimension]) + if self.label_map: + self.label_map_shrinker.SetShrinkFactors(scale_factors[:dimension]) region = itk.ImageRegion[dimension]() region.SetIndex(index) @@ -404,10 +435,14 @@ def f(): # Account for rounding # truncation issues region.PadByRadius(1) - region.Crop(self.image.GetLargestPossibleRegion()) + region.Crop(image.GetLargestPossibleRegion()) - self.extractor.SetInput(self.image) - self.extractor.SetExtractionRegion(region) + if self.image: + self.extractor.SetInput(self.image) + self.extractor.SetExtractionRegion(region) + if self.label_map: + self.label_map_extractor.SetInput(self.label_map) + self.label_map_extractor.SetExtractionRegion(region) size = region.GetSize() @@ -415,24 +450,46 @@ def f(): if np.any(self._largest_roi) and np.all( self._largest_roi == self.roi): is_largest = True - if self._largest_roi_rendered_image is not None: - self.rendered_image = self._largest_roi_rendered_image + if self._largest_roi_rendered_image is not None or self._largest_roi_rendered_label_map is not None: + if self.image: + self.rendered_image = self._largest_roi_rendered_image + if self.label_map: + self.rendered_label_map = self._largest_roi_rendered_label_map return - self.shrinker.UpdateLargestPossibleRegion() + if self.image: + self.shrinker.UpdateLargestPossibleRegion() + if self.label_map: + self.label_map_shrinker.UpdateLargestPossibleRegion() if is_largest: - self._largest_roi_rendered_image = self.shrinker.GetOutput() - self._largest_roi_rendered_image.DisconnectPipeline() - self._largest_roi_rendered_image.SetOrigin( - self.roi[0][:dimension]) - self.rendered_image = self._largest_roi_rendered_image + if self.image: + self._largest_roi_rendered_image = self.shrinker.GetOutput() + self._largest_roi_rendered_image.DisconnectPipeline() + self._largest_roi_rendered_image.SetOrigin( + self.roi[0][:dimension]) + self.rendered_image = self._largest_roi_rendered_image + if self.label_map: + self._largest_roi_rendered_label_map = self.label_map_shrinker.GetOutput() + self._largest_roi_rendered_label_map.DisconnectPipeline() + self._largest_roi_rendered_label_map.SetOrigin( + self.roi[0][:dimension]) + self.rendered_label_map = self._largest_roi_rendered_label_map return - shrunk = self.shrinker.GetOutput() - shrunk.DisconnectPipeline() - shrunk.SetOrigin(self.roi[0][:dimension]) - self.rendered_image = shrunk + if self.image: + shrunk = self.shrinker.GetOutput() + shrunk.DisconnectPipeline() + shrunk.SetOrigin(self.roi[0][:dimension]) + self.rendered_image = shrunk + if self.label_map: + shrunk = self.label_map_shrinker.GetOutput() + shrunk.DisconnectPipeline() + shrunk.SetOrigin(self.roi[0][:dimension]) + self.rendered_label_map = shrunk else: - self.rendered_image = self.image + if self.image: + self.rendered_image = self.image + if self.label_map: + self.rendered_label_map = self.image @validate('gradient_opacity') def _validate_gradient_opacity(self, proposal): @@ -541,10 +598,14 @@ def _on_geometries_changed(self, change=None): def roi_region(self): """Return the itk.ImageRegion corresponding to the roi.""" - dimension = self.image.GetImageDimension() - index = self.image.TransformPhysicalPointToIndex( + if self.image: + image = self.image + else: + image = self.label_map + dimension = image.GetImageDimension() + index = image.TransformPhysicalPointToIndex( tuple(self.roi[0][:dimension])) - upper_index = self.image.TransformPhysicalPointToIndex( + upper_index = image.TransformPhysicalPointToIndex( tuple(self.roi[1][:dimension])) size = upper_index - index for dim in range(dimension): @@ -552,12 +613,16 @@ def roi_region(self): region = itk.ImageRegion[dimension]() region.SetIndex(index) region.SetSize(tuple(size)) - region.Crop(self.image.GetLargestPossibleRegion()) + region.Crop(image.GetLargestPossibleRegion()) return region def roi_slice(self): """Return the numpy array slice corresponding to the roi.""" - dimension = self.image.GetImageDimension() + if self.image: + image = self.image + else: + image = self.label_map + dimension = image.GetImageDimension() region = self.roi_region() index = region.GetIndex() upper_index = np.array(index) + np.array(region.GetSize()) @@ -568,6 +633,7 @@ def roi_slice(self): def view(image=None, # noqa: C901 + label_map=None, # noqa: C901 cmap=cm.viridis, select_roi=False, interpolation=True, @@ -584,7 +650,8 @@ def view(image=None, # noqa: C901 Creates and returns an ipywidget to visualize an image, and/or point sets and/or geometries . - The image can be 2D or 3D. + The image can be 2D or 3D. A label map that corresponds to the image can + also be provided. The image and label map must have the same size. The type of the image can be an numpy.array, itk.Image, vtk.vtkImageData, pyvista.UniformGrid, imglyb.ReferenceGuardingRandomAccessibleInterval, @@ -634,6 +701,10 @@ def view(image=None, # noqa: C901 image : array_like, itk.Image, or vtk.vtkImageData The 2D or 3D image to visualize. + label_map : array_like, itk.Image, or vtk.vtkImageData + The 2D or 3D label map to visualize. If an image is also provided, the + label map must have the same size. + vmin: float, optional, default: None Value that maps to the minimum of image colormap. Defaults to minimum of the image pixel buffer. @@ -798,6 +869,7 @@ def view(image=None, # noqa: C901 image = images[0] viewer = Viewer(image=image, + label_map=label_map, cmap=cmap, select_roi=select_roi, interpolation=interpolation, diff --git a/js/lib/viewer.js b/js/lib/viewer.js index fc0cbc13..61b9bff5 100644 --- a/js/lib/viewer.js +++ b/js/lib/viewer.js @@ -68,6 +68,7 @@ const ViewerModel = widgets.DOMWidgetModel.extend({ _model_module_version: '0.27.5', _view_module_version: '0.27.5', rendered_image: null, + rendered_label_map: null, _rendering_image: false, interpolation: true, cmap: 'Viridis (matplotlib)', @@ -103,6 +104,7 @@ const ViewerModel = widgets.DOMWidgetModel.extend({ }}, { serializers: Object.assign({ rendered_image: { serialize: serialize_itkimage, deserialize: deserialize_itkimage }, + rendered_label_map: { serialize: serialize_itkimage, deserialize: deserialize_itkimage }, _custom_cmap: simplearray_serialization, point_sets: { serialize: serialize_polydata_list, deserialize: deserialize_polydata_list }, geometries: { serialize: serialize_polydata_list, deserialize: deserialize_polydata_list }, @@ -118,11 +120,11 @@ const ViewerModel = widgets.DOMWidgetModel.extend({ }) -const createRenderingPipeline = (domWidgetView, { rendered_image, point_sets, geometries }) => { +const createRenderingPipeline = (domWidgetView, { rendered_image, rendered_label_map, point_sets, geometries }) => { const containerStyle = { position: 'relative', width: '100%', - height: '600px', + height: '700px', minHeight: '400px', minWidth: '400px', margin: '1', @@ -149,10 +151,15 @@ const createRenderingPipeline = (domWidgetView, { rendered_image, point_sets, ge }; let is3D = true let imageData = null + let labelMapData = null if (rendered_image) { imageData = vtkITKHelper.convertItkToVtkImage(rendered_image) is3D = rendered_image.imageType.dimension === 3 } + if (rendered_label_map) { + labelMapData = vtkITKHelper.convertItkToVtkImage(rendered_label_map) + is3D = rendered_label_map.imageType.dimension === 3 + } let pointSets = null if (point_sets) { pointSets = point_sets.map((point_set) => vtk(point_set)) @@ -166,6 +173,7 @@ const createRenderingPipeline = (domWidgetView, { rendered_image, point_sets, ge domWidgetView.model.itkVtkViewer = createViewer(domWidgetView.el, { viewerStyle: viewerStyle, image: imageData, + labelMap: labelMapData, pointSets, geometries: vtkGeometries, use2D: !is3D, @@ -272,18 +280,20 @@ const createRenderingPipeline = (domWidgetView, { rendered_image, point_sets, ge domWidgetView.model.save_changes() } - if (rendered_image) { + if (rendered_image || rendered_label_map) { const interactor = viewProxy.getInteractor() interactor.onEndMouseWheel(cropROIByViewport) interactor.onEndPan(cropROIByViewport) interactor.onEndPinch(cropROIByViewport) - const dataArray = imageData.getPointData().getScalars() - const numberOfComponents = dataArray.getNumberOfComponents() - if (domWidgetView.model.use2D && dataArray.getDataType() === 'Uint8Array' && (numberOfComponents === 3 || numberOfComponents === 4)) { - domWidgetView.model.itkVtkViewer.setColorMap(0, 'Grayscale') - domWidgetView.model.set('cmap', 'Grayscale') - domWidgetView.model.save_changes() + if (rendered_image) { + const dataArray = imageData.getPointData().getScalars() + const numberOfComponents = dataArray.getNumberOfComponents() + if (domWidgetView.model.use2D && dataArray.getDataType() === 'Uint8Array' && (numberOfComponents === 3 || numberOfComponents === 4)) { + domWidgetView.model.itkVtkViewer.setColorMap(0, 'Grayscale') + domWidgetView.model.set('cmap', 'Grayscale') + domWidgetView.model.save_changes() + } } domWidgetView.model.set('_rendering_image', false) domWidgetView.model.save_changes() @@ -321,6 +331,19 @@ function replaceRenderedImage(domWidgetView, rendered_image) { } +function replaceRenderedLabelMap(domWidgetView, rendered_label_map) { + const labelMapData = vtkITKHelper.convertItkToVtkImage(rendered_label_map) + + domWidgetView.model.itkVtkViewer.setLabelMap(labelMapData) + + if (viewProxy.getViewMode() === 'VolumeRendering') { + viewProxy.resetCamera() + } + domWidgetView.model.set('_rendering_image', false) + domWidgetView.model.save_changes() +} + + function replacePointSets(domWidgetView, pointSets) { const vtkPointSets = pointSets.map((pointSet) => vtk(pointSet)) domWidgetView.model.itkVtkViewer.setPointSets(vtkPointSets) @@ -575,6 +598,7 @@ async function decompressPolyData(polyData) { const ViewerView = widgets.DOMWidgetView.extend({ initialize_itkVtkViewer: function() { const rendered_image = this.model.get('rendered_image') + const rendered_label_map = this.model.get('rendered_label_map') this.annotations_changed() if (rendered_image) { this.interpolation_changed() @@ -582,18 +606,20 @@ const ViewerView = widgets.DOMWidgetView.extend({ this.vmin_changed() this.vmax_changed() } - if (rendered_image) { - this.shadow_changed() + if (rendered_image || rendered_label_map) { this.slicing_planes_changed() this.x_slice_changed() this.y_slice_changed() this.z_slice_changed() + } + if (rendered_image) { + this.shadow_changed() this.gradient_opacity_changed() this.blend_changed() } this.ui_collapsed_changed() this.rotate_changed() - if (rendered_image) { + if (rendered_image || rendered_label_map) { this.select_roi_changed() this.scale_factors_changed() } @@ -831,6 +857,7 @@ const ViewerView = widgets.DOMWidgetView.extend({ render: function() { this.model.on('change:rendered_image', this.rendered_image_changed, this) + this.model.on('change:rendered_label_map', this.rendered_label_map_changed, this) this.model.on('change:cmap', this.cmap_changed, this) this.model.on('change:vmin', this.vmin_changed, this) this.model.on('change:vmax', this.vmax_changed, this) @@ -863,6 +890,10 @@ const ViewerView = widgets.DOMWidgetView.extend({ if (rendered_image) { toDecompress.push(decompressImage(rendered_image)) } + const rendered_label_map = this.model.get('rendered_label_map') + if (rendered_label_map) { + toDecompress.push(decompressImage(rendered_label_map)) + } const point_sets = this.model.get('point_sets') if(point_sets && !!point_sets.length) { toDecompress = toDecompress.concat(point_sets.map(decompressPolyData)) @@ -875,8 +906,13 @@ const ViewerView = widgets.DOMWidgetView.extend({ Promise.all(toDecompress).then((decompressedData) => { let index = 0; let decompressedRenderedImage = null + let decompressedRenderedLabelMap = null if (rendered_image) { - decompressedRenderedImage = decompressedData[0] + decompressedRenderedImage = decompressedData[index] + index++ + } + if (rendered_label_map) { + decompressedRenderedLabelMap = decompressedData[index] index++ } let decompressedPointSets = null @@ -890,7 +926,9 @@ const ViewerView = widgets.DOMWidgetView.extend({ index += geometries.length } - return createRenderingPipeline(domWidgetView, { rendered_image: decompressedRenderedImage, + return createRenderingPipeline(domWidgetView, { + rendered_image: decompressedRenderedImage, + rendered_label_map: decompressedRenderedLabelMap, point_sets: decompressedPointSets, geometries: decompressedGeometries }) @@ -906,7 +944,7 @@ const ViewerView = widgets.DOMWidgetView.extend({ if (domWidgetView.model.hasOwnProperty('itkVtkViewer')) { return Promise.resolve(replaceRenderedImage(domWidgetView, decompressed)) } else { - return createRenderingPipeline(domWidgetView, { decompressed }) + return createRenderingPipeline(domWidgetView, { rendered_image: decompressed }) } }) } else { @@ -920,6 +958,29 @@ const ViewerView = widgets.DOMWidgetView.extend({ return Promise.resolve(null) }, + rendered_label_map_changed: function() { + const rendered_label_map = this.model.get('rendered_label_map') + if(rendered_label_map) { + if (!rendered_label_map.data) { + const domWidgetView = this + decompressImage(rendered_label_map).then((decompressed) => { + if (domWidgetView.model.hasOwnProperty('itkVtkViewer')) { + return Promise.resolve(replaceRenderedLabelMap(domWidgetView, decompressed)) + } else { + return createRenderingPipeline(domWidgetView, { rendered_label_map: decompressed }) + } + }) + } else { + if (domWidgetView.model.hasOwnProperty('itkVtkViewer')) { + return Promise.resolve(replaceRenderedLabelMap(this, rendered_label_map)) + } else { + return Promise.resolve(createRenderingPipeline(this, { rendered_label_map })) + } + } + } + return Promise.resolve(null) + }, + point_sets_changed: function() { const point_sets = this.model.get('point_sets') if(point_sets && !!point_sets.length) {