Skip to content

Commit

Permalink
Merge pull request mlcommons#914 from VukW/binary-auroc-fix
Browse files Browse the repository at this point in the history
Fix binary auroc error
  • Loading branch information
sarthakpati authored Aug 10, 2024
2 parents dc90384 + 28c1e10 commit 72bd1db
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,17 @@ def __convert_tensor_to_int(input_tensor: torch.Tensor) -> torch.Tensor:
),
}
for metric_name, calculator in calculators.items():
metric_prediction = prediction
metric_target = target
if "auroc" in metric_name:
output_metrics[metric_name] = get_output_from_calculator(
predictions_prob, target_wrap, calculator
)
else:
output_metrics[metric_name] = get_output_from_calculator(
prediction, target, calculator
)
metric_prediction = predictions_prob
metric_target = target_wrap
if task == "binary":
metric_prediction = predictions_prob[:, 1]

output_metrics[metric_name] = get_output_from_calculator(
metric_prediction, metric_target, calculator
)

# metrics that do not need the "average" parameter
calculators = {
Expand Down

0 comments on commit 72bd1db

Please sign in to comment.