Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryAUROC raises IndexError with max_fpr > 0 and only true edges #1891

Closed
klieret opened this issue Jul 7, 2023 · 1 comment · Fixed by #1895
Closed

BinaryAUROC raises IndexError with max_fpr > 0 and only true edges #1891

klieret opened this issue Jul 7, 2023 · 1 comment · Fixed by #1895
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@klieret
Copy link

klieret commented Jul 7, 2023

🐛 Bug

To Reproduce

from torchmetrics.classification import BinaryAUROC
from torch import Tensor as T

m = BinaryAUROC(max_fpr=0.01)
# index error raised here
m(T([0, 0.5, 0.9]), T([1, 1, 1]))

If, in contrast, BinaryAUROC is initialized without max_fpr, you get the behavior from the next section:

Expected behavior

/Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: No negative samples in targets, false positive value should be meaningless. Returning zero tensor in false positive score
  warnings.warn(*args, **kwargs)
0

Actual behavior

IndexError with the following stacktrace:

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/metric.py:236, in Metric.forward(self, *args, **kwargs)
    234     self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    235 else:
--> 236     self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    238 return self._forward_cache

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/metric.py:303, in Metric._forward_reduce_state_update(self, *args, **kwargs)
    301 # calculate batch state and compute batch value
    302 self.update(*args, **kwargs)
--> 303 batch_val = self.compute()
    305 # reduce batch and global state
    306 self._update_count = _update_count + 1

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/metric.py:532, in Metric._wrap_compute.<locals>.wrapped_func(*args, **kwargs)
    524 # compute relies on the sync context manager to gather the states across processes and apply reduction
    525 # if synchronization happened, the current rank accumulated states will be restored to keep
    526 # accumulation going if ``should_unsync=True``,
    527 with self.sync_context(
    528     dist_sync_fn=self.dist_sync_fn,
    529     should_sync=self._to_sync,
    530     should_unsync=self._should_unsync,
    531 ):
--> 532     value = compute(*args, **kwargs)
    533     self._computed = _squeeze_if_scalar(value)
    535 return self._computed

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/classification/auroc.py:114, in BinaryAUROC.compute(self)
    112 else:
    113     state = self.confmat
--> 114 return _binary_auroc_compute(state, self.thresholds, self.max_fpr)

File ~/micromamba/envs/gnn/lib/python3.11/site-packages/torchmetrics/functional/classification/auroc.py:97, in _binary_auroc_compute(state, thresholds, max_fpr, pos_label)
     95 # Add a single point at max_fpr and interpolate its tpr value
     96 stop = torch.bucketize(max_area, fpr, out_int32=True, right=True)
---> 97 weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
     98 interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight)
     99 tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])

IndexError: index 4 is out of bounds for dimension 0 with size 4

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): pip-installed 1.0.0
  • Python & PyTorch Version (e.g., 1.0): 3.11.3, 2.0.1
  • Any other relevant information such as OS (e.g., Linux): OSX

Additional context

I'd be happy to work on a fix.

@klieret klieret added bug / fix Something isn't working help wanted Extra attention is needed labels Jul 7, 2023
@github-actions
Copy link

github-actions bot commented Jul 7, 2023

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants