Skip to content

Commit

Permalink
Merge pull request mlcommons#965 from szmazurek/refactor_loss_interfaces
Browse files Browse the repository at this point in the history
Refactor the code related to loss computation
  • Loading branch information
sarthakpati authored Nov 19, 2024
2 parents a1fb3f4 + 96b64e4 commit 709f6ab
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 2 deletions.
1 change: 0 additions & 1 deletion GANDLF/losses/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch

from .segmentation import MCD_loss, FocalLoss
from .regression import CCE_Generic, CE, CE_Logits

Expand Down
21 changes: 21 additions & 0 deletions GANDLF/losses/hybrid_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .regression_new import BinaryCrossEntropyLoss, BinaryCrossEntropyWithLogitsLoss
from .segmentation_new import MulticlassDiceLoss, MulticlassFocalLoss
from .loss_interface import AbstractHybridLoss


class DiceCrossEntropyLoss(AbstractHybridLoss):
def _initialize_all_loss_calculators(self):
return [MulticlassDiceLoss(self.params), BinaryCrossEntropyLoss(self.params)]


class DiceCrossEntropyLossLogits(AbstractHybridLoss):
def _initialize_all_loss_calculators(self):
return [
MulticlassDiceLoss(self.params),
BinaryCrossEntropyWithLogitsLoss(self.params),
]


class DiceFocalLoss(AbstractHybridLoss):
def _initialize_all_loss_calculators(self):
return [MulticlassDiceLoss(self.params), MulticlassFocalLoss(self.params)]
153 changes: 153 additions & 0 deletions GANDLF/losses/loss_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from torch import nn
from abc import ABC, abstractmethod
from typing import List


class AbstractLossFunction(nn.Module, ABC):
def __init__(self, params: dict):
nn.Module.__init__(self)
self.params = params
self.num_classes = len(params["model"]["class_list"])
self._initialize_penalty_weights()

def _initialize_penalty_weights(self):
default_penalty_weights = torch.ones(self.num_classes)
self.penalty_weights = self.params.get(
"penalty_weights", default_penalty_weights
)

@abstractmethod
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the loss function. To be implemented by child classes.
"""


class AbstractSegmentationLoss(AbstractLossFunction):
"""
Base class for loss funcions that are used for segmentation tasks.
"""

def __init__(self, params: dict):
super().__init__(params)

def _compute_single_class_loss(
self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int
) -> torch.Tensor:
"""Compute loss for a single class."""
loss_value = self._single_class_loss_calculator(
prediction[:, class_idx, ...], target[:, class_idx, ...]
)
return 1 - loss_value

def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor:
"""
Perform addtional operations on the loss value. Defaults to identity operation.
If needed, child classes can override this method. Useful in cases where
for example, the loss value needs to log-transformed or clipped.
"""
return loss

@abstractmethod
def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Compute loss for a pair of prediction and target tensors. To be implemented by child classes.
"""

def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
accumulated_loss = torch.tensor(0.0, device=prediction.device)

for class_idx in range(self.num_classes):
current_loss = self._compute_single_class_loss(
prediction, target, class_idx
)
accumulated_loss += (
self._optional_loss_operations(current_loss)
* self.penalty_weights[class_idx]
)

accumulated_loss /= self.num_classes

return accumulated_loss


class AbstractRegressionLoss(AbstractLossFunction):
"""
Base class for loss functions that are used for regression and classification tasks.
"""

def __init__(self, params: dict):
super().__init__(params)
self.loss_calculator = self._initialize_loss_function_object()
self.reduction_method = self._initialize_reduction_method()

def _initialize_reduction_method(self) -> str:
"""
Initialize the reduction method for the loss function. Defaults to 'mean'.
"""
loss_params = self.params["loss_function"]
reduction_method = "mean"
if isinstance(loss_params, dict):
reduction_method = loss_params.get("reduction", reduction_method)
assert reduction_method in [
"mean",
"sum",
], f"Invalid reduction method defined for loss function: {reduction_method}. Valid options are ['mean', 'sum']"
return reduction_method

def _calculate_loss_for_single_class(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Calculate loss for a single class. To be implemented by child classes.
"""
return self.loss_calculator(prediction, target)

@abstractmethod
def _initialize_loss_function_object(self) -> nn.modules.loss._Loss:
"""
Initialize the loss function object used in the forward method. Has to return
callable pytorch loss function object.
"""

def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
accumulated_loss = torch.tensor(0.0, device=prediction.device)
for class_idx in range(self.num_classes):
accumulated_loss += (
self._calculate_loss_for_single_class(
prediction[:, class_idx, ...], target[:, class_idx, ...]
)
* self.penalty_weights[class_idx]
)

accumulated_loss /= self.num_classes

return accumulated_loss


class AbstractHybridLoss(AbstractLossFunction):
"""
Base class for hybrid loss functions that are used for segmentation tasks.
"""

def __init__(self, params: dict):
super().__init__(params)
self.loss_calculators = self._initialize_all_loss_calculators()

@abstractmethod
def _initialize_all_loss_calculators(self) -> List[AbstractLossFunction]:
"""
Each hybrid loss should implement this method, creating all loss functions as a list that
will be used during the forward pass.
"""
pass

def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
accumulated_loss = torch.tensor(0.0, device=prediction.device)
for loss_calculator in self._initialize_all_loss_calculators():
accumulated_loss += loss_calculator(prediction, target)

return accumulated_loss
2 changes: 1 addition & 1 deletion GANDLF/losses/regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from GANDLF.utils import one_hot
from torch.nn import CrossEntropyLoss


def CEL(
Expand Down
64 changes: 64 additions & 0 deletions GANDLF/losses/regression_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from torch import nn
from .loss_interface import AbstractRegressionLoss


class CrossEntropyLoss(AbstractRegressionLoss):
"""
This class computes the cross entropy loss between two tensors.
"""

def _initialize_loss_function_object(self):
return nn.CrossEntropyLoss(reduction=self.reduction_method)


class BinaryCrossEntropyLoss(AbstractRegressionLoss):
"""
This class computes the binary cross entropy loss between two tensors.
"""

def _initialize_loss_function_object(self):
return nn.BCELoss(reduction=self.reduction_method)


class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss):
"""
This class computes the binary cross entropy loss with logits between two tensors.
"""

def _initialize_loss_function_object(self):
return nn.BCEWithLogitsLoss(reduction=self.reduction_method)


class BaseLossWithScaledTarget(AbstractRegressionLoss):
"""
General interface for the loss functions requiring scaling of the target tensor.
"""

def _initialize_scaling_factor(self):
loss_params: dict = self.params["loss_function"]
self.scaling_factor = loss_params.get("scaling_factor", 1.0)
if isinstance(loss_params, dict):
self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor)
return self.scaling_factor

def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor):
return self.loss_calculator(prediction, target * self.scaling_factor)


class L1Loss(BaseLossWithScaledTarget):
"""
This class computes the L1 loss between two tensors.
"""

def _initialize_loss_function_object(self):
return nn.L1Loss(reduction=self.reduction_method)


class MSELoss(BaseLossWithScaledTarget):
"""
This class computes the mean squared error loss between two tensors.
"""

def _initialize_loss_function_object(self):
return nn.MSELoss(reduction=self.reduction_method)
Loading

0 comments on commit 709f6ab

Please sign in to comment.