Skip to content

Commit

Permalink
Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 7, 2024
1 parent 6132d57 commit 3c2d8c1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
12 changes: 6 additions & 6 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,12 @@ def _get_classes(self) -> list:

def _get_coco_format(
self,
labels: list[torch.Tensor],
boxes: Optional[list[torch.Tensor]] = None,
masks: Optional[list[torch.Tensor]] = None,
scores: Optional[list[torch.Tensor]] = None,
crowds: Optional[list[torch.Tensor]] = None,
area: Optional[list[torch.Tensor]] = None,
labels: list[Tensor],
boxes: Optional[list[Tensor]] = None,
masks: Optional[list[Tensor]] = None,
scores: Optional[list[Tensor]] = None,
crowds: Optional[list[Tensor]] = None,
area: Optional[list[Tensor]] = None,
) -> dict:
"""Transforms and returns all cached targets or predictions in COCO format.
Expand Down
25 changes: 13 additions & 12 deletions src/torchmetrics/regression/csi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional

import torch
from torch import Tensor

from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update
from torchmetrics.metric import Metric
Expand All @@ -40,17 +41,17 @@ class CriticalSuccessIndex(Metric):
Example:
>>> import torch
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> x = Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> csi = CriticalSuccessIndex(0.5)
>>> csi(x, y)
tensor(0.3333)
Example:
>>> import torch
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> x = Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
>>> csi(x, y)
tensor([0.3333, 0.3333])
Expand All @@ -60,12 +61,12 @@ class CriticalSuccessIndex(Metric):
is_differentiable: bool = False
higher_is_better: bool = True

hits: torch.Tensor
misses: torch.Tensor
false_alarms: torch.Tensor
hits_list: list[torch.Tensor]
misses_list: list[torch.Tensor]
false_alarms_list: list[torch.Tensor]
hits:Tensor
misses: Tensor
false_alarms: Tensor
hits_list: list[Tensor]
misses_list: list[Tensor]
false_alarms_list: list[Tensor]

def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand All @@ -84,7 +85,7 @@ def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **
self.add_state("misses_list", default=[], dist_reduce_fx="cat")
self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat")

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
hits, misses, false_alarms = _critical_success_index_update(
preds, target, self.threshold, self.keep_sequence_dim
Expand All @@ -98,7 +99,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
self.misses_list.append(misses)
self.false_alarms_list.append(false_alarms)

def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""Compute critical success index over state."""
if self.keep_sequence_dim is None:
hits = self.hits
Expand Down

0 comments on commit 3c2d8c1

Please sign in to comment.