Skip to content

Commit

Permalink
Merge branch 'releases/2.2.0' into eugene/fix-detr-dtype-casting
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw authored Dec 4, 2024
2 parents f5cd1ec + 5d6f8d3 commit ef32906
Show file tree
Hide file tree
Showing 123 changed files with 811 additions and 694 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)
- Bump onnx to 1.17.0 to omit CVE-2024-5187
(<https://github.com/openvinotoolkit/training_extensions/pull/4063>)
- Decouple DinoV2 for semantic segmentation task
(<https://github.com/openvinotoolkit/training_extensions/pull/4136>)

### Bug fixes

Expand Down Expand Up @@ -126,6 +128,10 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4107>)
- Fix empty annotation in tiling
(<https://github.com/openvinotoolkit/training_extensions/pull/4124>)
- Fix patching early stopping in tools/converter.py, update headers in templates, change training schedule for classification
(<https://github.com/openvinotoolkit/training_extensions/pull/4131>)
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)
- Fix DETR target class indices are of type long in loss calculations
(<https://github.com/openvinotoolkit/training_extensions/pull/4143>)

Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/callbacks/adaptive_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
patience: int = 10,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
Expand Down
167 changes: 158 additions & 9 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
from __future__ import annotations

import math
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand Down Expand Up @@ -46,6 +47,7 @@
"vit-huge",
"dinov2-s",
"dinov2-small",
"dinov2-small-seg",
"dinov2-b",
"dinov2-base",
"dinov2-l",
Expand Down Expand Up @@ -87,6 +89,7 @@ class VisionTransformer(BaseModule):
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
interpolate_offset: work-around offset to apply when interpolating positional embeddings
lora: Enable LoRA training.
"""

Expand Down Expand Up @@ -147,6 +150,17 @@ class VisionTransformer(BaseModule):
"num_heads": 6,
"reg_tokens": 4,
"no_embed_class": True,
},
),
**dict.fromkeys(
["dinov2-small-seg"], # segmentation
{
"patch_size": 14,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 0,
"no_embed_class": False,
"init_values": 1e-5,
},
),
Expand Down Expand Up @@ -193,9 +207,9 @@ class VisionTransformer(BaseModule):

def __init__( # noqa: PLR0913
self,
arch: VIT_ARCH_TYPE = "vit-base",
arch: VIT_ARCH_TYPE | str = "vit-base",
img_size: int | tuple[int, int] = 224,
patch_size: int | tuple[int, int] | None = None,
patch_size: int | None = None,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int | None = None,
Expand All @@ -221,6 +235,7 @@ def __init__( # noqa: PLR0913
mlp_layer: nn.Module | None = None,
act_layer: LayerType | None = None,
norm_layer: LayerType | None = None,
interpolate_offset: float = 0.1,
lora: bool = False,
) -> None:
super().__init__()
Expand All @@ -231,7 +246,7 @@ def __init__( # noqa: PLR0913
arch_settings: dict[str, Any] = self.arch_zoo[arch]

self.img_size: int | tuple[int, int] = img_size
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
self.patch_size: int = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
Expand All @@ -251,6 +266,7 @@ def __init__( # noqa: PLR0913
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.interpolate_offset = interpolate_offset

embed_args = {}
if dynamic_img_size:
Expand Down Expand Up @@ -353,15 +369,17 @@ def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int,
# convert dinov2 pretrained weights
state_dict = torch.load(checkpoint_path)
state_dict.pop("mask_token", None)
state_dict["reg_token"] = state_dict.pop("register_tokens")
if "reg_token" in state_dict:
state_dict["reg_token"] = state_dict.pop("register_tokens")
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]

img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
patch_size = (self.patch_size, self.patch_size) if isinstance(self.patch_size, int) else self.patch_size
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
patch_size = (self.patch_size, self.patch_size)
if state_dict["pos_embed"].shape != self.pos_embed.shape:
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
self.load_state_dict(state_dict, strict=False)
else:
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
Expand Down Expand Up @@ -401,6 +419,137 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:

return self.pos_drop(x)

def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""Interpolates the positional encoding to match the input dimensions.
Args:
x (torch.Tensor): Input tensor.
w (int): Width of the input image.
h (int): Height of the input image.
Returns:
torch.Tensor: Tensor with interpolated positional encoding.
"""
previous_dtype = x.dtype
npatch = x.shape[1]
n = self.pos_embed.shape[1]
if npatch == n and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
m = int(math.sqrt(n)) # Recover the number of patches in each dimension
if m * m != n:
msg = f"Expected m * m to equal n, but got m={m}, n={n}"
raise ValueError(msg)
kwargs = {}
if self.interpolate_offset:
# fix float error by introducing small offset
sx = float(w0 + self.interpolate_offset) / m
sy = float(h0 + self.interpolate_offset) / m
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2),
mode="bicubic",
**kwargs,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor:
"""Prepare tokens with optional masks.
Args:
x (torch.Tensor): Input tensor.
masks (torch.Tensor | None): Optional masks tensor.
Returns:
torch.Tensor: Tensor with prepared tokens.
"""
_, _, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)

if self.reg_token is not None:
x = torch.cat(
(
x[:, :1],
self.reg_token.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)

return x

def _get_intermediate_layers_not_chunked(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]:
"""Get intermediate layers without chunking.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
Returns:
list[torch.Tensor]: List of intermediate layer outputs.
"""
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
if len(output) != len(blocks_to_take):
msg = f"only {len(output)} / {len(blocks_to_take)} blocks found"
raise RuntimeError(msg)
return output

def get_intermediate_layers(
self,
x: torch.Tensor,
n: int = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm: bool = True,
) -> tuple:
"""Get intermediate layers of the VisionTransformer.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
reshape (bool): Whether to reshape the output feature maps.
return_class_token (bool): Whether to return the class token.
norm (bool): Whether to apply normalization to the outputs.
Returns:
tuple: A tuple containing the intermediate layer outputs.
"""
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_reg_tokens :] for out in outputs]
if reshape:
b, _, w, h = x.shape
outputs = [
out.reshape(b, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)

def forward(
self,
x: torch.Tensor,
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:

return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
10 changes: 4 additions & 6 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
in_channels = 960 if self.mode == "large" else 576

return HLabelClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.core.data.entity.classification import (
Expand Down Expand Up @@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.in_features,
**head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
Loading

0 comments on commit ef32906

Please sign in to comment.