Skip to content

Commit

Permalink
Hacky fix for Dice score with average set to weighted or none
Browse files Browse the repository at this point in the history
  • Loading branch information
blazejdolicki committed Jan 23, 2023
1 parent 08244a8 commit 5a22643
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,11 @@ def __init__(
self.ignore_index = ignore_index
self.top_k = top_k

if average not in ["micro", "macro", "samples"]:
raise ValueError(f"The `reduce` {average} is not valid.")

if mdmc_average not in [None, "samplewise", "global"]:
raise ValueError(f"The `mdmc_reduce` {mdmc_average} is not valid.")

if average == "macro" and (not num_classes or num_classes < 1):
raise ValueError("When you set `average` as 'macro', you have to provide the number of classes.")
if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1):
raise ValueError(f"When you set `average` as '{average}', you have to provide the number of classes.")

if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1):
raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes")
Expand All @@ -180,7 +177,7 @@ def __init__(
if mdmc_average != "samplewise" and average != "samples":
if average == "micro":
zeros_shape = []
elif average == "macro":
elif average in ["macro", "weighted", "none", None]:
zeros_shape = [num_classes]
else:
raise ValueError(f'Wrong reduce="{average}"')
Expand Down

0 comments on commit 5a22643

Please sign in to comment.