Skip to content

Commit

Permalink
Move semisl helpers to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Mar 15, 2023
1 parent 8990127 commit b811539
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 53 deletions.
56 changes: 3 additions & 53 deletions otx/mpa/modules/models/heads/semisl_multilabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
from mmcls.models.builder import HEADS, build_loss
from torch import nn

from otx.mpa.modules.models.heads.custom_multi_label_linear_cls_head import (
CustomMultiLabelLinearClsHead,
Expand All @@ -13,6 +12,8 @@
CustomMultiLabelNonLinearClsHead,
)

from .utils import generate_aux_mlp, LossBalancer


class SemiMultilabelClsHead:
"""Multilabel Classification head for Semi-SL.
Expand Down Expand Up @@ -186,55 +187,4 @@ def __init__(
def forward_train(self, x, gt_label):
return self.forward_train_with_last_layers(
x, gt_label, final_cls_layer=self.classifier, final_emb_layer=self.aux_mlp
)


def generate_aux_mlp(aux_mlp_cfg: dict, in_channels: int):
out_channels = aux_mlp_cfg["out_channels"]
if out_channels <= 0:
raise ValueError(f"out_channels={out_channels} must be a positive integer")
if "hid_channels" in aux_mlp_cfg and aux_mlp_cfg["hid_channels"] > 0:
hid_channels = aux_mlp_cfg["hid_channels"]
mlp = nn.Sequential(
nn.Linear(in_features=in_channels, out_features=hid_channels),
nn.ReLU(inplace=True),
nn.Linear(in_features=hid_channels, out_features=out_channels),
)
else:
mlp = nn.Linear(in_features=in_channels, out_features=out_channels)

return mlp


class EMAMeter:
def __init__(self, alpha=0.9):
self.alpha = alpha
self.reset()

def reset(self):
self.val = 0

def update(self, val):
self.val = self.alpha * self.val + (1 - self.alpha) * val


class LossBalancer:
def __init__(self, num_losses, weights=None, ema_weight=0.7) -> None:
self.EPS = 1e-9
self.avg_estimators = [EMAMeter(ema_weight) for _ in range(num_losses)]

if weights is not None:
assert len(weights) == num_losses
self.final_weights = weights
else:
self.final_weights = [1.0] * num_losses

def balance_losses(self, losses):
total_loss = 0.0
for i, l in enumerate(losses):
self.avg_estimators[i].update(float(l))
total_loss += (
self.final_weights[i] * l / (self.avg_estimators[i].val + self.EPS) * self.avg_estimators[0].val
)

return total_loss
)
56 changes: 56 additions & 0 deletions otx/mpa/modules/models/heads/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT
#

from torch import nn


def generate_aux_mlp(aux_mlp_cfg: dict, in_channels: int):
out_channels = aux_mlp_cfg["out_channels"]
if out_channels <= 0:
raise ValueError(f"out_channels={out_channels} must be a positive integer")
if "hid_channels" in aux_mlp_cfg and aux_mlp_cfg["hid_channels"] > 0:
hid_channels = aux_mlp_cfg["hid_channels"]
mlp = nn.Sequential(
nn.Linear(in_features=in_channels, out_features=hid_channels),
nn.ReLU(inplace=True),
nn.Linear(in_features=hid_channels, out_features=out_channels),
)
else:
mlp = nn.Linear(in_features=in_channels, out_features=out_channels)

return mlp


class EMAMeter:
def __init__(self, alpha=0.9):
self.alpha = alpha
self.reset()

def reset(self):
self.val = 0

def update(self, val):
self.val = self.alpha * self.val + (1 - self.alpha) * val


class LossBalancer:
def __init__(self, num_losses, weights=None, ema_weight=0.7) -> None:
self.EPS = 1e-9
self.avg_estimators = [EMAMeter(ema_weight) for _ in range(num_losses)]

if weights is not None:
assert len(weights) == num_losses
self.final_weights = weights
else:
self.final_weights = [1.0] * num_losses

def balance_losses(self, losses):
total_loss = 0.0
for i, l in enumerate(losses):
self.avg_estimators[i].update(float(l))
total_loss += (
self.final_weights[i] * l / (self.avg_estimators[i].val + self.EPS) * self.avg_estimators[0].val
)

return total_loss

0 comments on commit b811539

Please sign in to comment.