Skip to content

Commit

Permalink
Refactor to allow for a wider model in TIMM (#3976)
Browse files Browse the repository at this point in the history
* update for releases 2.2.0rc0

* Fix Classification explain forward issue (#3867)

Fix bug

* Fix e2e code error (#3871)

* Update test_cli.py

* Update tests/e2e/cli/test_cli.py

Co-authored-by: Eunwoo Shin <[email protected]>

* Update test_cli.py

* Update test_cli.py

---------

Co-authored-by: Eunwoo Shin <[email protected]>

* Add documentation about configurable input size (#3870)

* add docs about configurable input size

* update api usecase and fix bug

* Fix zero-shot e2e (#3876)

Fix

* Fix DeiT for multi-label classification (#3881)

Remove init_args

* Fix Semi-SL for ViT accuracy drop (#3883)

Remove init_args

* Update docs for 2.2 (#3884)

Update docs

* Fix mean and scale for segmentation task (#3885)

fix mean and scale

* Update MAPI in 2.2 (#3889)

* Bump MAPI

* Update exportable code requirements

* Improve Semi-SL for LiteHRNet (small-medium case) (#3891)

* change drop pixels value

* go safe, change only tested models

* minor

* Improve h-cls for eff models (#3893)

* Update step size for eff v2

* Update effb0 recipe

* Fix maskrcnn swin nncf acc drop (#3900)

update maskrcnn swimt model type to transformer

* Add keypoint detection recipe for single object cases (#3903)

* add rtmpose_tiny for single obj

* add rtmpose_tiny for single obj

* modify test subset name

* fix unit test

* update recipe with reset

* Improve acc drop of efficientnetv2 for h-label cls (#3907)

* Add warmup_iters for effv2

* Update max_epochs

* Fix pretrained weight cached dir for timm (#3909)

* Fix pretrained_weight for timm

* Fix unit-test

* Fix keypoint detection single obj recipe (#3915)

* add rtmpose_tiny for single obj

* modify test subset name

* fix unit test

* property for pck

* Fix cached dir for timm & hugging-face (#3914)

* Fix cached dir

* Pretrained weight download unit-test

* Fix pre-commit

* Fix wrong template id mapping for anomaly (#3916)

* Update script to allow setting otx version using env. variable (#3913)

* Fix Datamodule creation for OV in AutoConfigurator (#3920)

Fix datamodule for ov

* Update tpp file for 2.2.0 (#3921)

* Fix names for ignored scope [HOT-FIX, 2.2.0] (#3924)

fix names for ignored scope

* Fix classification rt_info (#3922)

* Restore output_raw_scores for classificaiton

* Add uts

* Fix linter

* Update label info (#3925)

add label info to init

Signed-off-by: Ashwin Vaidya <[email protected]>

* Fix binary classification metric task (#3928)

* Fix binary classification

* Add unit-tests

* Improve MaskRCNN SwinT NNCF (#3929)

* ignore heads and disable smooth quant

* add activations_range_estimator_params

* update changelog

* Fix get_item for Chained Tasks in Classification (#3931)

* Fix Task Chain

* Add multi-label case as well

* Add multi-label case as well2

* Add H-label case

* Correct Keyerror for h-label cls in label_groups for dm_label_categories using label's id/key (#3932)

Modify label_groups for dm_label_categories with id/key of label

* Remove datumaro attribute id from tiling, add subset names (#3933)

* remove datumaro attribute id from tiling

* add subset names

* Fix soft predictions for Semantic Segmentation (#3934)

fix soft preds

* Update STFPM config (#3935)

* Add missing pretrained weights when creating a docker image (#3938)

* Fix pre-trained weight downloader

* Remove if condition for pretrained wiehgt download

* Change default option 'full' to 'base' in otx install (#3937)

* Change option full to base for otx install

* Fix wrong code

* Fix issue

* Fix docs

* Fix auto adapt batch size in Converter (#3939)

* Enable auto adapt batch size into converter

* Fix wrong

* Fix hpo converter (#3940)

* save best hp after hpo

* add test

* Fix tiling XAI out of range (#3943)

- Fix tile merge XAI out of range

* enable model export (#3952)

Signed-off-by: Ashwin Vaidya <[email protected]>

* Move templates from OTX1.X to OTX2.X (#3951)

* add otx1.6 templates

* added new models

* delete entrypoints and nncf cfg

* updated some hyperparams

* fix for rtmdet_tiny

* updated converter

* Update classification templates

* Update det, r-det, vpm

* Update template.yaml

* changed warmaup value in train.yaml

---------

Co-authored-by: Kang, Harim <[email protected]>
Co-authored-by: Kim, Sungchul <[email protected]>

* Add missing tile recipes and various tile recipe changes  (#3942)

* add missing tile recipes

* Fix tiling XAI out of range (#3943)

- Fix tile merge XAI out of range

* update xai tile merge

* update rtdetr

* update tile recipes

* update rtdetr tile postprocess

* update rtdetr recipes and tile recipes

* update tile recipes

* fix rtdetr unittest

* update recipes

* refactor tile unit test

* address pr reviews

* remove unnecessary files

* update color channel

* fix image channel passing

* include tiling in cli integration test

* remove transform_bbox

---------

Co-authored-by: Vladislav Sovrasov <[email protected]>

* Support ImageFromBytes (#3948)

* add image_from_bytes

Signed-off-by: Ashwin Vaidya <[email protected]>

* refactor code

Signed-off-by: Ashwin Vaidya <[email protected]>

* allow empty anomalous masks

Signed-off-by: Ashwin Vaidya <[email protected]>

---------

Signed-off-by: Ashwin Vaidya <[email protected]>

* Change categories mapping logic (#3946)

* change pre-filtering logic

* Update src/otx/core/data/pre_filtering.py

Co-authored-by: Eunwoo Shin <[email protected]>

---------

Co-authored-by: Eunwoo Shin <[email protected]>

* Update for 2.2.0rc1 (#3956)

* Include Geti arrow dataset subset names (#3962)

* restrited number of output masks by tiling

* add geti subset name

* update num of max pred

* Include full image with anno in case there's no tile in tile dataset (#3964)

* include full image with anno incase there's no tile in dataset

* update test

* Add type checker in converter for callable functions (optimizer, scheduler) (#3968)

Fix converter callable functions (optimizer, scheduler)

* Update for 2.2.0rc2 (#3969)

update for 2.2.0rc2

* Refactor TIMM

* Remove experimental recipes

* Revert timm version

* Fix conflict2

* Fix unit-test

---------

Signed-off-by: Ashwin Vaidya <[email protected]>
Co-authored-by: Yunchu Lee <[email protected]>
Co-authored-by: Emily Chun <[email protected]>
Co-authored-by: Eunwoo Shin <[email protected]>
Co-authored-by: Kim, Sungchul <[email protected]>
Co-authored-by: Prokofiev Kirill <[email protected]>
Co-authored-by: Vladislav Sovrasov <[email protected]>
Co-authored-by: Sooah Lee <[email protected]>
Co-authored-by: Eugene Liu <[email protected]>
Co-authored-by: Wonju Lee <[email protected]>
Co-authored-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
11 people authored Sep 27, 2024
1 parent a25a94f commit 00d5604
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 51 deletions.
39 changes: 12 additions & 27 deletions src/otx/algo/classification/backbones/timm.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,47 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""EfficientNetV2 model.
"""Timm Backbone Class for OTX classification.
Original papers:
- 'EfficientNetV2: Smaller Models and Faster Training,' https://arxiv.org/abs/2104.00298,
- 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665.
"""
from __future__ import annotations

from typing import Literal

import timm
import torch
from torch import nn

TimmModelType = Literal[
"mobilenetv3_large_100_miil_in21k",
"mobilenetv3_large_100_miil",
"tresnet_m",
"tf_efficientnetv2_s.in21k",
"tf_efficientnetv2_s.in21ft1k",
"tf_efficientnetv2_m.in21k",
"tf_efficientnetv2_m.in21ft1k",
"tf_efficientnetv2_b0",
]


class TimmBackbone(nn.Module):
"""Timm backbone model."""
"""Timm backbone model.
Args:
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
pretrained (bool, optional): Whether to load pretrained weights. Defaults to False.
"""

def __init__(
self,
backbone: TimmModelType,
model_name: str,
pretrained: bool = False,
pooling_type: str = "avg",
**kwargs,
):
super().__init__(**kwargs)
self.backbone = backbone
self.model_name = model_name
self.pretrained: bool | dict = pretrained
self.is_mobilenet = backbone.startswith("mobilenet")

self.model = timm.create_model(
self.backbone,
self.model_name,
pretrained=pretrained,
num_classes=1000,
)

self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx.
self.num_head_features = self.model.num_features
self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features
self.pooling_type = pooling_type
self.num_features = self.model.num_features

def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]:
"""Forward."""
Expand All @@ -60,11 +50,6 @@ def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]:

def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features."""
if self.is_mobilenet:
x = self.model.conv_stem(x)
x = self.model.bn1(x)
x = self.model.act1(x)
return self.model.blocks(x)
return self.model.forward_features(x)

def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]:
Expand Down
104 changes: 90 additions & 14 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""EfficientNetV2 model implementation."""
"""TIMM wrapper model class for OTX."""

from __future__ import annotations

Expand All @@ -12,7 +12,7 @@
import torch
from torch import nn

from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.backbones.timm import TimmBackbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
Expand Down Expand Up @@ -50,12 +50,38 @@


class TimmModelForMulticlassCls(OTXMulticlassClsModel):
"""TimmModel for multi-class classification task."""
"""TimmModel for multi-class classification task.
Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type.
Example:
1. API
>>> model = TimmModelForMulticlassCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

def __init__(
self,
label_info: LabelInfoTypes,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
Expand All @@ -64,7 +90,7 @@ def __init__(
torch_compile: bool = False,
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand Down Expand Up @@ -92,7 +118,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
Expand Down Expand Up @@ -142,20 +168,45 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t


class TimmModelForMultilabelCls(OTXMultilabelClsModel):
"""TimmModel for multi-label classification task."""
"""TimmModel for multi-label classification task.
Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
Example:
1. API
>>> model = TimmModelForMultilabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

def __init__(
self,
label_info: LabelInfoTypes,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand All @@ -182,7 +233,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
Expand Down Expand Up @@ -222,22 +273,47 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t


class TimmModelForHLabelCls(OTXHlabelClsModel):
"""EfficientNetV2 Model for hierarchical label classification task."""
"""Timm Model for hierarchical label classification task.
Args:
label_info (HLabelInfo): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
Example:
1. API
>>> model = TimmModelForHLabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<h-label-info>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForHLabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

label_info: HLabelInfo

def __init__(
self,
label_info: HLabelInfo,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand Down Expand Up @@ -267,7 +343,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, head_config: dict) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls
init_args:
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k
train_type: SEMI_SUPERVISED

optimizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/algo/classification/backbones/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

class TestOTXEfficientNetV2:
def test_forward(self):
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k")
assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8])

def test_get_config_optim(self):
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k")
assert model.get_config_optim([0.01])[0]["lr"] == 0.01
assert model.get_config_optim(0.01)[0]["lr"] == 0.01

Expand All @@ -24,5 +24,5 @@ def test_check_pretrained_weight_download(self):
if target.exists():
shutil.rmtree(target)
assert not target.exists()
TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True)
TimmBackbone(model_name="tf_efficientnetv2_s.in21k", pretrained=True)
assert target.exists()
6 changes: 3 additions & 3 deletions tests/unit/algo/classification/test_timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def fxt_multi_class_cls_model():
return TimmModelForMulticlassCls(
label_info=10,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_
def fxt_multi_label_cls_model():
return TimmModelForMultilabelCls(
label_info=10,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_
def fxt_h_label_cls_model(fxt_hlabel_cifar):
return TimmModelForHLabelCls(
label_info=fxt_hlabel_cifar,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down

0 comments on commit 00d5604

Please sign in to comment.