Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactor itk-dreg abstract interface for plugin methods #2

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/

46 changes: 0 additions & 46 deletions registration_interface.py

This file was deleted.

1 change: 1 addition & 0 deletions src/itk_dreg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
Empty file added src/itk_dreg/base/__init__.py
Empty file.
151 changes: 151 additions & 0 deletions src/itk_dreg/base/image_block_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env python3

from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, List

import dask.array
import numpy.typing as npt

from .itk_typing import ImageType, TransformType

"""
Define common data structures for managing block regions and registration output.
"""


class BlockRegStatus(IntEnum):
"""
Status codes indicating the registration outcome from a single block pair.

TODO: To be extended with more granular error codes for `itk-dreg` infrastructure.
"""

SUCCESS = 0
"""Registration yielded at least a forward transform result."""

FAILURE = 1
"""Registration encountered an unspecified error."""


@dataclass
class BlockInfo:
"""
Header information describing the position of a lazy dask subvolume (block)
in voxel space with respect to a parent volume.

Accessors are in NumPy order: (K,J,I) where K is slowest, I is fastest
"""

chunk_index: List[int]
"""
The chunk position in the greater image volume in terms of chunk counts.
For instance, if traversing 2x2x2 volume along the fastest axis first:
- the 0th chunk would have chunk index (0,0,0),
- the 1st chunk would have chunk index (0,0,1),
- the 7th chunk would have chunk index (1,0,0)
"""

array_slice: List[slice]
"""
The chunk position in the greater image volume in terms of voxel access.
For instance, if a 100x100x100 volume is evenly subdivided into 10x10x10 chunks,
the first chunk would slice along [(0,10,1),(0,10,1),(0,10,1)].
"""


@dataclass
class LocatedBlock:
"""
Combined header and data access to get a lazy dask volume with respect
to a parent volume in voxel space.

Accessors are in NumPy order: (K,J,I) where K is slowest, I is fastest
"""

loc: BlockInfo
"""
The location of the block relative to the parent image voxel array.
"""

arr: dask.array.core.Array
"""
The dask volume for lazy voxel access.
"""


@dataclass
class BlockPairRegistrationResult:
"""Encapsulate result of fixed-to-moving registration over one block pair."""

transform: Optional[TransformType]
"""
The forward transform registration result, if any.
The forward transform maps from moving to fixed space.
"""

transform_domain: Optional[ImageType]
"""
Oriented representation of the domain over which the forward transform is valid.
`transform_domain` has no voxel data and serves as a metadata representation of an
oriented bounding box in physical space.
`transform_domain` must be available if and only if `transform` is available.
"""

inv_transform: Optional[TransformType]
"""
The inverse transform registration result, if any.
The inverse transform maps from fixed to moving space.
If `inv_transform` is available then `transform` must also be available.
"""

inv_transform_domain: Optional[ImageType]
"""
Oriented representation of the domain over which the inverse transform is valid.
`inv_transform_domain` has no voxel data and serves as a metadata representation of an
oriented bounding box in physical space.
`inv_transform_domain` must be available if and only if `inv_transform` is available.
"""

status: BlockRegStatus
"""Status code indicating registration success or failure."""


@dataclass
class RegistrationTransformResult:
"""
Encapsulate result of fixed-to-moving registration over all block pairs.
"""

transform: TransformType
"""
The forward transform resulting from block postprocessing.
The forward transform maps from moving to fixed space.
"""

inv_transform: Optional[TransformType]
"""
The inverse transform registration result from block postprocessing, if any.
The inverse transform maps from fixed to moving space.
If `inv_transform` is available then `transform` must also be available.
"""


@dataclass
class RegistrationResult:
"""
Encapsulate result of fixed-to-moving registration over all block pairs.
"""

transforms: RegistrationTransformResult
"""The forward and inverse transforms resulting from registration."""

status: npt.ArrayLike
"""
`status` is an ND array where each element reflects the status code output
for block pair registration over the corresponding input moving chunk.

`status` has the same shape as the moving input array of chunks.
That is, if the moving input array is subdivided into 2 chunks x 3 chunks x 4 chunks,
`status` will be an array of voxels with shape [2,3,4].
"""
28 changes: 28 additions & 0 deletions src/itk_dreg/base/itk_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

from typing import Union

import itk

"""
Define common "union" type hints for floating-point registration
in two or three dimensional space.
"""

ImagePixelType = itk.F
FloatImage2DType = itk.Image[ImagePixelType, 2]
FloatImage3DType = itk.Image[ImagePixelType, 3]
ImageType = Union[FloatImage2DType, FloatImage3DType]

ImageRegion2DType = itk.ImageRegion[2]
ImageRegion3DType = itk.ImageRegion[3]
ImageRegionType = Union[ImageRegion2DType, ImageRegion3DType]

FloatImage2DReaderType = itk.ImageFileReader[FloatImage2DType]
FloatImage3DReaderType = itk.ImageFileReader[FloatImage3DType]
ImageReaderType = Union[FloatImage2DReaderType, FloatImage3DReaderType]

TransformScalarType = itk.D
TransformType = Union[
itk.Transform[TransformScalarType, 2], itk.Transform[TransformScalarType, 3]
]
158 changes: 158 additions & 0 deletions src/itk_dreg/base/registration_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python3

import itk
from abc import ABC, abstractmethod
from typing import Optional, Iterable, Tuple

from .image_block_interface import (
BlockPairRegistrationResult,
RegistrationTransformResult,
)
from .itk_typing import ImageType, ImageReaderType, ImageRegionType, TransformType

"""
Defines extensible components to extend with concrete implementations.
"""


class ConstructReaderMethod(ABC):
"""
A method that generates a new `itk.ImageFileReader` for image registration.

ITK provides the `itk.ImageFileReader` mechanism to retrieve all or part of
a spatial image from a provided local or remote image source. `itk-dreg`
registration infrastructure attempts to stream image subregions into memory
at runtime in order to perform block-based pairwise registration without
ever loading an entire image into memory at once.

Extend this class to customize the image reading step for cases such as
to attach domain-specific metadata or to convert from a reference type
not supported by ITK by default.

The resulting `itk.ImageFileReader` object MUST be initialized with metadata
to represent the extent of the underlying data in voxel and physical space.

It is strongly recommended that the resulting `itk.ImageFileReader` is NOT
initialized with underlying voxel data. Voxel regions should be lazily
initialized by `itk-dreg` registration infrastructure to match block
requested regions.

.. code-block:: python

ReaderSource = ConstructReaderMethodSubclass(...)
image_source = ReaderSource()
image = image_source.UpdateLargestPossibleRegion()
"""

@abstractmethod
def __call__(self, **kwargs) -> ImageReaderType:
tbirdso marked this conversation as resolved.
Show resolved Hide resolved
pass


class BlockPairRegistrationMethod(ABC):
"""
A method that registers two spatially located image blocks together.

`fixed_subimage` and `moving_image` inputs are `itk.Image` representations
of block subregions within greater fixed and moving inputs.

Extend this class to implement custom registration method that plugs in
to `itk-dreg` registration infrastructure.
"""

@abstractmethod
def __call__(
self,
fixed_subimage: ImageType,
moving_subimage: ImageType,
initial_transform: TransformType,
**kwargs
) -> BlockPairRegistrationResult:
"""
Run image-to-image pairwise registration.

:param fixed_subimage: The reference fixed subimage.
`fixed_subimage.RequestedRegion` reflects the requested subregion corresponding
to the scheduled dask array chunk.
The initial `fixed_subimage.BufferedRegion` includes the requested region
and possibly an extra padding factor introduced before fetching fixed
image voxel data.
:param moving_subimage: The moving subimage to be registered onto fixed image space.
`moving_subimage.RequestedRegion` reflects the requested subregion
corresponding to the approximate physical bounds of `fixed_subimage.RequestedRegion`
after initialization with `initial_transform`.
The initial `moving_subimage.BufferedRegion` includes the requested region
and possibly an extra padding factor introduced before fetching fixed
image voxel data.
:param initial_transform: The forward transform representing an initial alignment
mapping from fixed to moving image space.
"""
pass


class ReduceResultsMethod(ABC):
"""
A method that reduces a sparse collection of pairwise block registration results
to yield a generalized fixed-to-moving transform.

Extend this class to implement a custom method mapping block results to a general transform.

Possible implementations could include methods for finding global consensus among results,
stitching methods to yield a piecewise transform, or patchwise methods to normalize among
bounded transform domains.
"""

@abstractmethod
def __call__(
self,
block_results: Iterable[Tuple[ImageType, BlockPairRegistrationResult]],
fixed_reader_ctor: ConstructReaderMethod,
initial_transform: itk.Transform,
**kwargs
) -> RegistrationTransformResult:
"""
:param block_results: An iterable collection of subimages in fixed space
and the corresponding registration result for the given subimage.
Subimages are not buffered and represent the subdomains within the
original fixed image space prior to initial transform application.
:param fixed_reader_ctor: Method to create an image reader to stream
part or all of the fixed image.
:param initial_transform: Initial forward transform used in registration.
The forward transform maps from the fixed to moving image.
"""
pass


"""
my_fixed_image = ...
my_moving_image = ...

my_initial_transform = ...

# registration method returns an update to the initial transform

my_transform = itk.register_dreg(
construct_fixed_image_method=my_construct_streaming_reader_method,
construct_moving_image_method=my_construct_streaming_reader_method,
initial_transform=my_initial_transform,
registration_method=my_block_pair_registration_method_subclass,
reduce_method=my_reduce_registration_method_subclass
)

final_transform = itk.CompositeTransform()
final_transform.append_transform(my_initial_transform)
final_transform.append_transform(my_transform)

# we can use the result transform to resample the moving image to fixed image space

interpolator = itk.LinearInterpolateImageFunction.New(my_moving_image)

my_warped_image = itk.resample_image_filter(
my_moving_image,
transform=final_transform,
interpolator=interpolator,
use_reference_image=True,
reference_image=my_fixed_image
)

"""
Loading