Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of intensity clipping transform: bot hard and soft approaches #7535

Merged
merged 21 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
33af45c
Implementation of intensity clipping transform: bot hard clipping and…
Lucas-rbnt Mar 12, 2024
8d77605
correct soft_clip typing
Lucas-rbnt Mar 12, 2024
be08465
Merge branch 'dev' into percentile-clipper
KumoLiu Mar 22, 2024
ffc303c
Merge branch 'dev' of github.com:Lucas-rbnt/MONAI into percentile-cli…
Lucas-rbnt Mar 29, 2024
fc57eda
clarification on docstring, add 3d tests
Lucas-rbnt Mar 29, 2024
700b947
Merge branch 'Project-MONAI:dev' into percentile-clipper
Lucas-rbnt Mar 29, 2024
189cad4
Merge branch 'percentile-clipper' of github.com:Lucas-rbnt/MONAI into…
Lucas-rbnt Mar 29, 2024
800a051
fixing typo in ClipIntensityPercentile argument
Lucas-rbnt Mar 30, 2024
f5fc217
Merge branch 'dev' of github.com:Lucas-rbnt/MONAI into percentile-cli…
Lucas-rbnt Apr 2, 2024
b82ec99
Adding possibility to return percentiles in tensor metainfo
Lucas-rbnt Apr 2, 2024
37b38a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
3468369
fixing flake8 linting error
Lucas-rbnt Apr 2, 2024
72835e7
Merge branch 'percentile-clipper' of github.com:Lucas-rbnt/MONAI into…
Lucas-rbnt Apr 2, 2024
d75e32a
changing percentiles meta info to clipping values and fixing mypy errors
Lucas-rbnt Apr 4, 2024
ca2828f
Merge branch 'dev' of github.com:Lucas-rbnt/MONAI into percentile-cli…
Lucas-rbnt Apr 4, 2024
6dacece
fixing typo with ./runtests.sh --autofix
Lucas-rbnt Apr 4, 2024
6d5187c
typing correction, docstring clarification, and change in attribute n…
Lucas-rbnt Apr 4, 2024
6eb1eb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
89a4f5f
mypy compliant and dealing with potential float or int in clipping va…
Lucas-rbnt Apr 5, 2024
1d4fe82
Merge branch 'percentile-clipper' of github.com:Lucas-rbnt/MONAI into…
Lucas-rbnt Apr 5, 2024
500036a
Merge branch 'dev' into percentile-clipper
KumoLiu Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ Intensity
:members:
:special-members: __call__

`ClipIntensityPercentiles`
""""""""""""""""""""""""""
.. autoclass:: ClipIntensityPercentiles
:members:
:special-members: __call__

`RandScaleIntensity`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensity.png
Expand Down Expand Up @@ -1405,6 +1411,12 @@ Intensity (Dict)
:members:
:special-members: __call__

`ClipIntensityPercentilesd`
"""""""""""""""""""""""""""
.. autoclass:: ClipIntensityPercentilesd
:members:
:special-members: __call__

`RandScaleIntensityd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensityd.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd
from .intensity.array import (
AdjustContrast,
ClipIntensityPercentiles,
ComputeHoVerMaps,
DetectEnvelope,
ForegroundMask,
Expand Down Expand Up @@ -135,6 +136,9 @@
AdjustContrastd,
AdjustContrastD,
AdjustContrastDict,
ClipIntensityPercentilesd,
ClipIntensityPercentilesD,
ClipIntensityPercentilesDict,
ComputeHoVerMapsd,
ComputeHoVerMapsD,
ComputeHoVerMapsDict,
Expand Down
115 changes: 114 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils.enums import TransformBackends
from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
Expand All @@ -54,6 +54,7 @@
"NormalizeIntensity",
"ThresholdIntensity",
"ScaleIntensityRange",
"ClipIntensityPercentiles",
"AdjustContrast",
"RandAdjustContrast",
"ScaleIntensityRangePercentiles",
Expand Down Expand Up @@ -1007,6 +1008,118 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return ret


class ClipIntensityPercentiles(Transform):
"""
Apply clip based on the intensity distribution of input image.
If `sharpness_factor` is provided, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291

Soft clipping preserves the order of the values and maintains the gradient everywhere.
For example:

.. code-block:: python
:emphasize-lines: 11, 22

image = torch.Tensor(
[[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])

# Hard clipping from lower and upper image intensity percentiles
hard_clipper = ClipIntensityPercentiles(30, 70)
print(hard_clipper(image))
metatensor([[[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.]]])


# Soft clipping from lower and upper image intensity percentiles
soft_clipper = ClipIntensityPercentiles(30, 70, 10.)
print(soft_clipper(image))
metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]])

See Also:

- :py:class:`monai.transforms.ScaleIntensityRangePercentiles`
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
lower: float | None,
upper: float | None,
sharpness_factor: float | None = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
dongyang0122 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Args:
lower: lower intensity percentile.
upper: upper intensity percentile.
sharpness_factor: if not None, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)).
defaults to None.
channel_wise: if True, compute intensity percentile and normalize every channel separately.
default to False.
dtype: output data type, if None, same as input image. defaults to float32.
"""
if lower is None and upper is None:
raise ValueError("lower or upper percentiles must be provided")
if lower is not None and (lower < 0.0 or lower > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and (upper < 0.0 or upper > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and lower is not None and upper < lower:
raise ValueError("upper must be greater than or equal to lower")
if sharpness_factor is not None and sharpness_factor <= 0:
raise ValueError("sharpness_factor must be greater than 0")

self.lower = lower
self.upper = upper
self.sharpness_factor = sharpness_factor
self.channel_wise = channel_wise
self.dtype = dtype

def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.sharpness_factor is not None:
lower_percentile = percentile(img, self.lower) if self.lower is not None else None
upper_percentile = percentile(img, self.upper) if self.upper is not None else None
img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
else:
lower_percentile = percentile(img, self.lower) if self.lower is not None else percentile(img, 0)
upper_percentile = percentile(img, self.upper) if self.upper is not None else percentile(img, 100)
img = clip(img, lower_percentile, upper_percentile)

img = convert_to_tensor(img, track_meta=False)
return img

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
if self.channel_wise:
img_t = torch.stack([self._normalize(img=d) for d in img_t]) # type: ignore
else:
img_t = self._normalize(img=img_t)

return convert_to_dst_type(img_t, dst=img)[0]


class AdjustContrast(Transform):
"""
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
Expand Down
35 changes: 35 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.data.meta_obj import get_track_meta
from monai.transforms.intensity.array import (
AdjustContrast,
ClipIntensityPercentiles,
ComputeHoVerMaps,
ForegroundMask,
GaussianSharpen,
Expand Down Expand Up @@ -77,6 +78,7 @@
"NormalizeIntensityd",
"ThresholdIntensityd",
"ScaleIntensityRanged",
"ClipIntensityPercentilesd",
"AdjustContrastd",
"RandAdjustContrastd",
"ScaleIntensityRangePercentilesd",
Expand Down Expand Up @@ -122,6 +124,8 @@
"ThresholdIntensityDict",
"ScaleIntensityRangeD",
"ScaleIntensityRangeDict",
"ClipIntensityPercentilesD",
"ClipIntensityPercentilesDict",
"AdjustContrastD",
"AdjustContrastDict",
"RandAdjustContrastD",
Expand Down Expand Up @@ -886,6 +890,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ClipIntensityPercentilesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ClipIntensityPercentiles`.
Clip the intensity values of input image to a specific range based on the intensity distribution of the input.
If `sharpness_factor` is provided, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor) * softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
"""

def __init__(
self,
keys: KeysCollection,
lower: float | None,
upper: float | None,
sharpness_factor: float | None = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ClipIntensityPercentiles(
lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
)

def __call__(self, data: dict) -> dict:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
return d


class AdjustContrastd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AdjustContrast`.
Expand Down Expand Up @@ -1929,6 +1963,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
NormalizeIntensityD = NormalizeIntensityDict = NormalizeIntensityd
ThresholdIntensityD = ThresholdIntensityDict = ThresholdIntensityd
ScaleIntensityRangeD = ScaleIntensityRangeDict = ScaleIntensityRanged
ClipIntensityPercentilesD = ClipIntensityPercentilesDict = ClipIntensityPercentilesd
AdjustContrastD = AdjustContrastDict = AdjustContrastd
RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd
ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd
Expand Down
37 changes: 37 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
nonzero,
ravel,
searchsorted,
softplus,
unique,
unravel_index,
where,
Expand Down Expand Up @@ -131,9 +132,45 @@
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
"soft_clip",
]


def soft_clip(
arr: NdarrayOrTensor,
sharpness_factor: float = 1.0,
minv: NdarrayOrTensor | float | int | None = None,
maxv: NdarrayOrTensor | float | int | None = None,
dtype: DtypeLike | torch.dtype = np.float32,
) -> NdarrayOrTensor:
"""
Apply soft clip to the input array or tensor.
The intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291

To perform one-sided clipping, set either minv or maxv to None.
Args:
arr: input array to clip.
sharpness_factor: the sharpness of the soft clip function, default to 1.
minv: minimum value of target clipped array.
maxv: maximum value of target clipped array.
dtype: if not None, convert input array to dtype before computation.

"""

if dtype is not None:
arr, *_ = convert_data_type(arr, dtype=dtype)

v = arr
if minv is not None:
v = v + softplus(-sharpness_factor * (arr - minv)) / sharpness_factor
if maxv is not None:
v = v - softplus(sharpness_factor * (arr - maxv)) / sharpness_factor

return v


def rand_choice(prob: float = 0.5) -> bool:
"""
Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
Expand Down
15 changes: 15 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,24 @@
"median",
"mean",
"std",
"softplus",
]


def softplus(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""stable softplus through `np.logaddexp` with equivalent implementation for torch.

Args:
x: array/tensor.

Returns:
Softplus of the input.
"""
if isinstance(x, np.ndarray):
return np.logaddexp(np.zeros_like(x), x)
return torch.logaddexp(torch.zeros_like(x), x)


def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool:
"""`np.allclose` with equivalent implementation for torch."""
b, *_ = convert_to_dst_type(b, a, wrap_sequence=True)
Expand Down
Loading
Loading