Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Port MONAI Generative utils #7134

Merged
merged 17 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ State Cacher
------------
.. automodule:: monai.utils.state_cacher
:members:

Component store
---------------
.. autoclass:: monai.utils.component_store.ComponentStore
:members:
6 changes: 6 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
AdversarialIterationEvents,
AdversarialKeys,
AlgoKeys,
Average,
BlendMode,
Expand Down Expand Up @@ -47,6 +49,8 @@
MetricReduction,
NdimageMode,
NumpyPadMode,
OrderingTransformations,
OrderingType,
PatchKeys,
PostFix,
ProbMapKeys,
Expand Down Expand Up @@ -95,6 +99,8 @@
str2bool,
str2list,
to_tuple_of_dictionaries,
unsqueeze_left,
unsqueeze_right,
zip_with,
)
from .module import (
Expand Down
65 changes: 65 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

import random
from enum import Enum
from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.utils import deprecated
from monai.utils.module import min_version, optional_import

__all__ = [
"StrEnum",
Expand Down Expand Up @@ -88,6 +91,14 @@ def __repr__(self):
return self.value


if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
EventEnum, _ = optional_import(
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
)


class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
Expand Down Expand Up @@ -692,3 +703,57 @@ class AlgoKeys(StrEnum):
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"


class AdversarialKeys(StrEnum):
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""
Keys used by the AdversarialTrainer.
`REALS` are real images from the batch.
`FAKES` are fake images generated by the generator. Are the same as PRED.
`REAL_LOGITS` are logits of the discriminator for the real images.
`FAKE_LOGIT` are logits of the discriminator for the fake images.
`RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function.
`GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the
discriminator loss for the fake images. That is backpropagated through the generator only.
`DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the
discriminator loss for the real images and the fake images. That is backpropagated through the
discriminator only.
"""

REALS = "reals"
REAL_LOGITS = "real_logits"
FAKES = "fakes"
FAKE_LOGITS = "fake_logits"
RECONSTRUCTION_LOSS = "reconstruction_loss"
GENERATOR_LOSS = "generator_loss"
DISCRIMINATOR_LOSS = "discriminator_loss"


class AdversarialIterationEvents(EventEnum):
"""
Keys used to define events as used in the AdversarialTrainer.
"""

RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"


class OrderingType(StrEnum):
RASTER_SCAN = "raster_scan"
S_CURVE = "s_curve"
RANDOM = "random"


class OrderingTransformations(StrEnum):
ROTATE_90 = "rotate_90"
TRANSPOSE = "transpose"
REFLECT = "reflect"
10 changes: 10 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
sqrt_num = [int(math.sqrt(_num)) for _num in num]
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
return ensure_tuple(ret) == num


def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(...,) + (None,) * (ndim - arr.ndim)]


def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]
71 changes: 71 additions & 0 deletions tests/test_squeeze_unsqueeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.utils import unsqueeze_left, unsqueeze_right

RIGHT_CASES = [
(np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)),
(np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)),
(np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)),
]


LEFT_CASES = [
(np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)),
(np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)),
(np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)),
]
ALL_CASES = [
(np.random.rand(3, 4), 2, (3, 4)),
(np.random.rand(3, 4), 0, (3, 4)),
(np.random.rand(3, 4), -1, (3, 4)),
(np.array(3), 4, (1, 1, 1, 1)),
(np.array(3), 0, ()),
(np.random.rand(3, 4).astype(np.int32), 2, (3, 4)),
(np.random.rand(3, 4).astype(np.int32), 0, (3, 4)),
(np.random.rand(3, 4).astype(np.int32), -1, (3, 4)),
(np.array(3).astype(np.int32), 4, (1, 1, 1, 1)),
(np.array(3).astype(np.int32), 0, ()),
(torch.rand(3, 4), 2, (3, 4)),
(torch.rand(3, 4), 0, (3, 4)),
(torch.rand(3, 4), -1, (3, 4)),
(torch.tensor(3), 4, (1, 1, 1, 1)),
(torch.tensor(3), 0, ()),
(torch.rand(3, 4).type(torch.int32), 2, (3, 4)),
(torch.rand(3, 4).type(torch.int32), 0, (3, 4)),
(torch.rand(3, 4).type(torch.int32), -1, (3, 4)),
(torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)),
(torch.tensor(3).type(torch.int32), 0, ()),
]


class TestUnsqueeze(unittest.TestCase):
@parameterized.expand(RIGHT_CASES + ALL_CASES)
def test_unsqueeze_right(self, arr, ndim, shape):
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)

@parameterized.expand(LEFT_CASES + ALL_CASES)
def test_unsqueeze_left(self, arr, ndim, shape):
self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)