Skip to content

Commit

Permalink
update class_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo committed Mar 19, 2022
1 parent bf094a8 commit 2f70b2d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 30 deletions.
15 changes: 8 additions & 7 deletions configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
_base_ = [
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
]

num_things_classes = 80
num_stuff_classes = 53
num_classes = num_things_classes + num_stuff_classes
model = dict(
type='MaskFormer',
backbone=dict(
Expand All @@ -19,8 +21,8 @@
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
feat_channels=256,
out_channels=256,
num_things_classes=80,
num_stuff_classes=53,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
pixel_decoder=dict(
type='TransformerEncoderPixelDecoder',
Expand Down Expand Up @@ -87,11 +89,10 @@
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
reduction='mean',
class_weight=1.0),
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='FocalLoss',
use_sigmoid=True,
Expand All @@ -109,8 +110,8 @@
loss_weight=1.0)),
panoptic_fusion_head=dict(
type='MaskFormerFusionHead',
num_things_classes=80,
num_stuff_classes=53,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
loss_panoptic=None,
init_cfg=None),
train_cfg=dict(
Expand Down
26 changes: 3 additions & 23 deletions mmdet/models/dense_heads/maskformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def __init__(self,
positional_encoding=None,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
class_weight=1.0),
class_weight=[1.0] * 133 + [0.1]),
loss_mask=dict(
type='FocalLoss',
use_sigmoid=True,
Expand Down Expand Up @@ -124,25 +123,7 @@ def __init__(self,
sampler_cfg = dict(type='MaskPseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)

self.bg_cls_weight = 0
class_weight = loss_cls.get('class_weight', None)
if class_weight is not None and (self.__class__ is MaskFormerHead):
assert isinstance(class_weight, float), 'Expected ' \
'class_weight to have type float. Found ' \
f'{type(class_weight)}.'
# NOTE following the official MaskFormerHead repo, bg_cls_weight
# means relative classification weight of the VOID class.
bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
assert isinstance(bg_cls_weight, float), 'Expected ' \
'bg_cls_weight to have type float. Found ' \
f'{type(bg_cls_weight)}.'
class_weight = torch.ones(self.num_classes + 1) * class_weight
# set VOID class as the last indice
class_weight[self.num_classes] = bg_cls_weight
loss_cls.update({'class_weight': class_weight})
if 'bg_cls_weight' in loss_cls:
loss_cls.pop('bg_cls_weight')
self.bg_cls_weight = bg_cls_weight
self.class_weight = loss_cls.class_weight
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
Expand Down Expand Up @@ -384,8 +365,7 @@ def loss_single(self, cls_scores, mask_preds, gt_labels_list,
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)

class_weight = cls_scores.new_ones(self.num_classes + 1)
class_weight[-1] = self.bg_cls_weight
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(
cls_scores,
labels,
Expand Down

0 comments on commit 2f70b2d

Please sign in to comment.