From 974e1989ffaa31bd519f829b7bc8c19a1880c9a1 Mon Sep 17 00:00:00 2001 From: Tom Birdsong Date: Tue, 7 Nov 2023 09:29:34 -0500 Subject: [PATCH] WIP: autoformatting --- src/itk_dreg/base/image_block_interface.py | 73 ++++-- src/itk_dreg/base/registration_interface.py | 10 +- src/itk_dreg/block/convert.py | 35 ++- src/itk_dreg/block/dask.py | 20 +- src/itk_dreg/block/image.py | 17 +- src/itk_dreg/elastix/register.py | 132 +++++----- src/itk_dreg/elastix/serialize.py | 61 +++-- src/itk_dreg/elastix/util.py | 39 +-- src/itk_dreg/itk.py | 31 ++- src/itk_dreg/register.py | 239 +++++++++++------- src/itk_dreg/stitch_reduce/__init__.py | 2 +- src/itk_dreg/stitch_reduce/dreg.py | 127 ++++++---- .../stitch_reduce/matrix_transform.py | 17 +- src/itk_dreg/stitch_reduce/transform.py | 47 ++-- .../stitch_reduce/transform_collection.py | 165 +++++++----- 15 files changed, 576 insertions(+), 439 deletions(-) diff --git a/src/itk_dreg/base/image_block_interface.py b/src/itk_dreg/base/image_block_interface.py index e17c204..cae23db 100644 --- a/src/itk_dreg/base/image_block_interface.py +++ b/src/itk_dreg/base/image_block_interface.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import numpy as np -import chunk from dataclasses import dataclass from enum import IntEnum from typing import Optional, List @@ -59,14 +58,20 @@ class BlockInfo: @property def ndim(self) -> int: if len(self.chunk_index) != len(self.array_slice): - raise ValueError('Observed mismatch between chunk and slice index dimensions') + raise ValueError( + "Observed mismatch between chunk and slice index dimensions" + ) return len(self.chunk_index) - + @property def shape(self) -> List[int]: - if any([slice_val.step and slice_val.step != 1 for slice_val in self.array_slice]): + if any( + [slice_val.step and slice_val.step != 1 for slice_val in self.array_slice] + ): print() - raise ValueError('Illegal step size in `BlockInfo`, expected step size of 1') + raise ValueError( + "Illegal step size in `BlockInfo`, expected step size of 1" + ) return [slice_val.stop - slice_val.start for slice_val in self.array_slice] @@ -94,12 +99,12 @@ class BlockPairRegistrationResult: """Encapsulate result of fixed-to-moving registration over one block pair.""" def __init__( - self, - status: BlockRegStatus, - transform: Optional[TransformType]=None, - transform_domain: Optional[ImageType]=None, - inv_transform: Optional[TransformType]=None, - inv_transform_domain: Optional[ImageType]=None + self, + status: BlockRegStatus, + transform: Optional[TransformType] = None, + transform_domain: Optional[ImageType] = None, + inv_transform: Optional[TransformType] = None, + inv_transform_domain: Optional[ImageType] = None, ): """ :param status: Status code indicating registration success or failure. @@ -120,22 +125,39 @@ def __init__( `inv_transform_domain` must be available if and only if `inv_transform` is available. """ if status == BlockRegStatus.SUCCESS and not transform: - raise ValueError(f'Pairwise registration indicated success ({status})' - f' but no forward transform was provided') + raise ValueError( + f"Pairwise registration indicated success ({status})" + f" but no forward transform was provided" + ) if transform and not transform_domain: - raise ValueError(f'Pairwise registration returned incomplete forward transform:' - f' failed to provide forward transform domain') + raise ValueError( + "Pairwise registration returned incomplete forward transform:" + " failed to provide forward transform domain" + ) if transform_domain and itk.template(transform_domain)[0] != itk.Image: - raise TypeError(f'Received invalid transform domain type: {type(transform_domain)}') - if transform_domain and np.product(transform_domain.GetLargestPossibleRegion().GetSize()) == 0: - raise ValueError(f'Received invalid transform domain with size 0') + raise TypeError( + f"Received invalid transform domain type: {type(transform_domain)}" + ) + if ( + transform_domain + and np.product(transform_domain.GetLargestPossibleRegion().GetSize()) == 0 + ): + raise ValueError("Received invalid transform domain with size 0") if inv_transform and not inv_transform_domain: - raise ValueError(f'Pairwise registration returned incomplete inverse transform:' - f' failed to provide inverse transform domain') + raise ValueError( + "Pairwise registration returned incomplete inverse transform:" + " failed to provide inverse transform domain" + ) if inv_transform_domain and itk.template(inv_transform_domain)[0] != itk.Image: - raise TypeError(f'Received invalid transform domain type: {type(transform_domain)}') - if inv_transform_domain and np.product(inv_transform_domain.GetLargestPossibleRegion().GetSize()) == 0: - raise ValueError(f'Received invalid transform domain with size 0') + raise TypeError( + f"Received invalid transform domain type: {type(transform_domain)}" + ) + if ( + inv_transform_domain + and np.product(inv_transform_domain.GetLargestPossibleRegion().GetSize()) + == 0 + ): + raise ValueError("Received invalid transform domain with size 0") self.status = status self.transform = transform self.transform_domain = transform_domain @@ -150,13 +172,14 @@ class LocatedBlockResult: information describing how the fixed subimage in registration relates to the greater fixed image. """ - result:BlockPairRegistrationResult + + result: BlockPairRegistrationResult """ The result of pairwise subimage registration. May include extended information for specific implementations. """ - fixed_info:BlockInfo + fixed_info: BlockInfo """ Oriented representation of the fixed image block over which pairwise registration was performed to produce the encapsulated diff --git a/src/itk_dreg/base/registration_interface.py b/src/itk_dreg/base/registration_interface.py index 05f0f0b..2ec62ea 100644 --- a/src/itk_dreg/base/registration_interface.py +++ b/src/itk_dreg/base/registration_interface.py @@ -2,15 +2,15 @@ import itk from abc import ABC, abstractmethod -from typing import Optional, Iterable, Tuple +from typing import Iterable from .image_block_interface import ( BlockPairRegistrationResult, LocatedBlockResult, RegistrationTransformResult, - BlockInfo + BlockInfo, ) -from .itk_typing import ImageType, ImageReaderType, ImageRegionType, TransformType +from .itk_typing import ImageType, ImageReaderType, TransformType """ Defines extensible components to extend with concrete implementations. @@ -69,7 +69,7 @@ def __call__( moving_subimage: ImageType, initial_transform: TransformType, block_info: BlockInfo, - **kwargs + **kwargs, ) -> BlockPairRegistrationResult: """ Run image-to-image pairwise registration. @@ -115,7 +115,7 @@ def __call__( block_results: Iterable[LocatedBlockResult], fixed_reader_ctor: ConstructReaderMethod, initial_transform: itk.Transform, - **kwargs + **kwargs, ) -> RegistrationTransformResult: """ :param block_results: An iterable collection of subimages in fixed space diff --git a/src/itk_dreg/block/convert.py b/src/itk_dreg/block/convert.py index 91be69e..8f03a81 100644 --- a/src/itk_dreg/block/convert.py +++ b/src/itk_dreg/block/convert.py @@ -41,6 +41,7 @@ # reversed (np.flip) from NumPy conventional access order. # + def arr_to_continuous_index(index: Union[List, npt.ArrayLike]) -> itk.ContinuousIndex: r""" Convert Python array-like representation of a continuous index into @@ -55,7 +56,7 @@ def arr_to_continuous_index(index: Union[List, npt.ArrayLike]) -> itk.Continuous ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: Expecting a sequence of int (or long) ``` - + :param arr: The list or array representing a continuous index. The list or array must be a one dimensional collection of floating point scalar values. :returns: An `itk.ContinuousIndex` object representing the index. @@ -64,12 +65,13 @@ def arr_to_continuous_index(index: Union[List, npt.ArrayLike]) -> itk.Continuous """ index = np.array(index, dtype=np.float32) if index.ndim != 1: - raise ValueError(f'Expected 1D input index but received {index}') + raise ValueError(f"Expected 1D input index but received {index}") itk_index = itk.ContinuousIndex[itk.D, len(index)]() for dim, val in enumerate(index): itk_index.SetElement(dim, float(val)) return itk_index + def estimate_bounding_box( physical_region: npt.ArrayLike, transform: itk.Transform ) -> npt.ArrayLike: @@ -97,9 +99,7 @@ def estimate_bounding_box( physical_region[:, 0], physical_region[:, 1], physical_region[:, 2] ) ): - pt = np.array( - transform.TransformPoint([float(val) for val in (i, j, k)]) - ) + pt = np.array(transform.TransformPoint([float(val) for val in (i, j, k)])) arr[index, :] = pt return np.array([np.min(arr, axis=0), np.max(arr, axis=0)]) @@ -127,8 +127,7 @@ def block_to_physical_size( ) return np.abs( - ref_image.TransformIndexToPhysicalPoint(block_index) - - itk.origin(ref_image) + ref_image.TransformIndexToPhysicalPoint(block_index) - itk.origin(ref_image) ) @@ -137,9 +136,7 @@ def physical_to_block_size( ) -> npt.ArrayLike: """Convert from physical size to corresponding voxel size""" return np.abs( - ref_image.TransformPhysicalPointToIndex( - itk.origin(ref_image) + physical_size - ) + ref_image.TransformPhysicalPointToIndex(itk.origin(ref_image) + physical_size) ) @@ -147,7 +144,9 @@ def block_to_physical_region( block_region: npt.ArrayLike, ref_image: itk.Image, transform: itk.Transform = None, - estimate_bounding_box_method: Callable[[npt.ArrayLike, itk.Transform], npt.ArrayLike] = estimate_bounding_box, + estimate_bounding_box_method: Callable[ + [npt.ArrayLike, itk.Transform], npt.ArrayLike + ] = estimate_bounding_box, ) -> npt.ArrayLike: """Convert from voxel region to corresponding physical space""" # block region is a 2x3 matrix where row 0 is the lower bound and row 1 is the upper bound @@ -160,11 +159,11 @@ def block_to_physical_region( adjusted_block_region = block_region - HALF_VOXEL_STEP - index_to_physical_func = ( - lambda row: ref_image.TransformContinuousIndexToPhysicalPoint( + def index_to_physical_func(row: npt.ArrayLike) -> npt.ArrayLike: + return ref_image.TransformContinuousIndexToPhysicalPoint( arr_to_continuous_index(row) ) - ) + physical_region = np.apply_along_axis( index_to_physical_func, 1, adjusted_block_region ) @@ -190,14 +189,12 @@ def physical_to_block_region( # block region is a 2x3 matrix where row 0 is the lower bound and row 1 is the upper bound assert physical_region.ndim == 2 and physical_region.shape == (2, 3) - physical_to_index_func = ( + def physical_to_index_func(row: npt.ArrayLike) -> npt.ArrayLike: lambda row: ref_image.TransformPhysicalPointToContinuousIndex( [float(val) for val in row] ) - ) - block_region = np.apply_along_axis( - physical_to_index_func, 1, physical_region - ) + + block_region = np.apply_along_axis(physical_to_index_func, 1, physical_region) adjusted_block_region = np.array( [np.min(block_region, axis=0), np.max(block_region, axis=0)] ) diff --git a/src/itk_dreg/block/dask.py b/src/itk_dreg/block/dask.py index 8300297..ebc13d6 100644 --- a/src/itk_dreg/block/dask.py +++ b/src/itk_dreg/block/dask.py @@ -6,22 +6,22 @@ from .convert import get_target_block_size -def rechunk_to_physical_block(array:dask.array.core.Array, - array_image:itk.Image, - reference_array:dask.array.core.Array, - reference_image:itk.Image) -> dask.array.core.Array: + +def rechunk_to_physical_block( + array: dask.array.core.Array, + array_image: itk.Image, + reference_array: dask.array.core.Array, + reference_image: itk.Image, +) -> dask.array.core.Array: """ Rechunk the given array/image so that the physical size of each block approximately matches the physical size of chunks in the reference array/image. """ # Determine approximate physical size of a reference image chunk itk_ref_chunksize = np.flip(reference_array.chunksize) - arr_chunksize = \ - get_target_block_size( - itk_ref_chunksize, - reference_image, - array_image + arr_chunksize = get_target_block_size( + itk_ref_chunksize, reference_image, array_image ) - + # Rechunk the input array to match physical chunk size return dask.array.rechunk(array, chunks=np.flip(arr_chunksize)) diff --git a/src/itk_dreg/block/image.py b/src/itk_dreg/block/image.py index ff3cf83..1f58b12 100644 --- a/src/itk_dreg/block/image.py +++ b/src/itk_dreg/block/image.py @@ -7,7 +7,6 @@ """ import logging -import itertools from typing import List, Union import itk @@ -85,9 +84,9 @@ def block_to_itk_image( def physical_region_to_itk_image( physical_region: npt.ArrayLike, spacing: List[float], - direction: Union[itk.Matrix,npt.ArrayLike], + direction: Union[itk.Matrix, npt.ArrayLike], extend_beyond: bool = True, - image_type: type[itk.Image]=None + image_type: type[itk.Image] = None, ) -> itk.Image: """ Represent a physical region as an unallocated itk.Image object. @@ -115,9 +114,11 @@ def physical_region_to_itk_image( orientation mapping from I,J,K to X,Y,Z axes. """ direction = np.array(direction) - image_type = image_type or itk.Image[itk.F,3] - assert not np.any(np.isclose(spacing, 0)), f'Invalid spacing: {spacing}' - assert np.all((direction == 0) | (direction == 1) | (direction == -1)), f'Invalid direction: {direction}' + image_type = image_type or itk.Image[itk.F, 3] + assert not np.any(np.isclose(spacing, 0)), f"Invalid spacing: {spacing}" + assert np.all( + (direction == 0) | (direction == 1) | (direction == -1) + ), f"Invalid direction: {direction}" # Set up unit vectors mapping from voxel to physical space voxel_step_vecs = np.matmul(np.array(direction), np.eye(3) * spacing) @@ -135,9 +136,7 @@ def physical_region_to_itk_image( np.max(physical_region, axis=0) - np.min(physical_region, axis=0) ) / np.abs(physical_step) output_grid_size = ( - np.ceil(output_grid_size_f) - if extend_beyond - else np.floor(output_grid_size_f) + np.ceil(output_grid_size_f) if extend_beyond else np.floor(output_grid_size_f) ) centerpoint = np.mean(physical_region, axis=0) diff --git a/src/itk_dreg/elastix/register.py b/src/itk_dreg/elastix/register.py index f940779..bbd5494 100644 --- a/src/itk_dreg/elastix/register.py +++ b/src/itk_dreg/elastix/register.py @@ -2,60 +2,65 @@ import os import logging -from dataclasses import dataclass -from typing import List, Tuple, Dict, Optional +from typing import List, Optional import dask.distributed import itk import numpy as np from itk_dreg.base.itk_typing import ImageType, TransformType -from itk_dreg.base.image_block_interface import \ - BlockPairRegistrationResult, BlockRegStatus, BlockPairRegistrationResult, BlockInfo +from itk_dreg.base.image_block_interface import ( + BlockPairRegistrationResult, + BlockRegStatus, + BlockInfo, +) from itk_dreg.base.registration_interface import BlockPairRegistrationMethod import itk_dreg.block.convert as block_convert import itk_dreg.block.image as block_image -from itk_dreg.elastix.serialize import list_to_parameter_object, SerializableParameterObjectType -from itk_dreg.elastix.util import ParameterObjectType, get_elx_itk_transforms +from itk_dreg.elastix.serialize import ( + list_to_parameter_object, + SerializableParameterObjectType, +) +from itk_dreg.elastix.util import get_elx_itk_transforms logger = logging.getLogger(__name__) -worker_logger = logging.getLogger('distributed.worker') +worker_logger = logging.getLogger("distributed.worker") + class ElastixRegistrationResult(BlockPairRegistrationResult): """ Block pair registration result extended with Elastix-specific results. """ - def __init__(self, - registration_method:itk.ElastixRegistrationMethod, - **kwargs - ): + + def __init__(self, registration_method: itk.ElastixRegistrationMethod, **kwargs): """ :param registration_method: The filter used to run registration. """ super().__init__(kwargs) - self.registration_method=registration_method + self.registration_method = registration_method """ ITKElastix registration implementation for `itk-dreg` registration framework """ + class ElastixDRegBlockPairRegistrationMethod(BlockPairRegistrationMethod): def __call__( - self, - # ITK-DReg inherited parameters - fixed_subimage: ImageType, - moving_subimage: ImageType, - initial_transform: TransformType, - block_info: BlockInfo, - # Elastix-DReg parameters - log_directory:Optional[str], - elx_parameter_object_serial: SerializableParameterObjectType, - itk_transform_types: List[type], - preprocess_initial_transform:bool=False, #TODO - **kwargs - ) -> BlockPairRegistrationResult: + self, + # ITK-DReg inherited parameters + fixed_subimage: ImageType, + moving_subimage: ImageType, + initial_transform: TransformType, + block_info: BlockInfo, + # Elastix-DReg parameters + log_directory: Optional[str], + elx_parameter_object_serial: SerializableParameterObjectType, + itk_transform_types: List[type], + preprocess_initial_transform: bool = False, # TODO + **kwargs, + ) -> BlockPairRegistrationResult: """ Compute a series of ITKElastix transforms mapping from the moving image to the fixed image. @@ -85,53 +90,55 @@ def __call__( - The registration result status. :raises RuntimeError: If the Elastix registration procedure encounters an error. """ - elx_parameter_object = \ - list_to_parameter_object(elx_parameter_object_serial) - dask.distributed.print('Entering Elastix registration') - worker_logger.error('Entering Elastix registration') - LOG_FILENAME = 'elxLog.txt' + elx_parameter_object = list_to_parameter_object(elx_parameter_object_serial) + dask.distributed.print("Entering Elastix registration") + worker_logger.error("Entering Elastix registration") + LOG_FILENAME = "elxLog.txt" block_log_directory = None if log_directory: - block_log_directory = f'{log_directory}/{"-".join(map(str, block_info.chunk_index))}' + block_log_directory = ( + f'{log_directory}/{"-".join(map(str, block_info.chunk_index))}' + ) os.makedirs(block_log_directory, exist_ok=True) - logger.debug(f'{block_info.chunk_index}: ' - f'Elastix logs will be written to {block_log_directory}') - - if not "ElastixRegistrationMethod" in dir(itk): - raise KeyError( - "Elastix methods not found, please pip install itk-elastix" + logger.debug( + f"{block_info.chunk_index}: " + f"Elastix logs will be written to {block_log_directory}" ) + if "ElastixRegistrationMethod" not in dir(itk): + raise KeyError("Elastix methods not found, please pip install itk-elastix") + if initial_transform and itk_transform_types[-1]: # Cannot directly convert an external init ITK transfrom from Elastix itk_transform_types.append(None) - logger.debug(f'{block_info.chunk_index}: ' - f"Register with parameter object:{elx_parameter_object}") + logger.debug( + f"{block_info.chunk_index}: " + f"Register with parameter object:{elx_parameter_object}" + ) - #preprocess_initial_transform = \ + # preprocess_initial_transform = \ # initial_transform and itk.BSplineTransform[itk.D,3,3] in itk_transform_types if preprocess_initial_transform: - worker_logger.warning(f'{block_info.chunk_index}: ' - 'Resampling fixed image') + worker_logger.warning( + f"{block_info.chunk_index}: " "Resampling fixed image" + ) # B-spline requires Jacobian which is not supported by AdvancedExternalTransform # so we must resample the image ourselves and discard the initial transform - physical_region=block_image.image_to_physical_region( - fixed_subimage.GetBufferedRegion(), - fixed_subimage, - initial_transform + physical_region = block_image.image_to_physical_region( + fixed_subimage.GetBufferedRegion(), fixed_subimage, initial_transform ) - logger.debug(f'Resampling fixed image to initial domain {physical_region}') + logger.debug(f"Resampling fixed image to initial domain {physical_region}") fixed_subimage = itk.resample_image_filter( fixed_subimage, - transform=initial_transform.GetInverseTransform(), # TODO + transform=initial_transform.GetInverseTransform(), # TODO use_reference_image=True, reference_image=block_image.physical_region_to_itk_image( physical_region=physical_region, spacing=itk.spacing(fixed_subimage), direction=np.array(fixed_subimage.GetDirection()), - extend_beyond=True - ) + extend_beyond=True, + ), ) itk_transform_types = itk_transform_types[:-1] @@ -144,26 +151,29 @@ def __call__( ) if initial_transform and not preprocess_initial_transform: - logger.debug(f'{block_info.chunk_index}: ' - f'initial transform {str(initial_transform)}') + logger.debug( + f"{block_info.chunk_index}: " + f"initial transform {str(initial_transform)}" + ) registration_method.SetExternalInitialTransform(initial_transform) # If we are debugging, make the buffered subimages available on disk # for later review if block_log_directory and logger.getEffectiveLevel() == logging.DEBUG: import itk_dreg.itk - logger.info(f'Writing buffered subimages to {block_log_directory}') + + logger.info(f"Writing buffered subimages to {block_log_directory}") try: itk_dreg.itk.write_buffered_region( image=fixed_subimage, - filepath=f'{block_log_directory}/fixed_subimage.mha' + filepath=f"{block_log_directory}/fixed_subimage.mha", ) itk_dreg.itk.write_buffered_region( image=moving_subimage, - filepath=f'{block_log_directory}/moving_subimage.mha' + filepath=f"{block_log_directory}/moving_subimage.mha", ) except Exception as e: - logger.warning(f'Failed to write to {block_log_directory}: {e}') + logger.warning(f"Failed to write to {block_log_directory}: {e}") if block_log_directory: registration_method.SetLogToFile(True) @@ -171,7 +181,7 @@ def __call__( registration_method.SetLogFileName(LOG_FILENAME) # Run registration with `itk-elastix`, may take a few minutes - logger.info(f'{block_info.chunk_index}: Running pairwise registration') + logger.info(f"{block_info.chunk_index}: Running pairwise registration") registration_method.Update() # Get the ITKElastix result as a composite transform. @@ -188,11 +198,13 @@ def __call__( physical_region=block_convert.image_to_physical_region( image_region=fixed_subimage.GetBufferedRegion(), ref_image=fixed_subimage, - src_transform=None if preprocess_initial_transform else initial_transform + src_transform=None + if preprocess_initial_transform + else initial_transform, ), spacing=itk.spacing(fixed_subimage), direction=fixed_subimage.GetDirection(), - extend_beyond=True + extend_beyond=True, ) return ElastixRegistrationResult( @@ -201,5 +213,5 @@ def __call__( inv_transform=None, inv_transform_domain=None, status=BlockRegStatus.SUCCESS, - registration_method=registration_method + registration_method=registration_method, ) diff --git a/src/itk_dreg/elastix/serialize.py b/src/itk_dreg/elastix/serialize.py index dcc7793..328d3f6 100644 --- a/src/itk_dreg/elastix/serialize.py +++ b/src/itk_dreg/elastix/serialize.py @@ -1,50 +1,47 @@ #!/usr/bin/env python3 -import os import logging -from typing import List, Tuple, Dict, Optional, Any +from typing import List, Tuple, Dict, Any import dask.distributed import itk -import numpy as np -import itk_dreg.block.image as block_image logger = logging.getLogger(__name__) ParameterObjectType = itk.ParameterObject -ParameterMapType = Any # FIXME itk.elxParameterObjectPython.mapstringvectorstring -SerializableParameterMapType = Dict[str, Tuple[str,str]] +ParameterMapType = Any # FIXME itk.elxParameterObjectPython.mapstringvectorstring +SerializableParameterMapType = Dict[str, Tuple[str, str]] SerializableParameterObjectType = List[SerializableParameterMapType] def parameter_map_to_dict( - parameter_map:ParameterMapType + parameter_map: ParameterMapType, ) -> SerializableParameterMapType: """ Convert an ITKElastix parameter map to a pickleable dictionary """ - return {k:v for (k,v) in parameter_map.items()} + return {k: v for (k, v) in parameter_map.items()} -def dict_to_parameter_map( - val: SerializableParameterMapType -) -> ParameterMapType: + +def dict_to_parameter_map(val: SerializableParameterMapType) -> ParameterMapType: """ Convert an ITKElastix parameter map to a pickleable dictionary """ # Eagerly load ITKElastix definitions so that mapstringvectorstring is available # TODO investigate a cleaner approach for this - _ = itk.ParameterObject + _ = itk.ParameterObject parameter_map = itk.elxParameterObjectPython.mapstringvectorstring() - for k,v in val.items(): + for k, v in val.items(): parameter_map[k] = v if not parameter_map[k] == v: - raise ValueError(f'Failed to set parameter map value: {k}, {v}') + raise ValueError(f"Failed to set parameter map value: {k}, {v}") return parameter_map + def parameter_object_to_list( - parameter_object:ParameterObjectType, + parameter_object: ParameterObjectType, ) -> SerializableParameterObjectType: """ Convert an ITKElastix parameter object to a pickleable collection. @@ -53,15 +50,20 @@ def parameter_object_to_list( """ result = [] for map_index in range(parameter_object.GetNumberOfParameterMaps()): - result.append(parameter_map_to_dict(parameter_object.GetParameterMap(map_index))) + result.append( + parameter_map_to_dict(parameter_object.GetParameterMap(map_index)) + ) return result + def list_to_parameter_object( - elastix_parameter_map_vals: SerializableParameterObjectType + elastix_parameter_map_vals: SerializableParameterObjectType, ) -> ParameterObjectType: parameter_object = itk.ParameterObject.New() for elastix_parameter_map_params in elastix_parameter_map_vals: - parameter_object.AddParameterMap(dict_to_parameter_map(elastix_parameter_map_params)) + parameter_object.AddParameterMap( + dict_to_parameter_map(elastix_parameter_map_params) + ) return parameter_object @@ -72,26 +74,33 @@ def list_to_parameter_object( are not serializable (pickleable) by default. This block monkeypatches the appropriate classes to be serializable in dask initialization as a short-term fix. """ + + def get_mapstringvectorstring_state(self) -> SerializableParameterMapType: return parameter_map_to_dict(self) -def set_mapstringvectorstring_state(self, new_state:SerializableParameterMapType): + + +def set_mapstringvectorstring_state(self, new_state: SerializableParameterMapType): other = dict_to_parameter_map(new_state) self.clear() self.swap(other) + def get_itkparameterobject_state(self) -> SerializableParameterMapType: return parameter_object_to_list(self) -def set_itkparameterobject_state(self, new_state:SerializableParameterMapType): + + +def set_itkparameterobject_state(self, new_state: SerializableParameterMapType): other = list_to_parameter_object(new_state) try: other_pm = other.GetParameterMaps() self.SetParameterMaps(other_pm) except Exception as e: - dask.distributed.print(f'itk_dreg.elastix error: {e}') + dask.distributed.print(f"itk_dreg.elastix error: {e}") -#FIXME -#setattr(itk.elxParameterObjectPython.mapstringvectorstring, '__getstate__', get_mapstringvectorstring_state) -#setattr(itk.elxParameterObjectPython.mapstringvectorstring, '__setstate__,', set_mapstringvectorstring_state) -#setattr(itk.ParameterObject, '__getstate__', get_itkparameterobject_state) -#setattr(itk.ParameterObject, '__setstate__', set_itkparameterobject_state) +# FIXME +# setattr(itk.elxParameterObjectPython.mapstringvectorstring, '__getstate__', get_mapstringvectorstring_state) +# setattr(itk.elxParameterObjectPython.mapstringvectorstring, '__setstate__,', set_mapstringvectorstring_state) +# setattr(itk.ParameterObject, '__getstate__', get_itkparameterobject_state) +# setattr(itk.ParameterObject, '__setstate__', set_itkparameterobject_state) diff --git a/src/itk_dreg/elastix/util.py b/src/itk_dreg/elastix/util.py index 5d75528..78b87c3 100644 --- a/src/itk_dreg/elastix/util.py +++ b/src/itk_dreg/elastix/util.py @@ -1,21 +1,20 @@ #!/usr/bin/env python3 -import os import logging -from typing import List, Tuple, Dict, Optional, Any +from typing import List, Tuple, Dict, Any import itk -import numpy as np import itk_dreg.block.image as block_image logger = logging.getLogger(__name__) ParameterObjectType = itk.ParameterObject -ParameterMapType = Any # FIXME itk.elxParameterObjectPython.mapstringvectorstring -SerializableParameterMapType = Dict[str, Tuple[str,str]] +ParameterMapType = Any # FIXME itk.elxParameterObjectPython.mapstringvectorstring +SerializableParameterMapType = Dict[str, Tuple[str, str]] SerializableParameterObjectType = List[SerializableParameterMapType] + def compute_initial_translation( source_image: itk.Image, target_image: itk.Image ) -> itk.TranslationTransform[itk.D, 3]: @@ -33,6 +32,7 @@ def compute_initial_translation( return translation_transform + def get_elx_itk_transforms( registration_method: itk.ElastixRegistrationMethod, itk_transform_types: List[itk.Transform], @@ -51,20 +51,14 @@ def get_elx_itk_transforms( f"and {len(itk_transform_types)} ITK transforms to convert to" ) - itk_composite_transform = itk.CompositeTransform[ - value_type, dimension - ].New() + itk_composite_transform = itk.CompositeTransform[value_type, dimension].New() try: - for transform_index, itk_transform_type in enumerate( - itk_transform_types - ): + for transform_index, itk_transform_type in enumerate(itk_transform_types): if not itk_transform_type: # skip on None continue - elx_transform = registration_method.GetNthTransform( - transform_index - ) + elx_transform = registration_method.GetNthTransform(transform_index) itk_base_transform = registration_method.ConvertToItkTransform( elx_transform ) @@ -83,9 +77,7 @@ def get_elx_parameter_maps( """ Return a series of transform parameter results from Elastix registration """ - transform_parameter_object = ( - registration_method.GetTransformParameterObject() - ) + transform_parameter_object = registration_method.GetTransformParameterObject() output_parameter_maps = [ transform_parameter_object.GetParameterMap(parameter_map_index) for parameter_map_index in range( @@ -100,18 +92,15 @@ def make_default_elx_parameter_object() -> itk.ParameterObject: Generate a default set of parameters for Elastix registration """ parameter_object = itk.ParameterObject.New() - parameter_object.AddParameterMap( - parameter_object.GetDefaultParameterMap("rigid") - ) - parameter_object.AddParameterMap( - parameter_object.GetDefaultParameterMap("affine") - ) + parameter_object.AddParameterMap(parameter_object.GetDefaultParameterMap("rigid")) + parameter_object.AddParameterMap(parameter_object.GetDefaultParameterMap("affine")) bspline_map = parameter_object.GetDefaultParameterMap("bspline") bspline_map["FinalGridSpacingInPhysicalUnits"] = ("0.5000",) parameter_object.AddParameterMap(bspline_map) return parameter_object + def flatten_composite_transform( transform: itk.Transform, ) -> itk.CompositeTransform[itk.D, 3]: @@ -134,7 +123,7 @@ def _flatten_composite_transform_recursive( t = None try: t = itk.CompositeTransform[itk.D, 3].cast(transform) - except RuntimeError as e: + except RuntimeError: return [transform] transform_list = [] @@ -142,4 +131,4 @@ def _flatten_composite_transform_recursive( transform_list += [ *_flatten_composite_transform_recursive(t.GetNthTransform(index)) ] - return transform_list \ No newline at end of file + return transform_list diff --git a/src/itk_dreg/itk.py b/src/itk_dreg/itk.py index a944cb1..35677b6 100644 --- a/src/itk_dreg/itk.py +++ b/src/itk_dreg/itk.py @@ -8,23 +8,23 @@ - Debugging """ -from typing import Union, List +from typing import List import dask import dask.array as da import itk -import numpy as np -import numpy.typing as npt from itk_dreg.base.itk_typing import ImageReaderType, FloatImage3DType -def make_reader(filepath:str, - imageio:itk.ImageIOBase=None, - image_type:type[itk.Image]=FloatImage3DType) -> ImageReaderType: +def make_reader( + filepath: str, + imageio: itk.ImageIOBase = None, + image_type: type[itk.Image] = FloatImage3DType, +) -> ImageReaderType: """ Create an ITK image reader with initialized metadata. - + :param filepath: The local or remote file to read. :param imageio: Explicitly specifies how the image should be read. If `imageio` is not provided then ITK will attempt to determine @@ -42,11 +42,13 @@ def make_reader(filepath:str, reader.UpdateOutputInformation() return reader -def make_dask_array(image_reader:ImageReaderType, - chunk_size:List[int]=None) -> da.Array: + +def make_dask_array( + image_reader: ImageReaderType, chunk_size: List[int] = None +) -> da.Array: """ Create a chunked, unbuffered array representing an image buffer. - + :param image_reader: The `itk.ImageFileReader` image source. TODO: Verify that `image_reader` never actually buffers data for the lazy dask array. @@ -54,7 +56,7 @@ def make_dask_array(image_reader:ImageReaderType, buffered region but nonzero requested region. Investigate relation to NumPy bridge and determine whether this is a necessary requirement or can be worked around. - :chunk_size: The requested size of each subdivided region in the + :chunk_size: The requested size of each subdivided region in the result array. Default is 128 along each side. :returns: A subdivided `dask.Array` representing the unbuffered input image voxel region. @@ -65,13 +67,16 @@ def make_dask_array(image_reader:ImageReaderType, dimension = image_reader.GetOutput().GetImageDimension() chunk_size = chunk_size or [128] * dimension delayed_np_array = dask.delayed(itk.array_view_from_image)(image_reader.GetOutput()) - delayed_dask_array = da.from_delayed(delayed_np_array, image_reader.GetOutput().shape, dtype=pixel_type) + delayed_dask_array = da.from_delayed( + delayed_np_array, image_reader.GetOutput().shape, dtype=pixel_type + ) return delayed_dask_array.rechunk(chunk_size) + def write_buffered_region(image: itk.Image, filepath: str): """ Write out only the buffered region of an image to disk. - + :param image: The image to write. `image.GetBufferedRegion()` may or may not differ from `image.GetLargestPossibleRegion()`. diff --git a/src/itk_dreg/register.py b/src/itk_dreg/register.py index 762db32..01b1f35 100644 --- a/src/itk_dreg/register.py +++ b/src/itk_dreg/register.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import math -import os import logging import functools import itertools @@ -15,15 +14,24 @@ from .block import image as block_image from .block import convert as block_convert -from .base.image_block_interface import \ - BlockInfo, BlockRegStatus, LocatedBlock, BlockPairRegistrationResult, \ - RegistrationTransformResult, RegistrationResult, LocatedBlockResult -from .base.registration_interface import \ - ConstructReaderMethod, BlockPairRegistrationMethod, ReduceResultsMethod - -#logger = logging.getLogger(__name__) -logger = logging.getLogger('distributed.worker') -worker_logger = logging.getLogger('distributed.worker') +from .base.image_block_interface import ( + BlockInfo, + BlockRegStatus, + LocatedBlock, + BlockPairRegistrationResult, + RegistrationTransformResult, + RegistrationResult, + LocatedBlockResult, +) +from .base.registration_interface import ( + ConstructReaderMethod, + BlockPairRegistrationMethod, + ReduceResultsMethod, +) + +# logger = logging.getLogger(__name__) +logger = logging.getLogger("distributed.worker") +worker_logger = logging.getLogger("distributed.worker") """ Entry point for ITK-DReg multiresolution registration framework. @@ -44,24 +52,24 @@ def register_images( overlap_factors: Optional[List[float]] = None, # debugging parameters debug_iter: Optional[Iterator[bool]] = None, - dry_run:bool=False, - **kwargs + dry_run: bool = False, + **kwargs, ) -> RegistrationResult: """ Register blocks of an input image. - + This is the main entry point into `itk-dreg` registration infrastructure. """ - worker_logger.info('Entering registration') + worker_logger.info("Entering registration") overlap_factors = overlap_factors or [0] * fixed_da.ndim debug_iter = debug_iter or itertools.repeat(False) if dry_run: - raise NotImplementedError('Registeration dry run is not yet implemented.') + raise NotImplementedError("Registeration dry run is not yet implemented.") # Subdivide into subimage tasks according to the fixed image fixed_block_info_iterable = list(iterate_block_info(fixed_da)) - - worker_logger.info(f'Got iterable fixed block info {fixed_block_info_iterable}') + + worker_logger.info(f"Got iterable fixed block info {fixed_block_info_iterable}") # Register each subimage pair delayed_block_results = [ @@ -73,14 +81,16 @@ def register_images( block_registration_method=block_registration_method, overlap_factors=overlap_factors, write_debug=next(debug_iter), - **kwargs + **kwargs, ) for fixed_block_info in fixed_block_info_iterable ] delayed_located_block_results = [ - dask.delayed(LocatedBlockResult)(fixed_info=fixed_block_info,result=result) - for fixed_block_info, result in zip(iter(fixed_block_info_iterable), delayed_block_results) + dask.delayed(LocatedBlockResult)(fixed_info=fixed_block_info, result=result) + for fixed_block_info, result in zip( + iter(fixed_block_info_iterable), delayed_block_results + ) ] # Postprocess pairwise registration results into a single `itk.Transform` @@ -88,7 +98,7 @@ def register_images( block_results=delayed_located_block_results, fixed_reader_ctor=fixed_reader_ctor, initial_transform=initial_transform, - **kwargs + **kwargs, ) # Compose status codes for output @@ -98,37 +108,46 @@ def register_images( results=delayed_block_results, ) - def compose_output(composed_transform_result:RegistrationTransformResult, results:npt.ArrayLike): - return RegistrationResult(composed_transform_result,results) + def compose_output( + composed_transform_result: RegistrationTransformResult, results: npt.ArrayLike + ): + return RegistrationResult(composed_transform_result, results) + return dask.delayed(compose_output)(composed_transform_result, block_reg_results) def iterate_block_info(arr: dask.array) -> Iterator[BlockInfo]: # TODO testing - return (BlockInfo(chunk_loc, array_slices) - for chunk_loc, array_slices - in zip(np.ndindex(*arr.numblocks), - dask.array.core.slices_from_chunks(arr.chunks))) - + return ( + BlockInfo(chunk_loc, array_slices) + for chunk_loc, array_slices in zip( + np.ndindex(*arr.numblocks), dask.array.core.slices_from_chunks(arr.chunks) + ) + ) -def iterate_lazy_chunks(arr: dask.array) -> Iterator[Tuple[BlockInfo, dask.array.core.Array]]: +def iterate_lazy_chunks( + arr: dask.array, +) -> Iterator[Tuple[BlockInfo, dask.array.core.Array]]: """Return an iterator over input array chunks with metadata. - + Adapted from `dask-image`: https://github.com/dask/dask-image/blob/adcb217de766dd6fef99895ed1a33bf78a97d14b/dask_image/ndmeasure/__init__.py#L299 """ - block_info = (BlockInfo(chunk_loc, array_slices) - for chunk_loc, array_slices - in zip(np.ndindex(*arr.numblocks), - dask.array.core.slices_from_chunks(arr.chunks))) + block_info = ( + BlockInfo(chunk_loc, array_slices) + for chunk_loc, array_slices in zip( + np.ndindex(*arr.numblocks), dask.array.core.slices_from_chunks(arr.chunks) + ) + ) return ( - LocatedBlock(loc,arr) - for loc, arr - in zip( + LocatedBlock(loc, arr) + for loc, arr in zip( block_info, - map(functools.partial(operator.getitem, arr), - dask.array.core.slices_from_chunks(arr.chunks)) + map( + functools.partial(operator.getitem, arr), + dask.array.core.slices_from_chunks(arr.chunks), + ), ) ) @@ -143,73 +162,86 @@ def register_subimage( block_registration_method: BlockPairRegistrationMethod, overlap_factors: Optional[List[float]] = None, # debug parameters - default_result:Optional[BlockPairRegistrationResult]=None, - dry_run:bool=False, - **kwargs + default_result: Optional[BlockPairRegistrationResult] = None, + dry_run: bool = False, + **kwargs, ) -> BlockPairRegistrationResult: """ Callback to register one moving block to a fixed image subregion. - + `register_subimage` fetches voxel data representing initialized, physically aligned image subregions and then calls into a provided registration callback for the actual registration process. """ import dask.distributed - #FIXME - #logger.setLevel(logging.DEBUG) - dask.distributed.print(f'logger level: {logger.level} {logger.getEffectiveLevel()}') - #dask.distributed.print(f'Entering subimage with block {block_info}') - #worker_logger.info(f'Entering "register subimage" with block {block_info}') - - default_result = default_result or \ - BlockPairRegistrationResult( - transform=None, - transform_domain=None, - inv_transform=None, - inv_transform_domain=None, - status=BlockRegStatus.FAILURE) + + # FIXME + # logger.setLevel(logging.DEBUG) + dask.distributed.print(f"logger level: {logger.level} {logger.getEffectiveLevel()}") + # dask.distributed.print(f'Entering subimage with block {block_info}') + # worker_logger.info(f'Entering "register subimage" with block {block_info}') + + default_result = default_result or BlockPairRegistrationResult( + transform=None, + transform_domain=None, + inv_transform=None, + inv_transform_domain=None, + status=BlockRegStatus.FAILURE, + ) overlap_factors = overlap_factors or [0] * block_info.ndim if dry_run: - raise NotImplementedError('Dry run is not yet implemented.') + raise NotImplementedError("Dry run is not yet implemented.") if not block_info: - logger.error('Could not register block: no block info provided') + logger.error("Could not register block: no block info provided") return default_result # Parse dask inputs - if any([block_slice.step and block_slice.step != 1 - for block_slice in block_info.array_slice]): - logger.warning('Unexpected dask array slice step detected, proceeding with step size == 1 voxel') + if any( + [ + block_slice.step and block_slice.step != 1 + for block_slice in block_info.array_slice + ] + ): + logger.warning( + "Unexpected dask array slice step detected, proceeding with step size == 1 voxel" + ) chunk_loc_str = [str(x) for x in block_info.chunk_index] - start_index = [block_slice.start for block_slice in block_info.array_slice] # NumPy access order - padding = [math.ceil(data_len * overlap_factor * 0.5) - for data_len, overlap_factor in zip(block_info.shape, overlap_factors)] # NumPy access order + start_index = [ + block_slice.start for block_slice in block_info.array_slice + ] # NumPy access order + padding = [ + math.ceil(data_len * overlap_factor * 0.5) + for data_len, overlap_factor in zip(block_info.shape, overlap_factors) + ] # NumPy access order block_region = itk.ImageRegion[int(block_info.ndim)]( [int(val) for val in np.flip(start_index)], - [int(val) for val in np.flip(block_info.shape)] - ) # ITK access order + [int(val) for val in np.flip(block_info.shape)], + ) # ITK access order padded_region = itk.ImageRegion[int(block_info.ndim)]( [int(val) for val in np.flip(np.array(start_index) - padding)], - [int(val) for val in np.flip(np.array(block_info.shape) + 2 * np.array(padding))] - ) # ITK access order + [ + int(val) + for val in np.flip(np.array(block_info.shape) + 2 * np.array(padding)) + ], + ) # ITK access order # Represent physical position of fixed voxel block fixed_reader = fixed_reader_ctor() padded_region.Crop(fixed_reader.GetOutput().GetLargestPossibleRegion()) - if ( - not fixed_reader.GetOutput().GetLargestPossibleRegion()\ - .IsInside(padded_region) - ): + if not fixed_reader.GetOutput().GetLargestPossibleRegion().IsInside(padded_region): logger.warning( f"{chunk_loc_str} -> " f"Fixed padded region {padded_region} lies outside {fixed_reader.GetOutput().GetLargestPossibleRegion()}" ) return default_result - logger.debug(f"{chunk_loc_str} -> " - f'Fixed block has unpadded region {block_region} and padded region {padded_region}') - + logger.debug( + f"{chunk_loc_str} -> " + f"Fixed block has unpadded region {block_region} and padded region {padded_region}" + ) + fixed_reader.GetOutput().SetRequestedRegion(padded_region) fixed_reader.Update() fixed_block_image = itk.extract_image_filter( @@ -217,9 +249,11 @@ def register_subimage( extraction_region=fixed_reader.GetOutput().GetBufferedRegion(), ) if fixed_block_image.GetBufferedRegion() != padded_region: - logger.warning(f'Expected fixed block buffered region {padded_region}' - f'but read in {fixed_block_image.GetBufferedRegion}') - fixed_block_image.SetRequestedRegion(block_region) # ROI for registration + logger.warning( + f"Expected fixed block buffered region {padded_region}" + f"but read in {fixed_block_image.GetBufferedRegion}" + ) + fixed_block_image.SetRequestedRegion(block_region) # ROI for registration logger.debug( f"{chunk_loc_str} -> " @@ -244,7 +278,7 @@ def register_subimage( src_transform=initial_transform, crop_to_target=True, ) - + moving_padded_block_region = block_convert.get_target_block_region( block_region=block_convert.image_to_block_region( fixed_block_image.GetBufferedRegion() @@ -254,7 +288,9 @@ def register_subimage( src_transform=initial_transform, crop_to_target=True, ) - moving_padded_region = block_convert.block_to_image_region(moving_padded_block_region) + moving_padded_region = block_convert.block_to_image_region( + moving_padded_block_region + ) if ( not moving_reader.GetOutput() @@ -267,9 +303,11 @@ def register_subimage( f"largest possible region {moving_reader.GetOutput().GetLargestPossibleRegion()}" ) return default_result - logger.debug(f'{chunk_loc_str}: ' - f'Moving unpadded region: {block_convert.block_to_image_region(moving_block_region)},' - f' moving padded region: {moving_padded_region}') + logger.debug( + f"{chunk_loc_str}: " + f"Moving unpadded region: {block_convert.block_to_image_region(moving_block_region)}," + f" moving padded region: {moving_padded_region}" + ) moving_reader.GetOutput().SetRequestedRegion(moving_padded_region) moving_reader.Update() @@ -278,16 +316,16 @@ def register_subimage( extraction_region=moving_reader.GetOutput().GetBufferedRegion(), ) - #TODO determine root cause + # TODO determine root cause # Handle case where crop in fixed padded region can cause 1-voxel difference at target unpadded border moving_unpadded_region = block_convert.block_to_image_region(moving_block_region) moving_unpadded_region.Crop(moving_padded_region) moving_block_image.SetRequestedRegion(moving_unpadded_region) logger.debug( f"{chunk_loc_str} -> " - f'Moving subimage largest {moving_block_image.GetLargestPossibleRegion()}'\ - f' buffered {moving_block_image.GetBufferedRegion()}'\ - f' requested {moving_block_image.GetRequestedRegion()}' + f"Moving subimage largest {moving_block_image.GetLargestPossibleRegion()}" + f" buffered {moving_block_image.GetBufferedRegion()}" + f" requested {moving_block_image.GetRequestedRegion()}" ) logger.debug( @@ -301,7 +339,7 @@ def register_subimage( logger.warning(f"{chunk_loc_str} -> no signal observed in moving block") return default_result except RuntimeError as e: - logger.error(f'{chunk_loc_str}: {e}') + logger.error(f"{chunk_loc_str}: {e}") raise e try: @@ -311,25 +349,32 @@ def register_subimage( moving_subimage=moving_block_image, initial_transform=initial_transform, block_info=block_info, - **kwargs + **kwargs, ) if not issubclass(type(registration_result), BlockPairRegistrationResult): - raise TypeError(f'Received incompatible registration result of type' - f' {type(registration_result)}: {registration_result}') + raise TypeError( + f"Received incompatible registration result of type" + f" {type(registration_result)}: {registration_result}" + ) - logger.info(f'{chunk_loc_str} -> Registration completed with status {registration_result.status}') + logger.info( + f"{chunk_loc_str} -> Registration completed with status {registration_result.status}" + ) return registration_result except Exception as e: import traceback - logger.warning(f'{chunk_loc_str} -> {e}') + + logger.warning(f"{chunk_loc_str} -> {e}") traceback.print_exc() return default_result - -def compose_block_status_output(blocks_shape:List[int], - block_loc_list:Iterator[BlockInfo], - results:List[BlockPairRegistrationResult]) -> npt.ArrayLike: + +def compose_block_status_output( + blocks_shape: List[int], + block_loc_list: Iterator[BlockInfo], + results: List[BlockPairRegistrationResult], +) -> npt.ArrayLike: """Compose status codes from pairwise registration into an ND array.""" results_arr = np.zeros(blocks_shape, dtype=np.uint8) for block_loc, result in zip(block_loc_list, results): diff --git a/src/itk_dreg/stitch_reduce/__init__.py b/src/itk_dreg/stitch_reduce/__init__.py index 7bc1425..f348d33 100644 --- a/src/itk_dreg/stitch_reduce/__init__.py +++ b/src/itk_dreg/stitch_reduce/__init__.py @@ -1,4 +1,4 @@ """ ITK implementation of the `itk_dreg` interface to reduce transform results by sampling and/or stitching deformation fields -""" \ No newline at end of file +""" diff --git a/src/itk_dreg/stitch_reduce/dreg.py b/src/itk_dreg/stitch_reduce/dreg.py index c132b9f..f6f4b1a 100644 --- a/src/itk_dreg/stitch_reduce/dreg.py +++ b/src/itk_dreg/stitch_reduce/dreg.py @@ -1,23 +1,30 @@ #!/usr/bin/env python3 -from email.policy import default import logging -from typing import Union, Iterable, Tuple, List, Type, Optional +from typing import Iterable, List, Optional import itk import numpy as np -import numpy.typing as npt -from scipy.spatial.transform import Rotation -from itk_dreg.base.image_block_interface import LocatedBlockResult, LocatedBlock, BlockRegStatus, RegistrationTransformResult -from itk_dreg.base.itk_typing import TransformType, ImageType -from itk_dreg.base.registration_interface import ReduceResultsMethod, ConstructReaderMethod +from itk_dreg.base.image_block_interface import ( + LocatedBlockResult, + BlockRegStatus, + RegistrationTransformResult, +) +from itk_dreg.base.itk_typing import TransformType +from itk_dreg.base.registration_interface import ( + ReduceResultsMethod, + ConstructReaderMethod, +) import itk_dreg.block.convert as block_convert -import itk_dreg.block.image as block_image from .transform_collection import TransformCollection, TransformEntry from .transform import collection_to_deformation_field_transform -from .matrix_transform import estimate_euler_transform_consensus, itk_matrix_transform_to_matrix, to_itk_euler_transform +from .matrix_transform import ( + estimate_euler_transform_consensus, + itk_matrix_transform_to_matrix, + to_itk_euler_transform, +) """ Adapter methods to use transform utilities in ITK-DReg framework. @@ -26,7 +33,7 @@ """ logger = logging.getLogger(__name__) - + class StitchReduceResultsMethod(ReduceResultsMethod): def __call__( @@ -34,67 +41,76 @@ def __call__( block_results: Iterable[LocatedBlockResult], fixed_reader_ctor: ConstructReaderMethod, initial_transform: itk.Transform, - stitch_grid_scale_factors: Optional[List[float]] = [1.0,1.0,1.0], - **kwargs + stitch_grid_scale_factors: Optional[List[float]] = [1.0, 1.0, 1.0], + **kwargs, ) -> RegistrationTransformResult: target_reader = fixed_reader_ctor() forward_transform = reduce_to_deformation_field_transform( block_results=block_results, reference_image=target_reader.GetOutput(), initial_transform=initial_transform, - scale_factors=stitch_grid_scale_factors + scale_factors=stitch_grid_scale_factors, ) return RegistrationTransformResult( - transform=forward_transform, - inv_transform=None + transform=forward_transform, inv_transform=None ) class EulerConsensusReduceResultsMethod(ReduceResultsMethod): def __call__( - self, - block_results: Iterable[LocatedBlockResult], - **kwargs + self, block_results: Iterable[LocatedBlockResult], **kwargs ) -> RegistrationTransformResult: - samples_arr = np.zeros([0,4,4], dtype=np.float32) - for result in map(lambda located_result : located_result.result, block_results): - logger.debug(f'Attempting to reduce transform {result.transform}') + samples_arr = np.zeros([0, 4, 4], dtype=np.float32) + for result in map(lambda located_result: located_result.result, block_results): + logger.debug(f"Attempting to reduce transform {result.transform}") transform = result.transform - if type(result.transform) == itk.CompositeTransform[itk.D,3] and result.transform.GetNumberOfTransforms() == 1: - transform = itk.Euler3DTransform[itk.D].cast(transform.GetNthTransform(0)) + if ( + type(result.transform) == itk.CompositeTransform[itk.D, 3] + and result.transform.GetNumberOfTransforms() == 1 + ): + transform = itk.Euler3DTransform[itk.D].cast( + transform.GetNthTransform(0) + ) if transform and type(transform) != itk.Euler3DTransform[itk.D]: - raise TypeError(f'Could not get rigid consensus with transform type {type(transform)}') + raise TypeError( + f"Could not get rigid consensus with transform type {type(transform)}" + ) if transform and result.status == BlockRegStatus.SUCCESS: - samples_arr = np.vstack((samples_arr, - np.expand_dims(itk_matrix_transform_to_matrix(transform), 0))) + samples_arr = np.vstack( + ( + samples_arr, + np.expand_dims(itk_matrix_transform_to_matrix(transform), 0), + ) + ) rigid_consensus_mat = estimate_euler_transform_consensus(samples_arr) return RegistrationTransformResult( - transform=to_itk_euler_transform(rigid_consensus_mat), - inv_transform=None + transform=to_itk_euler_transform(rigid_consensus_mat), inv_transform=None ) -def reduce_to_deformation_field_transform( - block_results:Iterable[LocatedBlockResult], - reference_image:itk.Image[itk.F,3], - initial_transform:TransformType, - scale_factors:List[float] = [10,10,10], - default_transform:itk.Transform = None -) -> itk.DisplacementFieldTransform[itk.D,3]: +def reduce_to_deformation_field_transform( + block_results: Iterable[LocatedBlockResult], + reference_image: itk.Image[itk.F, 3], + initial_transform: TransformType, + scale_factors: List[float] = [10, 10, 10], + default_transform: itk.Transform = None, +) -> itk.DisplacementFieldTransform[itk.D, 3]: """ Resample from a set of block registration results into a deformation field transform. """ - default_transform = default_transform or itk.TranslationTransform[itk.D,3].New() + default_transform = default_transform or itk.TranslationTransform[itk.D, 3].New() organized_transforms = TransformCollection( blend_method=TransformCollection.blend_distance_weighted_mean ) for located_result in block_results: if located_result.result.status == BlockRegStatus.SUCCESS: - organized_transforms.push(TransformEntry( - transform=located_result.result.transform, - domain=located_result.result.transform_domain - )) + organized_transforms.push( + TransformEntry( + transform=located_result.result.transform, + domain=located_result.result.transform_domain, + ) + ) continue else: # TODO estimate the physical domain for the failed block and @@ -120,7 +136,7 @@ def reduce_to_deformation_field_transform( # if located_result.status != BlockRegStatus.SUCCESS: # block_transforms.push(TransformEntry(DEFAULT_TRANSFORM, transform_domain_image)) # continue - + # # TODO: can we extend to get spatial metadata for any block? # if (type(located_result.transform) == itk.DisplacementFieldTransform[itk.D, 3] or\ # type(located_result.transform) == itk.DisplacementFieldTransform[itk.F, 3]): @@ -133,23 +149,23 @@ def reduce_to_deformation_field_transform( # TransformEntry(located_result.transform, transform_domain_image) # ) if not organized_transforms.transforms: - raise ValueError('Failed to compose at least one transform for sampling') - logger.debug(f'Collected domains: {organized_transforms.domains}') + raise ValueError("Failed to compose at least one transform for sampling") + logger.debug(f"Collected domains: {organized_transforms.domains}") physical_domains = [ block_convert.block_to_physical_region( block_region=block_convert.image_to_block_region( image_region=domain.GetLargestPossibleRegion() ), - ref_image=domain + ref_image=domain, ) for domain in organized_transforms.domains ] - logger.debug(f'Physical domains: {physical_domains}') - return collection_to_deformation_field_transform(\ - organized_transforms, reference_image, initial_transform, scale_factors) + logger.debug(f"Physical domains: {physical_domains}") + return collection_to_deformation_field_transform( + organized_transforms, reference_image, initial_transform, scale_factors + ) - class TransformCollectionReduceResultsMethod(ReduceResultsMethod): """ Return a transform collection of results. @@ -157,18 +173,19 @@ class TransformCollectionReduceResultsMethod(ReduceResultsMethod): `transform_collection` does not yet extend `itk.Transform`. This should not be used in production. """ + def __call__(self, block_results: Iterable[LocatedBlockResult], **kwargs): organized_transforms = TransformCollection( blend_method=TransformCollection.blend_distance_weighted_mean ) for located_result in block_results: if located_result.result.status == BlockRegStatus.SUCCESS: - organized_transforms.push(TransformEntry( - transform=located_result.result.transform, - domain=located_result.result.transform_domain - )) + organized_transforms.push( + TransformEntry( + transform=located_result.result.transform, + domain=located_result.result.transform_domain, + ) + ) return RegistrationTransformResult( - transform=organized_transforms, - inv_transform=None + transform=organized_transforms, inv_transform=None ) - diff --git a/src/itk_dreg/stitch_reduce/matrix_transform.py b/src/itk_dreg/stitch_reduce/matrix_transform.py index f55128e..29fe986 100644 --- a/src/itk_dreg/stitch_reduce/matrix_transform.py +++ b/src/itk_dreg/stitch_reduce/matrix_transform.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 -import logging -from typing import Union, Iterable, Tuple, List, Type, Optional +from typing import Union import itk import numpy as np @@ -9,7 +8,7 @@ from scipy.spatial.transform import Rotation # from registration_methods import register_elastix -from itk_dreg.base.image_block_interface import BlockPairRegistrationResult, LocatedBlock, BlockRegStatus + def itk_matrix_transform_to_matrix( t: Union[itk.Euler3DTransform[itk.D], itk.AffineTransform[itk.D, 3]] @@ -20,16 +19,18 @@ def itk_matrix_transform_to_matrix( output_arr[:3, 3] = np.array(t.GetTranslation()) return output_arr -def to_itk_euler_transform(mat:npt.ArrayLike) -> itk.Euler3DTransform[itk.D]: + +def to_itk_euler_transform(mat: npt.ArrayLike) -> itk.Euler3DTransform[itk.D]: transform = itk.Euler3DTransform[itk.D].New() - transform.SetMatrix(np_to_itk_matrix(mat[:3,:3])) - transform.Translate(mat[:3,3]) + transform.SetMatrix(np_to_itk_matrix(mat[:3, :3])) + transform.Translate(mat[:3, 3]) return transform + def postprocess_block_matrix_transform( t: Union[itk.Euler3DTransform[itk.D], itk.AffineTransform[itk.D, 3]] ) -> npt.ArrayLike: - return t # do nothing + return t # do nothing def np_to_itk_matrix(arr: npt.ArrayLike) -> itk.Matrix[itk.D, 3, 3]: @@ -106,4 +107,4 @@ def average_translation(translations: npt.ArrayLike) -> npt.ArrayLike: """Compute linear average of translation vectors""" assert translations.ndim == 2 assert translations.shape[1] == 3 - return np.mean(translations, axis=0) \ No newline at end of file + return np.mean(translations, axis=0) diff --git a/src/itk_dreg/stitch_reduce/transform.py b/src/itk_dreg/stitch_reduce/transform.py index e278b8b..c20c6ac 100644 --- a/src/itk_dreg/stitch_reduce/transform.py +++ b/src/itk_dreg/stitch_reduce/transform.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import logging -from typing import Union, Iterable, Tuple, List, Type, Optional +from typing import List import itk import numpy as np @@ -26,7 +26,7 @@ # scale_factors:List[float]) -> TransformType: # """ # Convert an ITK transform block alignment result to a displacement field transform. - + # The output transform is defined over the domain of the input target block image. # target_block_image is a partially buffered ITK image. @@ -57,11 +57,12 @@ # displacement_field=displacement_field # ) + def collection_to_deformation_field_transform( transform_collection: TransformCollection, reference_image: itk.Image, initial_transform: TransformType, - scale_factors:List[float] + scale_factors: List[float], ) -> itk.DisplacementFieldTransform[itk.D, 3]: """ Stitch multiple input displacement fields into one transform. @@ -72,47 +73,53 @@ def collection_to_deformation_field_transform( Assumptions: - input physical regions cover output physical region - domain overlap is handled in TransformCollection - """ + """ dimension = reference_image.GetImageDimension() - DEFAULT_VALUE = itk.Vector[itk.D,dimension]([0] * dimension) + DEFAULT_VALUE = itk.Vector[itk.D, dimension]([0] * dimension) # Get oriented, unallocated image representing the requested bounds output_field = block_image.physical_region_to_itk_image( physical_region=block_convert.image_to_physical_region( image_region=reference_image.GetLargestPossibleRegion(), ref_image=reference_image, - src_transform=initial_transform + src_transform=initial_transform, ), - spacing=[spacing * scale_factor - for spacing, scale_factor - in zip(itk.spacing(reference_image), scale_factors)], + spacing=[ + spacing * scale_factor + for spacing, scale_factor in zip( + itk.spacing(reference_image), scale_factors + ) + ], direction=reference_image.GetDirection(), extend_beyond=True, - image_type=itk.Image[itk.Vector[itk.D,dimension],dimension] + image_type=itk.Image[itk.Vector[itk.D, dimension], dimension], ) output_field.Allocate() output_field.FillBuffer(DEFAULT_VALUE) - logger.info(f'Output field has size {itk.size(output_field)}' - f' and domain ' - f'{block_image.image_to_physical_region(output_field.GetBufferedRegion(), output_field)}') + logger.info( + f"Output field has size {itk.size(output_field)}" + f" and domain " + f"{block_image.image_to_physical_region(output_field.GetBufferedRegion(), output_field)}" + ) - #TODO serial bottleneck, parallelization required. + # TODO serial bottleneck, parallelization required. # To be resolved with `TransformCollection` ITK C++ implementation. for x, y, z in np.ndindex(tuple(itk.size(output_field))): if np.random.uniform() < 0.005: - logger.debug(f'Sampling displacement at voxel [{x},{y},{z}]') + logger.debug(f"Sampling displacement at voxel [{x},{y},{z}]") index = [int(x), int(y), int(z)] physical_point = output_field.TransformIndexToPhysicalPoint(index) try: - output_field.SetPixel(index, - transform_collection.transform_point(physical_point) - physical_point) + output_field.SetPixel( + index, + transform_collection.transform_point(physical_point) - physical_point, + ) except ValueError as e: - logger.debug(f'Failed to sample displacement at [{x},{y},{z}]: {e}') + logger.debug(f"Failed to sample displacement at [{x},{y},{z}]: {e}") continue vector_type = itk.template(output_field)[1][0] scalar_type = itk.template(vector_type)[1][0] - output_transform = itk.DisplacementFieldTransform[scalar_type,dimension].New() + output_transform = itk.DisplacementFieldTransform[scalar_type, dimension].New() output_transform.SetDisplacementField(output_field) return output_transform - diff --git a/src/itk_dreg/stitch_reduce/transform_collection.py b/src/itk_dreg/stitch_reduce/transform_collection.py index dcecdd1..025139d 100644 --- a/src/itk_dreg/stitch_reduce/transform_collection.py +++ b/src/itk_dreg/stitch_reduce/transform_collection.py @@ -12,16 +12,17 @@ logger = logging.getLogger(__name__) + @dataclass class TransformEntry: - transform:Type[itk.Transform] - domain:Optional[itk.Image] + transform: Type[itk.Transform] + domain: Optional[itk.Image] class TransformCollection: """ Represent a collection of (possibly bounded) itk.Transform(s). - + A single point to transform may fall within multiple transform domains. In that case a simple average is taken of point output candidates. @@ -31,30 +32,34 @@ class TransformCollection: TODO: Re-implement in ITK C++ to inherit from `itk.Transform` for use in filters and for improved performance. """ + @property def transforms(self): - return [entry.transform - for entry in self.transform_and_domain_list] + return [entry.transform for entry in self.transform_and_domain_list] @property def domains(self): - return [entry.domain - for entry in self.transform_and_domain_list] + return [entry.domain for entry in self.transform_and_domain_list] @staticmethod - def _bounds_contains(bounds:npt.ArrayLike, pt:npt.ArrayLike) -> bool: - return np.all(np.min(bounds, axis=0) <= pt) and np.all(np.max(bounds, axis=0) >= pt) - + def _bounds_contains(bounds: npt.ArrayLike, pt: npt.ArrayLike) -> bool: + return np.all(np.min(bounds, axis=0) <= pt) and np.all( + np.max(bounds, axis=0) >= pt + ) + @staticmethod - def blend_simple_mean(input_pt:itk.Point, region_contributors:List[TransformEntry]) -> npt.ArrayLike: + def blend_simple_mean( + input_pt: itk.Point, region_contributors: List[TransformEntry] + ) -> npt.ArrayLike: pts = [ - entry.transform.TransformPoint(input_pt) - for entry in region_contributors + entry.transform.TransformPoint(input_pt) for entry in region_contributors ] return np.mean(pts, axis=0) - + @classmethod - def blend_distance_weighted_mean(cls, input_pt:itk.Point, region_contributors:List[TransformEntry]) -> npt.ArrayLike: + def blend_distance_weighted_mean( + cls, input_pt: itk.Point, region_contributors: List[TransformEntry] + ) -> npt.ArrayLike: """ Blending method to weight transform results by their proximity to the edge of the corresponding transform domain. @@ -64,15 +69,23 @@ def blend_distance_weighted_mean(cls, input_pt:itk.Point, region_contributors:Li """ MIN_WEIGHT = 1e-9 - point_candidates = [entry.transform.TransformPoint(input_pt) for entry in region_contributors] - weights = [cls._physical_distance_from_edge(input_pt, entry.domain)[0] if entry.domain else MIN_WEIGHT - for entry in region_contributors] + point_candidates = [ + entry.transform.TransformPoint(input_pt) for entry in region_contributors + ] + weights = [ + cls._physical_distance_from_edge(input_pt, entry.domain)[0] + if entry.domain + else MIN_WEIGHT + for entry in region_contributors + ] if np.any([w < 0 for w in weights]): - logger.error('Detected at least one negative weight indicating' - ' a point unexpectedly lies outside a contributing region.' - ' May impact transform blending results.') - + logger.error( + "Detected at least one negative weight indicating" + " a point unexpectedly lies outside a contributing region." + " May impact transform blending results." + ) + # Treat weights lying on an edge as if they were very small step inside the edge. # Domains are considered inclusive at bounds, meaning a single point candidate # at the boundary of a domain is a valid candidate and should be included in weighted averaging. @@ -81,7 +94,9 @@ def blend_distance_weighted_mean(cls, input_pt:itk.Point, region_contributors:Li return np.average(point_candidates, axis=0, weights=interior_weights) @classmethod - def _physical_distance_from_edge(cls, input_pt:itk.Point, domain:itk.Image) -> Tuple[float, int]: + def _physical_distance_from_edge( + cls, input_pt: itk.Point, domain: itk.Image + ) -> Tuple[float, int]: """ Estimate unsigned minimum physical distance to closest domain side. @@ -93,7 +108,9 @@ def _physical_distance_from_edge(cls, input_pt:itk.Point, domain:itk.Image) -> T 1. The zero-indexed axis to travel to reach the nearest edge. """ # Set up unit vectors mapping from voxel to physical space - voxel_step_vecs = np.matmul(np.array(domain.GetDirection()), np.eye(3) * itk.spacing(domain)) + voxel_step_vecs = np.matmul( + np.array(domain.GetDirection()), np.eye(3) * itk.spacing(domain) + ) physical_step = np.ravel( np.take_along_axis( voxel_step_vecs, @@ -103,18 +120,19 @@ def _physical_distance_from_edge(cls, input_pt:itk.Point, domain:itk.Image) -> T ) assert physical_step.ndim == 1 and physical_step.shape[0] == 3 assert np.all(physical_step) - + pixel_axis_dists = cls._pixel_distance_from_edge(input_pt, domain) physical_axis_dists = [ np.linalg.norm(axis_dist * physical_step) - for axis_dist, physical_step - in zip(pixel_axis_dists, physical_step) + for axis_dist, physical_step in zip(pixel_axis_dists, physical_step) ] arg_min = np.argmin(np.abs(physical_axis_dists)) return physical_axis_dists[arg_min], arg_min - + @staticmethod - def _pixel_distance_from_edge(input_pt:itk.Point, domain:itk.Image) -> npt.ArrayLike: + def _pixel_distance_from_edge( + input_pt: itk.Point, domain: itk.Image + ) -> npt.ArrayLike: """ Estimate signed voxel distance to each image side. @@ -122,58 +140,73 @@ def _pixel_distance_from_edge(input_pt:itk.Point, domain:itk.Image) -> npt.Array https://github.com/InsightSoftwareConsortium/ITKMontage/blob/master/include/itkTileMergeImageFilter.hxx#L217 """ VOXEL_HALF_STEP = [0.5] * 3 - dist_to_lower_bound = np.array(domain.TransformPhysicalPointToContinuousIndex(input_pt)) -\ - (np.array(domain.GetLargestPossibleRegion().GetIndex()) - VOXEL_HALF_STEP) - dist_to_upper_bound = np.array(domain.GetLargestPossibleRegion().GetSize()) - dist_to_lower_bound + dist_to_lower_bound = np.array( + domain.TransformPhysicalPointToContinuousIndex(input_pt) + ) - (np.array(domain.GetLargestPossibleRegion().GetIndex()) - VOXEL_HALF_STEP) + dist_to_upper_bound = ( + np.array(domain.GetLargestPossibleRegion().GetSize()) - dist_to_lower_bound + ) pixel_dists = np.array([dist_to_lower_bound, dist_to_upper_bound]) - axis_mins = np.ravel(np.take_along_axis(pixel_dists, - np.expand_dims(np.argmin(np.abs(pixel_dists), axis=0), axis=0), - axis=0)) + axis_mins = np.ravel( + np.take_along_axis( + pixel_dists, + np.expand_dims(np.argmin(np.abs(pixel_dists), axis=0), axis=0), + axis=0, + ) + ) return axis_mins - - + @staticmethod - def _resolve_displacements(vecs:List[itk.Vector]) -> npt.ArrayLike: + def _resolve_displacements(vecs: List[itk.Vector]) -> npt.ArrayLike: return np.mean(vecs, axis=0) - + @staticmethod - def _validate_entry(entry:TransformEntry) -> None: - if not issubclass(type(entry.transform), itk.Transform[itk.D,3,3]) and\ - not issubclass(type(entry.transform), itk.Transform[itk.F,3,3]): - raise TypeError(f'Bad entry transform type: {type(entry.transform)}') - - if entry.domain and\ - itk.template(entry.domain)[0] != itk.Image and\ - itk.template(entry.domain)[0] != itk.VectorImage: - raise TypeError(f'Bad entry domain type: {type(entry.domain)}') - - def __init__(self, - transform_and_domain_list:List[Type[TransformEntry]]=None, - blend_method:Callable[[itk.Point, List[TransformEntry]], itk.Point]=None,): + def _validate_entry(entry: TransformEntry) -> None: + if not issubclass( + type(entry.transform), itk.Transform[itk.D, 3, 3] + ) and not issubclass(type(entry.transform), itk.Transform[itk.F, 3, 3]): + raise TypeError(f"Bad entry transform type: {type(entry.transform)}") + + if ( + entry.domain + and itk.template(entry.domain)[0] != itk.Image + and itk.template(entry.domain)[0] != itk.VectorImage + ): + raise TypeError(f"Bad entry domain type: {type(entry.domain)}") + + def __init__( + self, + transform_and_domain_list: List[Type[TransformEntry]] = None, + blend_method: Callable[[itk.Point, List[TransformEntry]], itk.Point] = None, + ): transform_and_domain_list = transform_and_domain_list or [] for entry in transform_and_domain_list: TransformCollection._validate_entry(entry) - self.blend_method = blend_method if blend_method else TransformCollection.blend_distance_weighted_mean + self.blend_method = ( + blend_method + if blend_method + else TransformCollection.blend_distance_weighted_mean + ) self.transform_and_domain_list = transform_and_domain_list - - def push(self, entry:Type[TransformEntry]) -> None: + + def push(self, entry: Type[TransformEntry]) -> None: TransformCollection._validate_entry(entry) self.transform_and_domain_list.append(entry) - - def transform_point(self, pt:itk.Point[itk.F,3]) -> npt.ArrayLike: + + def transform_point(self, pt: itk.Point[itk.F, 3]) -> npt.ArrayLike: region_contributors = [ entry for entry in self.transform_and_domain_list - if not entry.domain or TransformCollection._bounds_contains( - block_image.get_sample_bounds(entry.domain), - pt + if not entry.domain + or TransformCollection._bounds_contains( + block_image.get_sample_bounds(entry.domain), pt ) ] if not region_contributors: - raise ValueError(f'No candidates found: {pt} lies outside all transform domains') - return itk.Point[itk.F,3](self.blend_method(pt, region_contributors)) - - def TransformPoint(self, pt:itk.Point[itk.F,3]) -> npt.ArrayLike: - return self.transform_point(pt) - + raise ValueError( + f"No candidates found: {pt} lies outside all transform domains" + ) + return itk.Point[itk.F, 3](self.blend_method(pt, region_contributors)) + def TransformPoint(self, pt: itk.Point[itk.F, 3]) -> npt.ArrayLike: + return self.transform_point(pt)