Skip to content

Commit

Permalink
Add training loss dynamics exportation feature for multi-class classi…
Browse files Browse the repository at this point in the history
…fication task (#1985)

* Add training loss dynamics tracking module

* Prevent circular import

* Add TestLossDynamicsTrackingMixin

* Refactor code structure

* Fix tests

* Fix tests

* Fix after merge develop

* Fix OTXDataset

* Add integration test

* revert yaml default

* Fix mypy

* Fix mypy in a more neat way

* Use buffer for IBLoss weight

* Fix typo
---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Apr 10, 2023
1 parent 945ab5b commit 746e6eb
Show file tree
Hide file tree
Showing 20 changed files with 570 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pylint: disable=invalid-name, too-many-locals, no-member

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import numpy as np
from mmcls.core import average_performance, mAP
Expand All @@ -20,6 +20,7 @@
from otx.algorithms.common.utils import get_cls_img_indices, get_old_new_img_indices
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.id import ID
from otx.api.entities.label import LabelEntity
from otx.api.utils.argument_checks import (
DatasetParamTypeCheck,
Expand All @@ -42,6 +43,7 @@ def __init__(
self.labels = labels
self.label_names = [label.name for label in self.labels]
self.label_idx = {label.id: i for i, label in enumerate(labels)}
self.idx_to_label_id = {v: k for k, v in self.label_idx.items()}
self.empty_label = empty_label
self.class_acc = False

Expand Down Expand Up @@ -90,19 +92,25 @@ def __getitem__(self, index: int):

height, width = item.height, item.width

gt_label = self.gt_labels[index]
data_info = dict(
dataset_item=item,
width=width,
height=height,
index=index,
gt_label=self.gt_labels[index],
gt_label=gt_label,
ignored_labels=ignored_labels,
entity_id=getattr(item, "id_", None),
label_id=self._get_label_id(gt_label),
)

if self.pipeline is None:
return data_info
return self.pipeline(data_info)

def _get_label_id(self, gt_label: np.ndarray) -> Union[ID, List[ID]]:
return self.idx_to_label_id.get(gt_label.item(), ID())

def get_gt_labels(self):
"""Get all ground-truth labels (categories).
Expand Down Expand Up @@ -285,6 +293,9 @@ def evaluate(

return eval_results

def _get_label_id(self, gt_label: np.ndarray) -> Union[ID, List[ID]]:
return [self.idx_to_label_id.get(idx, ID()) for idx, v in enumerate(gt_label) if v == 1]


@DATASETS.register_module()
class OTXHierarchicalClsDataset(OTXMultilabelClsDataset):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Module defining Mix-in class of SAMClassifier."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import datumaro as dm
import numpy as np
import pandas as pd

from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.dataset_item import DatasetItemEntityWithID
from otx.api.entities.datasets import DatasetEntity
from otx.core.data.noisy_label_detection import LossDynamicsTracker, LossDynamicsTrackingMixin

logger = get_logger()


class SAMClassifierMixin:
"""SAM-enabled BaseClassifier mix-in."""

def train_step(self, data, optimizer=None, **kwargs):
"""Saving current batch data to compute SAM gradient."""
self.current_batch = data
return super().train_step(data, optimizer, **kwargs)


class MultiClassClsLossDynamicsTracker(LossDynamicsTracker):
"""Loss dynamics tracker for multi-class classification task."""

def __init__(self) -> None:
super().__init__()

def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None:
"""DatasetEntity should be injected to the tracker for the initialization."""
otx_labels = otx_dataset.get_labels()
label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels])
self.otx_label_map = {label_entity.id_: idx for idx, label_entity in enumerate(otx_labels)}

def _convert_anns(item: DatasetItemEntityWithID):
labels = [
dm.Label(label=self.otx_label_map[label.id_])
for ann in item.get_annotations()
for label in ann.get_labels()
]
return labels

self._export_dataset = dm.Dataset.from_iterable(
[
dm.DatasetItem(
id=item.id_,
subset="train",
media=dm.Image(path=item.media.path, size=(item.media.height, item.media.width)),
annotations=_convert_anns(item),
)
for item in otx_dataset
],
infos={"purpose": "noisy_label_detection", "task": "OTX-MultiClassCls"},
categories={dm.AnnotationType.label: label_categories},
)

super().init_with_otx_dataset(otx_dataset)

def accumulate(self, outputs, iter) -> None:
"""Accumulate training loss dynamics for each training step."""
entity_ids = outputs["entity_ids"]
label_ids = np.squeeze(outputs["label_ids"])
loss_dyns = outputs["loss_dyns"]

for entity_id, label_id, loss_dyn in zip(entity_ids, label_ids, loss_dyns):
self._loss_dynamics[(entity_id, label_id)].append((iter, loss_dyn))

def export(self, output_path: str) -> None:
"""Export loss dynamics statistics to Datumaro format."""
df = pd.DataFrame.from_dict(
{
k: (np.array([iter for iter, _ in arr]), np.array([value for _, value in arr]))
for k, arr in self._loss_dynamics.items()
},
orient="index",
columns=["iters", "loss_dynamics"],
)

for (entity_id, label_id), row in df.iterrows():
item = self._export_dataset.get(entity_id, "train")
for ann in item.annotations:
if isinstance(ann, dm.Label) and ann.label == self.otx_label_map[label_id]:
ann.attributes = row.to_dict()

self._export_dataset.export(output_path, format="datumaro")


class ClsLossDynamicsTrackingMixin(LossDynamicsTrackingMixin):
"""Mix-in to track loss dynamics during training for classification tasks."""

def __init__(self, track_loss_dynamics: bool = False, **kwargs):
if track_loss_dynamics:
if getattr(self, "multilabel", False) or getattr(self, "hierarchical", False):
raise RuntimeError("multilabel or hierarchical tasks are not supported now.")

head_cfg = kwargs.get("head", None)
loss_cfg = head_cfg.get("loss", None)
loss_cfg["reduction"] = "none"

# This should be called after modifying "reduction" config.
super().__init__(**kwargs)

# This should be called after super().__init__(),
# since LossDynamicsTrackingMixin.__init__() creates self._loss_dyns_tracker
self._loss_dyns_tracker = MultiClassClsLossDynamicsTracker()

def train_step(self, data, optimizer=None, **kwargs):
"""The iteration step for training.
If self._track_loss_dynamics = False, just follow BaseClassifier.train_step().
Otherwise, it steps with tracking loss dynamics.
"""
if self.loss_dyns_tracker.initialized:
return self._train_step_with_tracking(data, optimizer, **kwargs)
return super().train_step(data, optimizer, **kwargs)

def _train_step_with_tracking(self, data, optimizer=None, **kwargs):
losses = self(**data)

loss_dyns = losses["loss"].detach().cpu().numpy()
assert not np.isscalar(loss_dyns)

entity_ids = [img_meta["entity_id"] for img_meta in data["img_metas"]]
label_ids = [img_meta["label_id"] for img_meta in data["img_metas"]]
loss, log_vars = self._parse_losses(losses)

outputs = dict(
loss=loss,
log_vars=log_vars,
loss_dyns=loss_dyns,
entity_ids=entity_ids,
label_ids=label_ids,
num_samples=len(data["img"].data),
)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names

from .sam_classifier_mixin import SAMClassifierMixin
from .mixin import ClsLossDynamicsTrackingMixin, SAMClassifierMixin

logger = get_logger()


@CLASSIFIERS.register_module()
class SAMImageClassifier(SAMClassifierMixin, ImageClassifier):
class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
"""SAM-enabled ImageClassifier."""

def __init__(self, task_adapt=None, **kwargs):
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,26 @@
class IBLoss(CrossEntropyLoss):
"""IB Loss, Influence-Balanced Loss for Imbalanced Visual Classification, https://arxiv.org/abs/2110.02444."""

def __init__(self, num_classes, start=5, alpha=1000.0):
def __init__(self, num_classes, start=5, alpha=1000.0, reduction: str = "mean"):
"""Init fuction of IBLoss.
Args:
num_classes (int): Number of classes in dataset
start (int): Epoch to start finetuning with IB loss
alpha (float): Hyper-parameter for an adjustment for IB loss re-weighting
reduction (str): How to reduce the output. Available options are "none" or "mean". Defaults to 'mean'.
"""
super().__init__(loss_weight=1.0)
super().__init__(loss_weight=1.0, reduction=reduction)
if alpha < 0:
raise ValueError("Alpha for IB loss should be bigger than 0")
self.alpha = alpha
self.epsilon = 0.001
self.num_classes = num_classes
self.weight = None
self.register_buffer("weight", torch.ones(size=(self.num_classes,)))
self._start_epoch = start
self._cur_epoch = 0
if reduction not in {"mean", "none"}:
raise ValueError(f"reduction={reduction} is not allowed.")

@property
def cur_epoch(self):
Expand All @@ -48,7 +51,7 @@ def update_weight(self, cls_num_list):
per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights)
self.weight = per_cls_weights
self.weight.data = per_cls_weights.to(device=self.weight.device)

def forward(self, x, target, feature):
"""Forward fuction of IBLoss."""
Expand All @@ -58,6 +61,6 @@ def forward(self, x, target, feature):
feature = torch.sum(torch.abs(feature), 1).reshape(-1, 1)
scaler = grads * feature.reshape(-1)
scaler = self.alpha / (scaler + self.epsilon)
ce_loss = F.cross_entropy(x, target, weight=self.weight.to(x.get_device()), reduction="none")
ce_loss = F.cross_entropy(x, target, weight=self.weight, reduction="none")
loss = ce_loss * scaler
return loss.mean()
return loss.mean() if self.reduction == "mean" else loss
6 changes: 5 additions & 1 deletion otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.core.data import caching
from otx.core.data.noisy_label_detection import LossDynamicsTrackingHook

from .configurer import (
ClassificationConfigurer,
Expand Down Expand Up @@ -141,6 +142,10 @@ def _init_task(self, export: bool = False): # noqa
# Update recipe with caching modules
self._update_caching_modules(self._recipe_cfg.data)

# Loss dynamics tracking
if getattr(self._hyperparams.algo_backend, "enable_noisy_label_detection", False):
LossDynamicsTrackingHook.configure_recipe(self._recipe_cfg, self._output_path)

# pylint: disable=too-many-arguments
def configure(
self,
Expand Down Expand Up @@ -263,7 +268,6 @@ def _infer_model(
time_monitor = [hook.time_monitor for hook in cfg.custom_hooks if hook.type == "OTXProgressHook"]
time_monitor = time_monitor[0] if time_monitor else None
if time_monitor is not None:

# pylint: disable=unused-argument
def pre_hook(module, inp):
time_monitor.on_test_batch_begin(None, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def patch_datasets(
def update_pipeline(cfg):
if subset == "train":
for collect_cfg in get_configs_by_pairs(cfg, dict(type="Collect")):
get_meta_keys(collect_cfg)
get_meta_keys(collect_cfg, ["entity_id", "label_id"])

for subset in subsets:
if subset not in config.data:
Expand Down
9 changes: 9 additions & 0 deletions otx/algorithms/classification/configs/base/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ class __AlgoBackend(BaseConfig.BaseAlgoBackendParameters):
header = string_attribute("Parameters for the MPA algo-backend")
description = header

enable_noisy_label_detection = configurable_boolean(
default_value=False,
header="Enable loss dynamics tracking for noisy label detection",
description="Set to True to enable loss dynamics tracking for each sample to detect noisy labeled samples.",
editable=False,
visible_in_ui=False,
affects_outcome_of=ModelLifecycle.TRAINING,
)

@attrs
class __POTParameter(BaseConfig.BasePOTParameter):
"""POT-related parameter configurations."""
Expand Down
15 changes: 15 additions & 0 deletions otx/algorithms/classification/configs/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,20 @@ algo_backend:
type: UI_RULES
visible_in_ui: false
warning: null
enable_noisy_label_detection:
affects_outcome_of: TRAINING
default_value: false
description: Set to True to enable loss dynamics tracking for each sample to detect noisy labeled samples.
editable: true
header: Enable loss dynamics tracking for noisy label detection
type: BOOLEAN
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
value: true
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
3 changes: 2 additions & 1 deletion otx/algorithms/common/adapters/mmcv/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,11 @@ def align_data_config_with_recipe(data_config: ConfigDict, config: Union[Config,
)


def get_meta_keys(pipeline_step):
def get_meta_keys(pipeline_step, add_meta_keys: List[str] = []):
"""Update meta_keys for ignore_labels."""
meta_keys = list(pipeline_step.get("meta_keys", DEFAULT_META_KEYS))
meta_keys.append("ignored_labels")
meta_keys += add_meta_keys
pipeline_step["meta_keys"] = set(meta_keys)
return pipeline_step

Expand Down
Loading

0 comments on commit 746e6eb

Please sign in to comment.