Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train metrics #868

Merged
merged 18 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 26 additions & 43 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pathlib
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -51,21 +51,20 @@ def validate_network(
print("*" * 20)
# Initialize a few things
total_epoch_valid_loss = 0
total_epoch_valid_metric = {}
total_epoch_valid_metric: dict[str, Union[float, np.array]] = {}
average_epoch_valid_metric = {}

for metric in params["metrics"]:
if "per_label" in metric:
total_epoch_valid_metric[metric] = []
total_epoch_valid_metric[metric] = np.zeros(1)
else:
total_epoch_valid_metric[metric] = 0

logits_list = []
subject_id_list = []
is_classification = params.get("problem_type") == "classification"
calculate_overall_metrics = (
(params["problem_type"] == "classification")
or (params["problem_type"] == "regression")
params["problem_type"] in {"classification", "regression"}
) and mode == "validation"
is_inference = mode == "inference"

Expand Down Expand Up @@ -107,10 +106,8 @@ def validate_network(

# get ground truths for classification problem, validation set
if calculate_overall_metrics:
(
ground_truth_array,
predictions_array,
) = get_ground_truths_and_predictions_tensor(params, "validation_data")
ground_truth_array = []
predictions_array = []

for batch_idx, (subject) in enumerate(
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
Expand Down Expand Up @@ -193,6 +190,7 @@ def validate_network(

if params["save_output"] or is_inference:
# we divide by scaling factor here because we multiply by it during loss/metric calculation
# TODO: regression-only, right?
outputToWrite += (
str(epoch)
+ ","
Expand All @@ -206,23 +204,15 @@ def validate_network(
)

if calculate_overall_metrics:
predictions_array[batch_idx] = (
torch.argmax(pred_output[0], 0).cpu().item()
)
ground_truth_array.append(label_ground_truth.item())
# TODO: that's for classification only. What about regression?
predictions_array.append(torch.argmax(pred_output[0], 0).cpu().item())
# # Non network validation related
total_epoch_valid_loss += final_loss.detach().cpu().item()
for metric in final_metric.keys():
if isinstance(total_epoch_valid_metric[metric], list):
if len(total_epoch_valid_metric[metric]) == 0:
total_epoch_valid_metric[metric] = np.array(
final_metric[metric]
)
else:
total_epoch_valid_metric[metric] += np.array(
final_metric[metric]
)
else:
total_epoch_valid_metric[metric] += final_metric[metric]
for metric, metric_val in final_metric.items():
total_epoch_valid_metric[metric] = (
total_epoch_valid_metric[metric] + metric_val
)

else: # for segmentation problems OR regression/classification when no label is present
grid_sampler = torchio.inference.GridSampler(
Expand Down Expand Up @@ -315,8 +305,7 @@ def validate_network(

# save outputs
if params["problem_type"] == "segmentation":
output_prediction = aggregator.get_output_tensor()
output_prediction = output_prediction.unsqueeze(0)
output_prediction = aggregator.get_output_tensor().unsqueeze(0)
if params["save_output"]:
img_for_metadata = torchio.ScalarImage(
tensor=subject["1"]["data"].squeeze(0),
Expand Down Expand Up @@ -386,16 +375,18 @@ def validate_network(
# final regression output
output_prediction = output_prediction / len(patch_loader)
if calculate_overall_metrics:
predictions_array[batch_idx] = (
# TOD: what? regression and argmax?
predictions_array.append(
torch.argmax(output_prediction[0], 0).cpu().item()
)
ground_truth_array.append(label_ground_truth.item())
if params["save_output"]:
outputToWrite += (
str(epoch)
+ ","
+ subject["subject_id"][0]
+ ","
+ str(output_prediction)
+ str(output_prediction[0])
+ "\n"
)

Expand All @@ -407,7 +398,6 @@ def validate_network(
n.squeeze(), raw_input=image[i].squeeze(-1)
)

output_prediction = output_prediction.squeeze(-1)
if is_inference and is_classification:
logits_list.append(output_prediction)
subject_id_list.append(subject.get("subject_id")[0])
Expand All @@ -418,9 +408,8 @@ def validate_network(
if label_ground_truth.shape[0] == 3:
label_ground_truth = label_ground_truth[0, ...].unsqueeze(0)
# we always want the ground truth to be in the same format as the prediction
# add batch dim
label_ground_truth = label_ground_truth.unsqueeze(0)
if label_ground_truth.shape[-1] == 1:
label_ground_truth = label_ground_truth.squeeze(-1)
final_loss, final_metric = get_loss_and_metrics(
image,
label_ground_truth,
Expand All @@ -440,17 +429,9 @@ def validate_network(
# loss.cpu().data.item()
total_epoch_valid_loss += final_loss.cpu().item()
for metric in final_metric.keys():
if isinstance(total_epoch_valid_metric[metric], list):
if len(total_epoch_valid_metric[metric]) == 0:
total_epoch_valid_metric[metric] = np.array(
final_metric[metric]
)
else:
total_epoch_valid_metric[metric] += np.array(
final_metric[metric]
)
else:
total_epoch_valid_metric[metric] += final_metric[metric]
total_epoch_valid_metric[metric] = (
total_epoch_valid_metric[metric] + final_metric[metric]
)

if label_ground_truth is not None:
if params["verbose"]:
Expand Down Expand Up @@ -486,7 +467,9 @@ def validate_network(
# get overall stats for classification
if calculate_overall_metrics:
average_epoch_valid_metric = overall_stats(
predictions_array, ground_truth_array, params
torch.Tensor(predictions_array),
torch.Tensor(ground_truth_array),
params,
)
average_epoch_valid_metric = print_and_format_metrics(
average_epoch_valid_metric,
Expand Down
66 changes: 35 additions & 31 deletions GANDLF/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from typing import Dict, Tuple
import warnings
from typing import Dict, Tuple, Union
from GANDLF.losses import global_losses_dict
from GANDLF.metrics import global_metrics_dict
import torch
Expand All @@ -13,7 +14,7 @@ def get_metric_output(
prediction: torch.Tensor,
target: torch.Tensor,
params: dict,
) -> float:
) -> Union[float, list]:
"""
This function computes the metric output for a given metric function, prediction and target.

Expand All @@ -36,6 +37,12 @@ def get_metric_output(
if len(temp) > 1:
return temp
else:
# TODO: this branch is extremely age case and is buggy.
# Overall the case when metric returns a list but of length 1 is very rare. The only case is when
# the metric returns Nx.. tensor (i.e. without aggregation by elements) and batch_size==N==1. This branch
# would definitely fail for such a metrics like
# MulticlassAccuracy(num_classes=3, multidim_average="samplewise")
# Maybe the best solution is to raise an error here if metric is configured to return samplewise results?
return metric_output.item()


Expand Down Expand Up @@ -115,41 +122,38 @@ def get_loss_and_metrics(
loss_kld = global_losses_dict["kld"](prediction[2], prediction[3])
loss_cycle = global_losses_dict["mse"](prediction[2], prediction[4], None)
loss = 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
elif deep_supervision_model:
# this is for models that have deep-supervision
for i, _ in enumerate(prediction):
# loss is calculated based on resampled "soft" labels using a pre-defined weights array
loss += (
loss_function(prediction[i], ground_truth_resampled[i], params)
* loss_weights[i]
)
else:
if deep_supervision_model:
# this is for models that have deep-supervision
for i, _ in enumerate(prediction):
# loss is calculated based on resampled "soft" labels using a pre-defined weights array
loss += (
loss_function(prediction[i], ground_truth_resampled[i], params)
* loss_weights[i]
)
else:
loss = loss_function(prediction, target, params)
loss = loss_function(prediction, target, params)
metric_output = {}

# Metrics should be a list
for metric in params["metrics"]:
metric_lower = metric.lower()
metric_output[metric] = 0
if metric_lower in global_metrics_dict:
metric_function = global_metrics_dict[metric_lower]
if sdnet_check:
metric_output[metric] = get_metric_output(
metric_function, prediction[0], target.squeeze(-1), params
if metric_lower not in global_metrics_dict:
warnings.warn("WARNING: Could not find the requested metric '" + metric)
continue

metric_function = global_metrics_dict[metric_lower]
if sdnet_check:
metric_output[metric] = get_metric_output(
metric_function, prediction[0], target.squeeze(-1), params
)
elif deep_supervision_model:
for i, _ in enumerate(prediction):
metric_output[metric] += get_metric_output(
metric_function, prediction[i], ground_truth_resampled[i], params
)
else:
if deep_supervision_model:
for i, _ in enumerate(prediction):
metric_output[metric] += get_metric_output(
metric_function,
prediction[i],
ground_truth_resampled[i],
params,
)

else:
metric_output[metric] = get_metric_output(
metric_function, prediction, target, params
)
else:
metric_output[metric] = get_metric_output(
metric_function, prediction, target, params
)
return loss, metric_output
75 changes: 45 additions & 30 deletions GANDLF/compute/step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple
import warnings
from typing import Optional, Tuple, Union
import torch
import psutil
from .loss_and_metric import get_loss_and_metrics
Expand All @@ -7,22 +8,27 @@
def step(
model: torch.nn.Module,
image: torch.Tensor,
label: torch.Tensor,
label: Optional[torch.Tensor],
params: dict,
train: Optional[bool] = True,
) -> Tuple[float, dict, torch.Tensor, torch.Tensor]:
) -> Tuple[float, dict, Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]:
"""
This function performs a single step of training or validation.

Args:
model (torch.nn.Module): The model to process the input image with, it should support appropriate dimensions.
image (torch.Tensor): The input image stack according to requirements.
label (torch.Tensor): The input label for the corresponding image tensor.
image (torch.Tensor): The input image stack according to requirements. (B, C, H, W, D)
label Optional[torch.Tensor]: The input label for the corresponding image tensor.
If segmentation, (B, C, H, W, D);
if classification / regression (not multilabel), (B, 1)
if classif / reg (multilabel), (B, N_LABELS)

params (dict): The parameters dictionary.
train (Optional[bool], optional): Whether the step is for training or validation. Defaults to True.

Returns:
Tuple[float, dict, torch.Tensor, torch.Tensor]: The loss, metrics, output, and attention map.
Tuple[float, dict, Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]: The loss, metrics, output,
and attention map.
"""
if params["verbose"]:
if torch.cuda.is_available():
Expand All @@ -44,37 +50,34 @@ def step(
if params["problem_type"] == "segmentation":
if label.shape[1] == 3:
label = label[:, 0, ...].unsqueeze(1)
# this warning should only come up once
if params["print_rgb_label_warning"]:
print(
"WARNING: The label image is an RGB image, only the first channel will be used.",
flush=True,
)
params["print_rgb_label_warning"] = False
warnings.warn(
"The label image is an RGB image, only the first channel will be used."
)

if params["model"]["dimension"] == 2:
label = torch.squeeze(label, -1)
assert len(label) == len(image)

if params["model"]["dimension"] == 2:
image = torch.squeeze(image, -1)
if "value_keys" in params:
if label is not None:
if len(label.shape) > 1:
label = torch.squeeze(label, -1)
image = image.squeeze(-1) # removing depth

# for segmentation remove the depth dimension from the label.
# for classification / regression, flattens class / reg label from list (possible in multilabel) to scalar
# TODO: second condition is crutch - in some cases label is passed as 1-d Tensor (B,) and if Batch size is 1,
# it is squeezed to scalar tensor (0-d) and the future logic fails
if label is not None and len(label.shape) != 1:
label = label.squeeze(-1)

if not (train) and params["model"]["type"].lower() == "openvino":
if not train and params["model"]["type"].lower() == "openvino":
output = torch.from_numpy(
model(inputs={params["model"]["IO"][0][0]: image.cpu().numpy()})[
params["model"]["IO"][1][0]
]
)
output = output.to(params["device"])
else:
if params["model"]["amp"]:
with torch.cuda.amp.autocast():
output = model(image)
else:
elif params["model"]["amp"]:
with torch.cuda.amp.autocast():
output = model(image)
else:
output = model(image)

attention_map = None
if "medcam_enabled" in params and params["medcam_enabled"]:
Expand All @@ -86,12 +89,24 @@ def step(
else:
loss, metric_output = None, None

if len(output) > 1:
output = output[0]

if params["model"]["dimension"] == 2:
output = torch.unsqueeze(output, -1)
if "medcam_enabled" in params and params["medcam_enabled"]:
attention_map = torch.unsqueeze(attention_map, -1)

if not isinstance(output, torch.Tensor):
warnings.warn(
f"Model output is not a Tensor: {type(output)}. Say, `deep_resunet` and `deep_unet` may return "
f"list of tensors on different scales instead of just one prediction Tensor. However due to "
f"GaNDLF architecture it is expected that models return only one tensor. For deep_* models "
f"only the biggeest scale is processed. Use these models with caution till fix is implemented."
)
output = output[0]

if params["model"]["dimension"] == 2 and params["problem_type"] == "segmentation":
# for 2d images where the depth is removed, add it back
output = output.unsqueeze(-1)

assert len(output) == len(
image
), f"Error: output({len(output)}) and batch({len(image)}) have different lengths. Both should be equal to batch size!"
return loss, metric_output, output, attention_map
Loading
Loading