Skip to content

Commit

Permalink
add mask2former
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo committed Feb 23, 2022
1 parent cac3563 commit d1b4585
Show file tree
Hide file tree
Showing 18 changed files with 2,144 additions and 48 deletions.
236 changes: 236 additions & 0 deletions configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
_base_ = [
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
]
# TODO based on ../maskformer/maskformer_r50_mstrain_64x1_300e_coco.py
model = dict(
type='Mask2Former',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
panoptic_head=dict(
type='Mask2FormerHead',
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
feat_channels=256,
out_channels=256,
num_things_classes=80,
num_stuff_classes=53,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_return_feat_levels=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
feedforward_channels=1024,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
# the following parameter was not used,
# just make current api happy
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=1.0),
loss_mask=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
bce_expand_one_hot=False,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0)),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=2.0),
mask_cost=dict(
type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
dice_cost=dict(
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)),
sampler=dict(type='MaskPseudoSampler')),
test_cfg=dict(
panoptic_on=True,
# the dataset doesn't support the evaluation of semantic segmentation.
semantic_on=False,
instance_on=True,
max_dets_per_image=100,
object_mask_thr=0.8,
iou_thr=0.8),
init_cfg=None)

# dataset settings
image_size = (1024, 1024)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(
type='LoadPanopticAnnotations',
with_bbox=True,
with_mask=True,
with_seg=True),
dict(type='RandomFlip', flip_ratio=0.5),
# large scale jittering
dict(
type='Resize',
img_scale=image_size,
ratio_range=(0.1, 2.0),
multiscale_mode='range',
keep_ratio=True),
dict(
type='RandomCrop',
crop_size=image_size,
crop_type='absolute',
recompute_bbox=True,
allow_negative_crop=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=image_size),
dict(type='DefaultFormatBundle', img_to_float=True),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data_root = 'data/coco/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(
pipeline=test_pipeline,
ins_ann_file=data_root + 'annotations/instances_val2017.json',
),
test=dict(
pipeline=test_pipeline,
ins_ann_file=data_root + 'annotations/instances_val2017.json',
))

embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0))
optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2))

# learning policy
lr_config = dict(
policy='step',
gamma=0.1,
by_epoch=False,
step=[327778, 355092],
warmup='linear',
warmup_by_epoch=False,
warmup_ratio=1.0, # no warmup
warmup_iters=10)

runner = dict(type='IterBasedRunner', max_iters=368750)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False)
])
workflow = [('train', 5000)]
checkpoint_config = dict(by_epoch=False, interval=5000)
# TODO add metric segm
evaluation = dict(interval=5000, metric='PQ')
62 changes: 62 additions & 0 deletions configs/mask2former/mask2former_swin-t-p4-w7_lsj_8x2_50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = ['./mask2former_r50_lsj_8x2_50e_coco.py']
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa

depths = [2, 2, 6, 2]
model = dict(
type='Mask2Former',
backbone=dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
panoptic_head=dict(
type='Mask2FormerHead', in_channels=[96, 192, 384, 768]),
init_cfg=None)

# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
12 changes: 12 additions & 0 deletions mmdet/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def single_gpu_test(model,
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (bbox_results,
encode_mask_results(mask_results))

results.extend(result)

for _ in range(batch_size):
Expand Down Expand Up @@ -104,6 +110,12 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (
bbox_results, encode_mask_results(mask_results))

results.extend(result)

if rank == 0:
Expand Down
6 changes: 3 additions & 3 deletions mmdet/core/bbox/match_costs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_match_cost
from .match_cost import (BBoxL1Cost, ClassificationCost, DiceCost,
FocalLossCost, IoUCost)
from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost,
DiceCost, FocalLossCost, IoUCost)

__all__ = [
'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
'FocalLossCost', 'DiceCost'
'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost'
]
Loading

0 comments on commit d1b4585

Please sign in to comment.