Skip to content

Commit

Permalink
Merge branch 'new-apis_v0.1.0-dev' into add_logging_final_version
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati authored Jul 19, 2024
2 parents f857e65 + b6dfe2d commit ca64876
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
1 change: 0 additions & 1 deletion GANDLF/entrypoints/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-

import argparse
import sys
import click
from deprecated import deprecated
from typing import Optional
Expand Down
31 changes: 23 additions & 8 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import torch
import torchmetrics as tm
import torch.nn.functional as F
Expand All @@ -9,9 +7,7 @@
from GANDLF.utils.generic import determine_classification_task_type


def overall_stats(
prediction: torch.Tensor, target: torch.Tensor, params: dict
) -> dict[str, Union[float, list]]:
def overall_stats(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> dict:
"""
Generates a dictionary of metrics calculated on the overall prediction and ground truths.
Expand All @@ -27,6 +23,26 @@ def overall_stats(
params["problem_type"] == "classification"
), "Only classification is supported for these stats"

def __convert_tensor_to_int(input_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert the input tensor to integer format.
Args:
input_tensor (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The tensor converted to integer format.
"""
return_tensor = input_tensor
if return_tensor.dtype != torch.long or return_tensor.dtype != torch.int:
return_tensor = return_tensor.long()
return return_tensor

# this is needed for a few metrics
# ensure that predictions and target are in integer format
prediction_wrap = __convert_tensor_to_int(prediction)
target_wrap = __convert_tensor_to_int(target)

# this is needed for auroc
# ensure that predictions are in integer format
prediction_wrap = prediction
Expand Down Expand Up @@ -85,13 +101,12 @@ def overall_stats(
),
}
for metric_name, calculator in calculators.items():
avg_typed_metric_name = f"{metric_name}_{average_type_key}"
if "auroc" in metric_name:
output_metrics[metric_name] = get_output_from_calculator(
predictions_prob, target, calculator
predictions_prob, target_wrap, calculator
)
else:
output_metrics[avg_typed_metric_name] = get_output_from_calculator(
output_metrics[metric_name] = get_output_from_calculator(
prediction, target, calculator
)

Expand Down

0 comments on commit ca64876

Please sign in to comment.