Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add Semi-SL multilabel classification algorithm #1805

Merged
merged 22 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,56 @@ In the table below the `mAP <https://en.wikipedia.org/w/index.php?title=Informat
| EfficientNet-V2-S | 91.91 | 77.28 | 71.52 | 80.24 |
+-----------------------+-----------------+-----------+------------------+-----------+

.. ************************
.. Semi-supervised Learning
.. ************************
************************
Semi-supervised Learning
************************

.. To be added soon
Semi-SL (Semi-supervised Learning) is a type of machine learning algorithm that uses both labeled and unlabeled data to improve the performance of the model. This is particularly useful when labeled data is limited, expensive or time-consuming to obtain.

To utilize unlabeled data during training, we use `BarlowTwins loss <https://arxiv.org/abs/2103.03230>`_ as an auxiliary loss for Semi-SL task solving. BarlowTwins enforces consistency across augmented versions of the same data (both labeled and unlabeled): each sample is augmented first with `Augmix <https://arxiv.org/abs/1912.02781>`_, then strongly augmented sample is generated by applying a pre-defined `RandAugment <https://arxiv.org/abs/1909.13719>`_ strategy on top of the basic augmentation.

.. _mlc_cls_semi_supervised_pipeline:

- ``BarlowTwins loss``: A specific implementation of Semi-SL that combines the use of a consistency loss with strong data augmentations, and a specific optimizer called Sharpness-Aware Minimization (`SAM <https://arxiv.org/abs/2010.01412>`_) to improve the performance of the model.

- ``Adaptive loss auxiliary loss weighting``: A technique for assigning such a weight for an auxiliary loss that the resulting value is a predefined fraction of the EMA-smoothed main loss value. This method allows aligning contribution of the losses during different training phases.

- ``Exponential Moving Average (EMA)``: A technique for maintaining a moving average of the model's parameters, which can improve the generalization performance of the model.

- ``Additional techniques``: Other than that, we use several solutions that apply to supervised learning (No bias Decay, Augmentations, Early-Stopping, etc.)
sovrasov marked this conversation as resolved.
Show resolved Hide resolved

Please, refer to the :doc:`tutorial <../../../tutorials/advanced/semi_sl>` on how to train semi-supervised learning.
Training time depends on the number of images and can be up to several times longer than conventional supervised learning.

In the table below the mAP metric on some public datasets using our pipeline is presented.

+-----------------------+---------+----------------------+----------------+---------+----------------+---------+
| Dataset | AerialMaritime 3 cls | | VOC 2007 3 cls | | COCO 14 5 cls | |
+=======================+======================+=========+================+=========+================+=========+
| | SL | Semi-SL | SL | Semi-SL | SL | Semi-SL |
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
| MobileNet-V3-large-1x | 74.28 | 74.41 | 96.34 | 97.29 | 82.39 | 83.77 |
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
| EfficientNet-B0 | 79.59 | 80.91 | 97.75 | 98.59 | 83.24 | 84.19 |
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+
| EfficientNet-V2-S | 75.91 | 81.91 | 95.65 | 96.43 | 85.19 | 84.24 |
+-----------------------+----------------------+---------+----------------+---------+----------------+---------+

AerialMaritime was sampled with 5 images per class. VOC was sampled with 10 images per class and COCO was sampled with 20 images per class.
Additionel information abount the datasets can be found in the table below.

+-----------------------+----------------+----------------------+
| Dataset | Labeled images | Unlabeled images |
+=======================+================+======================+
| AerialMaritime 3 cls | 10 | 42 |
+-----------------------+----------------+----------------------+
| VOC 2007 3 cls | 30 | 798 |
+-----------------------+----------------+----------------------+
| COCO 14 5 cls | 95 | 10142 |
+-----------------------+----------------+----------------------+

.. note::
This result can vary depending on the image selected for each class. Also, since there are few labeled settings for the Semi-SL algorithm. Some models may require larger datasets for better results.

.. ************************
.. Self-supervised Learning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@

__train_pipeline = [
*__common_pipeline,
dict(type="PILImageToNDArray", keys=["img"]),
dict(type="PostAug", keys=dict(img_strong=__strong_pipeline)),
dict(type="PILImageToNDArray", keys=["img", "img_strong"]),
dict(type="Normalize", **__img_norm_cfg),
dict(type="ImageToTensor", keys=["img"]),
dict(type="ImageToTensor", keys=["img", "img_strong"]),
dict(type="ToTensor", keys=["gt_label"]),
dict(type="Collect", keys=["img", "gt_label"]),
dict(type="Collect", keys=["img", "img_strong", "gt_label"]),
]

__unlabeled_pipeline = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""EfficientNet-B0 config for semi-supervised multi-label classification."""

# pylint: disable=invalid-name

_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/efficientnet.py"]

model = dict(
task="classification",
type="SemiSLMultilabelClassifier",
backbone=dict(
version="b0",
),
head=dict(
type="SemiLinearMultilabelClsHead",
use_dynamic_loss_weighting=True,
unlabeled_coef=0.1,
in_channels=-1,
aux_mlp=dict(hid_channels=0, out_channels=1024),
normalized=True,
scale=7.0,
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
aux_loss=dict(
type="BarlowTwinsLoss",
off_diag_penality=1.0 / 128.0,
loss_weight=1.0,
),
),
)

fp16 = dict(loss_scale=512.0)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ entrypoints:
base: otx.algorithms.classification.tasks.ClassificationTrainTask
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
base_model_path: ../../adapters/deep_object_reid/configs/efficientnet_b0/template_experimental.yaml

# Capabilities.
capabilities:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""EfficientNet-V2 config for semi-supervised multi-label classification."""

# pylint: disable=invalid-name

_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/efficientnet_v2.py"]

model = dict(
task="classification",
type="SemiSLMultilabelClassifier",
backbone=dict(
version="s_21k",
),
head=dict(
type="SemiLinearMultilabelClsHead",
use_dynamic_loss_weighting=True,
unlabeled_coef=0.1,
in_channels=-1,
aux_mlp=dict(hid_channels=0, out_channels=1024),
normalized=True,
scale=7.0,
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
aux_loss=dict(
type="BarlowTwinsLoss",
off_diag_penality=1.0 / 128.0,
loss_weight=1.0,
),
),
)

fp16 = dict(loss_scale=512.0)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ entrypoints:
base: otx.algorithms.classification.tasks.ClassificationTrainTask
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
base_model_path: ../../adapters/deep_object_reid/configs/efficientnet_v2_s/template_experimental.yaml

# Capabilities.
capabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ entrypoints:
base: otx.algorithms.classification.tasks.ClassificationTrainTask
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_large_075/template_experimental.yaml

# Capabilities.
capabilities:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""MobileNet-V3-large-1 config for semi-supervised multi-label classification."""

# pylint: disable=invalid-name

_base_ = ["../../../../../recipes/stages/classification/multilabel/semisl.yaml", "../../base/models/mobilenet_v3.py"]

model = dict(
task="classification",
type="SemiSLMultilabelClassifier",
backbone=dict(mode="large"),
head=dict(
type="SemiNonLinearMultilabelClsHead",
in_channels=960,
hid_channels=1280,
use_dynamic_loss_weighting=True,
unlabeled_coef=0.1,
aux_mlp=dict(hid_channels=0, out_channels=1024),
normalized=True,
scale=7.0,
loss=dict(type="AsymmetricAngularLossWithIgnore", gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
aux_loss=dict(
type="BarlowTwinsLoss",
off_diag_penality=1.0 / 128.0,
loss_weight=1.0,
),
),
)

fp16 = dict(loss_scale=512.0)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ entrypoints:
base: otx.algorithms.classification.tasks.ClassificationTrainTask
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_large_1/template_experimental.yaml

# Capabilities.
capabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ entrypoints:
base: otx.algorithms.classification.tasks.ClassificationTrainTask
openvino: otx.algorithms.classification.tasks.ClassificationOpenVINOTask
nncf: otx.algorithms.classification.tasks.nncf.ClassificationNNCFTask
base_model_path: ../../adapters/deep_object_reid/configs/mobilenet_v3_small/template_experimental.yaml

# Capabilities.
capabilities:
Expand Down
2 changes: 1 addition & 1 deletion otx/cli/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,12 @@ def build_workspace(self, new_workspace_path: Optional[str] = None) -> None:
# Copy config files
config_files = [
(model_dir, "model.py", train_type_dir),
(model_dir, "model_multilabel.py", train_type_dir),
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
(model_dir, "data_pipeline.py", train_type_dir),
(template_dir, "tile_pipeline.py", self.workspace_root),
(template_dir, "deployment.py", self.workspace_root),
(template_dir, "hpo_config.yaml", self.workspace_root),
(template_dir, "model_hierarchical.py", self.workspace_root),
(template_dir, "model_multilabel.py", self.workspace_root),
]
for target_dir, file_name, dest_dir in config_files:
self._copy_config_files(target_dir, file_name, dest_dir)
Expand Down
1 change: 1 addition & 0 deletions otx/mpa/cls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import otx.mpa.modules.models.heads.custom_multi_label_non_linear_cls_head
import otx.mpa.modules.models.heads.non_linear_cls_head
import otx.mpa.modules.models.heads.semisl_cls_head
import otx.mpa.modules.models.heads.semisl_multilabel_cls_head
import otx.mpa.modules.models.heads.supcon_cls_head
import otx.mpa.modules.models.losses.asymmetric_angular_loss_with_ignore
import otx.mpa.modules.models.losses.asymmetric_loss_with_ignore
Expand Down
7 changes: 6 additions & 1 deletion otx/mpa/modules/models/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@
#

# flake8: noqa
from . import sam_classifier, semisl_classifier, supcon_classifier
from . import (
sam_classifier,
semisl_classifier,
semisl_multilabel_classifier,
supcon_classifier,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from mmcls.models.builder import CLASSIFIERS

from otx.mpa.utils.logger import get_logger

from .sam_classifier import SAMImageClassifier

logger = get_logger()


@CLASSIFIERS.register_module()
class SemiSLMultilabelClassifier(SAMImageClassifier):
"""Semi-SL Multilabel Classifier
This classifier supports unlabeled data by overriding forward_train
"""

def forward_train(self, imgs, gt_label, **kwargs):
"""Data is transmitted as a classifier training function

Args:
imgs (list[Tensor]): List of tensors of shape (1, C, H, W)
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
Typically these should be mean centered and std scaled.
gt_label (Tensor): Ground truth labels for the input labeled images
kwargs (keyword arguments): Specific to concrete implementation
"""
if "extra_0" not in kwargs:
raise ValueError("'extra_0' does not exist in the dataset")
if "img_strong" not in kwargs:
raise ValueError("'img_strong' does not exist in the dataset")

target = gt_label.squeeze()
unlabeled_data = kwargs["extra_0"]
x = {}
x["labeled_weak"] = self.extract_feat(imgs)
x["labeled_strong"] = self.extract_feat(kwargs["img_strong"])

img_uw = unlabeled_data["img"]
x["unlabeled_weak"] = self.extract_feat(img_uw)

img_us = unlabeled_data["img_strong"]
x["unlabeled_strong"] = self.extract_feat(img_us)

losses = dict()
loss = self.head.forward_train(x, target)
losses.update(loss)

return losses
Loading