Skip to content

Commit

Permalink
temp remove MultilabelTopKRecall
Browse files Browse the repository at this point in the history
  • Loading branch information
Noietch committed Oct 22, 2024
1 parent b8611b3 commit 5e2d9ea
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
6 changes: 3 additions & 3 deletions oadp/dp/bbox_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from todd.models import LossRegistry as LR

from .classifiers import Classifier
from .utils import MultilabelTopKRecall
# from .utils import MultilabelTopKRecall


class NotWithRegMixin(BBoxHead):
Expand All @@ -24,7 +24,7 @@ class BlockMixin(NotWithRegMixin):

def __init__(self, *args, topk: int, loss: todd.Config, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._multilabel_topk_recall = MultilabelTopKRecall(k=topk)
# self._multilabel_topk_recall = MultilabelTopKRecall(k=topk)
self._loss = LR.build(loss)

def loss(
Expand All @@ -34,7 +34,7 @@ def loss(
) -> dict[str, torch.Tensor]:
return dict(
loss_block=self._loss(logits.sigmoid(), targets),
recall_block=self._multilabel_topk_recall(logits, targets),
# recall_block=self._multilabel_topk_recall(logits, targets),
)


Expand Down
6 changes: 3 additions & 3 deletions oadp/dp/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ..utils import Globals
from .roi_heads import OADPRoIHead
from .utils import MultilabelTopKRecall
# from .utils import MultilabelTopKRecall

SelfDistiller = kd.distillers.SelfDistiller
StudentMixin = kd.distillers.StudentMixin
Expand All @@ -35,7 +35,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._multilabel_topk_recall = MultilabelTopKRecall(k=topk)
# self._multilabel_topk_recall = MultilabelTopKRecall(k=topk)
self._classifier = MODELS.build(classifier)
self._loss = LossRegistry.build(loss)

Expand All @@ -59,7 +59,7 @@ def forward(
targets[i, label] = True
return dict(
loss_global=self._loss(logits.sigmoid(), targets),
recall_global=self._multilabel_topk_recall(logits, targets),
# recall_global=self._multilabel_topk_recall(logits, targets),
)


Expand Down
75 changes: 41 additions & 34 deletions oadp/dp/utils.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,53 @@
__all__ = [
'MultilabelTopKRecall',
# 'MultilabelTopKRecall',
'NormalizedLinear',
]

import sklearn.metrics
# import sklearn
import torch
import torch.nn.functional as F
from torch import nn


class MultilabelTopKRecall(nn.Module):

def __init__(self, *args, k: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._k = k

def forward(
self,
logits: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor:
r"""Compute the multilabel top-K recall.
Args:
logits: :math:`bs \times K`, float.
targets: :math:`bs \times K`, bool.
Returns:
One element tensor representing the recall.
"""
_, indices = logits.topk(self._k)
preds = torch.zeros_like(targets).scatter(1, indices, 1)
# labels showing up at least once
labels, = torch.where(targets.sum(0))
recall = sklearn.metrics.recall_score(
targets.cpu().numpy(),
preds.cpu().numpy(),
labels=labels.cpu().numpy(),
average='macro',
zero_division=0,
)
return logits.new_tensor(recall * 100)
# class MultilabelTopKRecall(nn.Module):

# def __init__(self, *args, k: int, **kwargs) -> None:
# super().__init__(*args, **kwargs)
# self._k = k
# self.evaluator = Engine(self.process_function)
# self.macro_recall = Recall(average=True)
# self.macro_recall.attach(self.evaluator, 'recall')

# def process_function(self, engine, batch):
# y_pred, y = batch
# return y_pred, y

# def forward(
# self,
# logits: torch.Tensor,
# targets: torch.Tensor,
# ) -> torch.Tensor:
# r"""Compute the multilabel top-K recall.

# Args:
# logits: :math:`bs \times K`, float.
# targets: :math:`bs \times K`, bool.

# Returns:
# One element tensor representing the recall.
# """
# _, indices = logits.topk(self._k)
# preds = torch.zeros_like(targets).scatter(1, indices, 1)
# # labels showing up at least once
# labels, = torch.where(targets.sum(0))
# recall = sklearn.metrics.recall_score(
# targets.cpu().numpy(),
# preds.cpu().numpy(),
# labels=labels.cpu().numpy(),
# average='macro',
# zero_division=0,
# )
# return logits.new_tensor(recall * 100)


class NormalizedLinear(nn.Linear):
Expand Down

0 comments on commit 5e2d9ea

Please sign in to comment.