Skip to content

Commit

Permalink
Adds new enums
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Oct 16, 2023
1 parent 8a70678 commit 7aaeab0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
4 changes: 4 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
AdversarialIterationEvents,
AdversarialKeys,
AlgoKeys,
Average,
BlendMode,
Expand Down Expand Up @@ -46,6 +48,8 @@
MetricReduction,
NdimageMode,
NumpyPadMode,
OrderingTransformations,
OrderingType,
PatchKeys,
PostFix,
ProbMapKeys,
Expand Down
47 changes: 47 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

import random
from enum import Enum
from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.utils import deprecated
from monai.utils.module import min_version, optional_import

__all__ = [
"StrEnum",
Expand Down Expand Up @@ -88,6 +91,14 @@ def __repr__(self):
return self.value


if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
EventEnum, _ = optional_import(
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
)


class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
Expand Down Expand Up @@ -692,3 +703,39 @@ class AlgoKeys(StrEnum):
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"


class AdversarialKeys(StrEnum):
REALS = "reals"
REAL_LOGITS = "real_logits"
FAKES = "fakes"
FAKE_LOGITS = "fake_logits"
RECONSTRUCTION_LOSS = "reconstruction_loss"
GENERATOR_LOSS = "generator_loss"
DISCRIMINATOR_LOSS = "discriminator_loss"


class AdversarialIterationEvents(EventEnum):
RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"


class OrderingType(StrEnum):
RASTER_SCAN = "raster_scan"
S_CURVE = "s_curve"
RANDOM = "random"


class OrderingTransformations(StrEnum):
ROTATE_90 = "rotate_90"
TRANSPOSE = "transpose"
REFLECT = "reflect"

0 comments on commit 7aaeab0

Please sign in to comment.