Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MedicalNetPerceptualSimilarity: Add multi-channel #7568

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0250284
MedicalNetPerceptualSimilarity: Add multi-channel
SomeUserName1 Mar 22, 2024
dd84afc
Add channelwise flag to perceputal loss
SomeUserName1 Mar 25, 2024
2a4a551
DCO Remediation Commit for Fabian Klopfer <[email protected]>
SomeUserName1 Mar 25, 2024
94d56cb
Correct code formatting
SomeUserName1 Mar 25, 2024
4664a62
PerceptualLoss: correct summation over features per channel, calculat…
SomeUserName1 Mar 25, 2024
a7aae86
Fixup flake type error
SomeUserName1 Mar 25, 2024
8afd921
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
SomeUserName1 Mar 26, 2024
60cc8ce
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
SomeUserName1 Mar 27, 2024
b20005c
Merge remote-tracking branch 'main/dev' into 7567-multi-channel-medic…
SomeUserName1 Apr 3, 2024
af33a8e
fix typo: channelwise to channel_wise
SomeUserName1 Apr 3, 2024
b2bd8be
correct typo in doc strings and tests too
SomeUserName1 Apr 5, 2024
29e187d
fix formatting
SomeUserName1 Apr 9, 2024
eea576f
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
SomeUserName1 Apr 9, 2024
129e429
Delete tags
SomeUserName1 Apr 10, 2024
ba55fae
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
SomeUserName1 Apr 17, 2024
93f5208
add flag to squeeze and compute mean only on dim 0
SomeUserName1 Apr 18, 2024
220febd
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
KumoLiu Apr 19, 2024
913a2fc
merge two lines into one
SomeUserName1 Apr 19, 2024
248d309
Merge branch 'dev' into 7567-multi-channel-medicalnetperceptualsimila…
SomeUserName1 Apr 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module):

The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all
three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.
MedicalNet networks are only compatible with 3D inputs and support channel-wise loss.

Args:
spatial_dims: number of spatial dimensions.
Expand All @@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module):
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""

def __init__(
Expand All @@ -74,6 +77,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
channel_wise: bool = False,
):
super().__init__()

Expand All @@ -86,6 +90,9 @@ def __init__(
"Argument is_fake_3d must be set to False."
)

if channel_wise and "medicalnet_" not in network_type:
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")

if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
Expand All @@ -102,7 +109,9 @@ def __init__(
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = MedicalNetPerceptualSimilarity(
net=network_type, verbose=False, channel_wise=channel_wise
)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
Expand Down Expand Up @@ -170,9 +179,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = loss_sagittal + loss_axial + loss_coronal
else:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)
loss = self.perceptual_function(input, target).squeeze()

return torch.mean(loss)
return torch.mean(loss, dim=0)
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved


class MedicalNetPerceptualSimilarity(nn.Module):
Expand All @@ -185,14 +194,20 @@ class MedicalNetPerceptualSimilarity(nn.Module):
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""

def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
def __init__(
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()

self.channel_wise = channel_wise

for param in self.parameters():
param.requires_grad = False

Expand All @@ -206,20 +221,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.

"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)

# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
feats_per_ch = 0
for ch_idx in range(input.shape[1]):
input_channel = input[:, ch_idx, ...].unsqueeze(1)
target_channel = target[:, ch_idx, ...].unsqueeze(1)

if ch_idx == 0:
outs_input = self.model.forward(input_channel)
outs_target = self.model.forward(target_channel)
feats_per_ch = outs_input.shape[1]
else:
outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)
outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
feats_diff: torch.Tensor = (feats_input - feats_target) ** 2
if self.channel_wise:
results = torch.zeros(
feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]
)
for i in range(input.shape[1]):
l_idx = i * feats_per_ch
r_idx = (i + 1) * feats_per_ch
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
else:
results = feats_diff.sum(dim=1, keepdim=True)

results = spatial_average_3d(results, keepdim=True)

return results

Expand Down
Loading
Loading