Skip to content

Commit

Permalink
DOC: docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tbirdso committed Nov 10, 2023
1 parent abe47c5 commit 735c489
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 109 deletions.
12 changes: 11 additions & 1 deletion src/itk_dreg/reduce_dfield/dreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@


class ReduceToDisplacementFieldMethod(ReduceResultsMethod):
"""
Implements `itk-dreg` registration reduction by composing an
`itk.DisplacementFieldTransform` from pairwise subimage registration results.
"""

def __call__(
self,
block_results: Iterable[LocatedBlockResult],
Expand All @@ -57,6 +62,11 @@ def __call__(


class EulerConsensusReduceResultsMethod(ReduceResultsMethod):
"""
Implements `itk-dreg` registration reduction by composing an
`itk.Euler3DTransform` from a pairwise subimage rigid registration results.
"""

def __call__(
self, block_results: Iterable[LocatedBlockResult], **kwargs
) -> RegistrationTransformResult:
Expand Down Expand Up @@ -139,7 +149,7 @@ class TransformCollectionReduceResultsMethod(ReduceResultsMethod):
"""
Return a transform collection of results.
`transform_collection` does not yet extend `itk.Transform`.
Note (2023.11.10): `transform_collection` does not yet extend `itk.Transform`.
This should not be used in production.
"""

Expand Down
41 changes: 5 additions & 36 deletions src/itk_dreg/reduce_dfield/matrix_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@ def itk_matrix_transform_to_matrix(


def to_itk_euler_transform(mat: npt.ArrayLike) -> itk.Euler3DTransform[itk.D]:
"""
Convert from a NumPy affine matrix to `itk.Euler3DTransform` representation.
:param mat: The input 4x4 affine matrix
:return: The corresponding `itk.Euler3DTransform`.
"""
transform = itk.Euler3DTransform[itk.D].New()
transform.SetMatrix(np_to_itk_matrix(mat[:3, :3]))
transform.Translate(mat[:3, 3])
return transform


def postprocess_block_matrix_transform(
t: Union[itk.Euler3DTransform[itk.D], itk.AffineTransform[itk.D, 3]]
) -> npt.ArrayLike:
return t # do nothing


def np_to_itk_matrix(arr: npt.ArrayLike) -> itk.Matrix[itk.D, 3, 3]:
"""Convert a 3x3 matrix from numpy to ITK format"""
vnl_matrix = itk.Matrix[itk.D, 3, 3]().GetVnlMatrix()
Expand All @@ -54,32 +53,6 @@ def estimate_euler_transform_consensus(transforms: npt.ArrayLike) -> npt.ArrayLi
return average_transform


def estimate_affine_transform_consensus(transforms: npt.ArrayLike) -> npt.ArrayLike:
"""Estimate a mean representation of a list of transform results"""
if transforms.ndim != 3 or transforms.shape[1] != 4 or transforms.shape[2] != 4:
raise ValueError(
f"Expected list of 4x4 affine transforms but received array with shape {transforms.shape}"
)

translations = transforms[:, :3, 3]
scale_factors = np.linalg.norm(transforms[:, :3, :3], axis=1)
rotations = np.divide(
transforms[:, :3, :3], np.tile(np.expand_dims(scale_factors, 1), (1, 3, 1))
)

est_rot = average_rotation(rotations)
est_scale = average_scale_factors(scale_factors)
est_translation = average_translation(translations)

rot_block = est_rot
rot_block = np.multiply(est_rot, np.tile(np.expand_dims(est_scale, 0), (3, 1)))

average_transform = np.eye(4)
average_transform[:3, :3] = rot_block
average_transform[:3, 3] = est_translation
return average_transform


def average_rotation(rotations: npt.ArrayLike) -> npt.ArrayLike:
"""Compute average rotation by way of linear quaternion averaging"""
if rotations.ndim != 3 or rotations.shape[1] != 3 or rotations.shape[2] != 3:
Expand All @@ -99,10 +72,6 @@ def average_rotation(rotations: npt.ArrayLike) -> npt.ArrayLike:
return Rotation.from_quat(accum_quat).as_matrix()


def average_scale_factors(scale_factors: npt.ArrayLike) -> npt.ArrayLike:
return np.mean(scale_factors, axis=0)


def average_translation(translations: npt.ArrayLike) -> npt.ArrayLike:
"""Compute linear average of translation vectors"""
assert translations.ndim == 2
Expand Down
50 changes: 12 additions & 38 deletions src/itk_dreg/reduce_dfield/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,6 @@
using ITK Python, ITKElastix, and Dask.
"""

# def to_displacement_field_transform(
# itk_transform:TransformType,
# initial_transform:TransformType,
# target_block_image:itk.Image[itk.F,3],
# scale_factors:List[float]) -> TransformType:
# """
# Convert an ITK transform block alignment result to a displacement field transform.

# The output transform is defined over the domain of the input target block image.

# target_block_image is a partially buffered ITK image.

# The output displacement field is a fully buffered ITK image over the target block region.
# """
# target_region = block_image.image_to_physical_region(
# image_region=target_block_image.GetBufferedRegion(),
# ref_image=target_block_image,
# src_transform=initial_transform
# )
# ref_image = block_image.physical_region_to_itk_image(
# physical_region=target_region,
# spacing=[spacing * scale_factor
# for spacing, scale_factor in zip(itk.spacing(target_block_image), scale_factors)],
# direction=np.array(target_block_image.GetDirection()),
# extend_beyond=True
# )

# displacement_field = itk.transform_to_displacement_field_filter(
# transform=itk_transform,
# use_reference_image=True,
# reference_image=ref_image
# )
# vector_pixel_type = itk.template(displacement_field)[1][0]
# scalar_type = itk.template(vector_pixel_type)[1][0]
# return itk.DisplacementFieldTransform[scalar_type,3].New(
# displacement_field=displacement_field
# )


def collection_to_deformation_field_transform(
transform_collection: TransformCollection,
Expand All @@ -73,6 +35,18 @@ def collection_to_deformation_field_transform(
Assumptions:
- input physical regions cover output physical region
- domain overlap is handled in TransformCollection
:param transform_collection: The `TransformCollection` to discretely sample into
an `itk.DisplacementFieldTransform` output.
:param reference_image: The image to reference to apply spatial metadata to
the output displacement field image.
:param initial_transform: The initial transform to apply to the reference image
to get appropriate initial positioning for the displacement field image.
:param scale_factors: The desired scale factors to reduce or increase the
size of the displacement field image grid relative to the reference image.
:return: An `itk.DisplacementFieldTransform` discretizing the input
collection of transforms. May be applied in sequence after `initial_transform`
to map from an input to an output image domain.
"""
dimension = reference_image.GetImageDimension()
DEFAULT_VALUE = itk.Vector[itk.D, dimension]([0] * dimension)
Expand Down
91 changes: 81 additions & 10 deletions src/itk_dreg/reduce_dfield/transform_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,25 @@ class TransformCollection:
"""

@property
def transforms(self):
def transforms(self) -> List[itk.Transform]:
"""The `itk.Transform`s in this collection."""
return [entry.transform for entry in self.transform_and_domain_list]

@property
def domains(self):
def domains(self) -> List[Optional[itk.Image]]:
"""The transform domains in this collection."""
return [entry.domain for entry in self.transform_and_domain_list]

@staticmethod
def _bounds_contains(bounds: npt.ArrayLike, pt: npt.ArrayLike) -> bool:
"""
Determines whether a point (X,Y,Z) falls within an axis-aligned physical bounding box.
:param bounds: A 2x3 voxel array representing axis-aligned inclusive
upper and lower bounds in (X,Y,Z) physical space.
:param pt: A physical point (X,Y,Z).
:return: True if the point is contained inside the inclusive physical region, else False.
"""
return np.all(np.min(bounds, axis=0) <= pt) and np.all(
np.max(bounds, axis=0) >= pt
)
Expand All @@ -51,6 +61,21 @@ def _bounds_contains(bounds: npt.ArrayLike, pt: npt.ArrayLike) -> bool:
def blend_simple_mean(
input_pt: itk.Point, region_contributors: List[TransformEntry]
) -> npt.ArrayLike:
"""
Method to blend among multiple possible transform outputs
by performing unweighted averaging of all point candidates.
If the input point falls within the domain overlap of three transforms,
each transform will be applied independently to produce three point candidates
and the output will be the linear sum of each candidate weighted by (1/3).
This is a simple approach that may result in significant discontinuities
at transform domain edges.
:param input_pt: The point to transform.
:param region_contributors: The transforms whose domains include the given point.
:return: The average transformed point.
"""
pts = [
entry.transform.TransformPoint(input_pt) for entry in region_contributors
]
Expand All @@ -61,11 +86,17 @@ def blend_distance_weighted_mean(
cls, input_pt: itk.Point, region_contributors: List[TransformEntry]
) -> npt.ArrayLike:
"""
Blending method to weight transform results by their proximity to the edge of the corresponding transform domain.
Blending method to weight transform results by their proximity
to the edge of the corresponding transform domain.
This blending approach avoids discontinuities at transform domain bounds.
Transforms that have no bounds on the domain over which they apply are weighted minimally.
#TODO Investigate alternatives to consider unbounded/background transform information.
:param input_pt: The point to transform.
:param region_contributors: The transforms whose domains include the given point.
:return: The average transformed point.
"""

MIN_WEIGHT = 1e-9
Expand Down Expand Up @@ -103,6 +134,8 @@ def _physical_distance_from_edge(
Handles domain with isotropic or anisotropic spacing over
non-axis-aligned image domain representation.
:param input_pt: The point to transform.
:param domain: The transform domain to consider.
:returns: Tuple with elements:
0. The physical linear distance to the nearest image edge, and
1. The zero-indexed axis to travel to reach the nearest edge.
Expand Down Expand Up @@ -138,6 +171,11 @@ def _pixel_distance_from_edge(
Inspired by
https://github.com/InsightSoftwareConsortium/ITKMontage/blob/master/include/itkTileMergeImageFilter.hxx#L217
:param input_pt: The point to transform.
:param domain: The transform domain to consider.
:return: The shortest pixel distance to an edge along each axis
in ITK access order (I,J,K)
"""
VOXEL_HALF_STEP = [0.5] * 3
dist_to_lower_bound = np.array(
Expand All @@ -156,12 +194,14 @@ def _pixel_distance_from_edge(
)
return axis_mins

@staticmethod
def _resolve_displacements(vecs: List[itk.Vector]) -> npt.ArrayLike:
return np.mean(vecs, axis=0)

@staticmethod
def _validate_entry(entry: TransformEntry) -> None:
"""
Validate that an input transform entry is valid.
:param entry: The bounded or unbounded transform entry.
:raises TypeError: If either the transform or transform domain type is invalid.
"""
if not issubclass(
type(entry.transform), itk.Transform[itk.D, 3, 3]
) and not issubclass(type(entry.transform), itk.Transform[itk.F, 3, 3]):
Expand All @@ -176,9 +216,17 @@ def _validate_entry(entry: TransformEntry) -> None:

def __init__(
self,
transform_and_domain_list: List[Type[TransformEntry]] = None,
transform_and_domain_list: List[TransformEntry] = None,
blend_method: Callable[[itk.Point, List[TransformEntry]], itk.Point] = None,
):
"""
Initialize a new `TransformCollection`.
:param transform_and_domain_list: The list of transforms and associated transform domains
to inform `TransformCollection` behavior.
:param blend_method: The method to use to blend among output candidates in the case of
overlapping transform domains.
"""
transform_and_domain_list = transform_and_domain_list or []
for entry in transform_and_domain_list:
TransformCollection._validate_entry(entry)
Expand All @@ -189,11 +237,28 @@ def __init__(
)
self.transform_and_domain_list = transform_and_domain_list

def push(self, entry: Type[TransformEntry]) -> None:
def push(self, entry: TransformEntry) -> None:
"""
Add a new bounded or unbounded transform to the underlying collection.
:param entry: The transform and domain to add.
:raises TypeError: If the entry is invalid.
"""
TransformCollection._validate_entry(entry)
self.transform_and_domain_list.append(entry)

def transform_point(self, pt: itk.Point[itk.F, 3]) -> npt.ArrayLike:
def transform_point(self, pt: itk.Point[itk.F, 3]) -> itk.Point[itk.F, 3]:
"""
Transforms an input physical point (X,Y,Z) by the piecewise transform
relationship developed by underlying bounded transforms and the
selected blending method.
:param pt: The physical point (X,Y,Z) to transform.
:return: The transformed point (X,Y,Z) obtained after blending among
point outputs from each viable transform candidate.
:raises ValueError: If the input point does not fall within any
of the transform domains contained within the `TransformCollection`.
"""
region_contributors = [
entry
for entry in self.transform_and_domain_list
Expand All @@ -209,4 +274,10 @@ def transform_point(self, pt: itk.Point[itk.F, 3]) -> npt.ArrayLike:
return itk.Point[itk.F, 3](self.blend_method(pt, region_contributors))

def TransformPoint(self, pt: itk.Point[itk.F, 3]) -> npt.ArrayLike:
"""
`itk.Transform`-like interface to transform a point by the
transform relationship developed by this `TransformCollection` instance.
See `transform_point` documentation.
"""
return self.transform_point(pt)
Loading

0 comments on commit 735c489

Please sign in to comment.