diff --git a/src/deepali/utils/vtk/image.py b/src/deepali/utils/vtk/image.py index e536c50..c3d5bfb 100644 --- a/src/deepali/utils/vtk/image.py +++ b/src/deepali/utils/vtk/image.py @@ -8,7 +8,6 @@ vtkImageData, vtkImageStencilData, vtkImageStencilToImage, - vtkMatrixToLinearTransform, vtkPolyData, vtkPolyDataToImageStencil, vtkTransformPolyDataFilter, @@ -41,30 +40,33 @@ def surface_mesh_grid(*mesh: vtkPolyData, resolution: Optional[float] = None) -> ) -def surface_image_stencil(mesh: vtkPolyData, grid: Grid) -> vtkImageStencilData: - r"""Convert vtkPolyData surface mesh to image stencil.""" - max_index = [n - 1 for n in grid.size().tolist()] - - rot = np.eye(4, dtype=np.float) - rot[:3, :3] = np.array(grid.direction).reshape(3, 3) - rot = numpy_to_vtk_matrix4x4(rot) - - transform = vtkMatrixToLinearTransform() - transform.SetInput(rot) +def surface_image_stencil(mesh: vtkPolyData, grid: GridAttrs) -> vtkImageStencilData: + r"""Convert vtkPolyData surface mesh to image stencil.""" + # Create the transform + transform = vtkTransform() + transform.Translate(grid.center) + transform.Concatenate(numpy_to_vtk_matrix4x4(grid.dcm.T)) # type: ignore + transform.Translate(tuple(-x for x in grid.center)) + # Apply the transform to the polydata transformer = vtkTransformPolyDataFilter() transformer.SetInputData(mesh) transformer.SetTransform(transform) - - converter = vtkPolyDataToImageStencil() - converter.SetInputConnection(transformer.GetOutputPort()) - converter.SetOutputOrigin(grid.origin().tolist()) - converter.SetOutputSpacing(grid.spacing().tolist()) - converter.SetOutputWholeExtent([0, max_index[0], 0, max_index[1], 0, max_index[2]]) - converter.Update() - + transformer.Update() + + # Convert the transformed polydata to an image stencil + grid_extent = [0, grid.size[0] - 1, 0, grid.size[1] - 1, 0, grid.size[2] - 1] + grid_no_direction = GridAttrs(size=grid.size, spacing=grid.spacing, center=grid.center) + polydata_to_stencil = vtkPolyDataToImageStencil() + polydata_to_stencil.SetInputConnection(transformer.GetOutputPort()) + polydata_to_stencil.SetOutputOrigin(grid_no_direction.origin) + polydata_to_stencil.SetOutputSpacing(grid.spacing) + polydata_to_stencil.SetOutputWholeExtent(grid_extent) + polydata_to_stencil.Update() + + # Get the output stencil stencil = vtkImageStencilData() - stencil.DeepCopy(converter.GetOutput()) + stencil.DeepCopy(polydata_to_stencil.GetOutput()) return stencil