From 2ad5723f74dedf2de52f302d35f550042191d558 Mon Sep 17 00:00:00 2001 From: luochunhua Date: Sun, 27 Feb 2022 15:35:18 +0000 Subject: [PATCH] [Feature] Add Mask2Former to mmdet update doc update doc format deepcopy pixel_decoder cfg move mask_pseudo_sampler cfg to config file move part of postprocess from head to detector fix bug in postprocessing move class setting from head to config file remove if else move mask2bbox to mask/util update docstring update docstring in result2json fix bug update class_weight add maskformer_fusion_head add maskformer fusion head update add cfg for filter_low_score update maskformer update class_weight update config update unit test rename param update comments in config rename variable, rm arg, update unit tests update mask2bbox add unit test for mask2bbox replace unsqueeze(1) and squeeze(1) add unit test for maskformer_fusion_head update docstrings update docstring delete \ remove modification to ce loss update docstring update docstring update docstring of ce loss update unit test update docstring update docstring update docstring rename rename add msdeformattn pixel decoder maskformer refactor add strides in config remove redundant code remove redundant code update unit test update config update --- .../mask2former_r50_lsj_8x2_50e_coco.py | 253 +++++++++++ ...ormer_swin-t-p4-w7-224_lsj_8x2_50e_coco.py | 62 +++ mmdet/core/bbox/match_costs/__init__.py | 6 +- mmdet/core/bbox/match_costs/match_cost.py | 65 +++ mmdet/datasets/coco.py | 3 - mmdet/datasets/coco_panoptic.py | 17 +- mmdet/models/dense_heads/__init__.py | 4 +- mmdet/models/dense_heads/mask2former_head.py | 430 ++++++++++++++++++ mmdet/models/detectors/__init__.py | 3 +- mmdet/models/detectors/mask2former.py | 27 ++ .../test_dense_heads/test_mask2former_head.py | 216 +++++++++ tests/test_models/test_forward.py | 111 +++++ tests/test_utils/test_assigner.py | 24 + 13 files changed, 1212 insertions(+), 9 deletions(-) create mode 100644 configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py create mode 100644 configs/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py create mode 100644 mmdet/models/dense_heads/mask2former_head.py create mode 100644 mmdet/models/detectors/mask2former.py create mode 100644 tests/test_models/test_dense_heads/test_mask2former_head.py diff --git a/configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py b/configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py new file mode 100644 index 00000000000..54d138fce3f --- /dev/null +++ b/configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py @@ -0,0 +1,253 @@ +_base_ = [ + '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py' +] +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes +model = dict( + type='Mask2Former', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + panoptic_head=dict( + type='Mask2FormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + panoptic_fusion_head=dict( + type='MaskFormerFusionHead', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=None) + +# dataset settings +image_size = (1024, 1024) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile', to_float32=True), + dict( + type='LoadPanopticAnnotations', + with_bbox=True, + with_mask=True, + with_seg=True), + dict(type='RandomFlip', flip_ratio=0.5), + # large scale jittering + dict( + type='Resize', + img_scale=image_size, + ratio_range=(0.1, 2.0), + multiscale_mode='range', + keep_ratio=True), + dict( + type='RandomCrop', + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=image_size), + dict(type='DefaultFormatBundle', img_to_float=True), + dict( + type='Collect', + keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data_root = 'data/coco/' +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict(pipeline=train_pipeline), + val=dict( + pipeline=test_pipeline, + ins_ann_file=data_root + 'annotations/instances_val2017.json', + ), + test=dict( + pipeline=test_pipeline, + ins_ann_file=data_root + 'annotations/instances_val2017.json', + )) + +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +# optimizer +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': embed_multi, + 'query_feat': embed_multi, + 'level_embed': embed_multi, + }, + norm_decay_mult=0.0)) +optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + gamma=0.1, + by_epoch=False, + step=[327778, 355092], + warmup='linear', + warmup_by_epoch=False, + warmup_ratio=1.0, # no warmup + warmup_iters=10) + +max_iters = 368750 +runner = dict(type='IterBasedRunner', max_iters=max_iters) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + dict(type='TensorboardLoggerHook', by_epoch=False) + ]) +interval = 200000 +workflow = [('train', interval)] +checkpoint_config = dict( + by_epoch=False, interval=interval, save_last=True, max_keep_ckpts=3) + +# Before 365001th iteration, we do evaluation every 200000 iterations. +# After 365000th iteration, we do evaluation every 368750 iterations, +# which means do evaluation at the end of training. +# In all, we do evaluation at the 200000th iteration and the +# last iteratoin. +dynamic_intervals = [(max_iters // interval * interval + 1, max_iters)] +evaluation = dict( + interval=interval, dynamic_intervals=dynamic_intervals, metric='PQ') diff --git a/configs/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py b/configs/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py new file mode 100644 index 00000000000..70e3103e482 --- /dev/null +++ b/configs/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py @@ -0,0 +1,62 @@ +_base_ = ['./mask2former_r50_lsj_8x2_50e_coco.py'] +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa + +depths = [2, 2, 6, 2] +model = dict( + type='Mask2Former', + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dims=96, + depths=depths, + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(0, 1, 2, 3), + with_cp=False, + convert_weights=True, + frozen_stages=-1, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + panoptic_head=dict( + type='Mask2FormerHead', in_channels=[96, 192, 384, 768]), + init_cfg=None) + +# set all layers in backbone to lr_mult=0.1 +# set all norm layers, position_embeding, +# query_embeding, level_embeding to decay_multi=0.0 +backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0) +backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0) +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +custom_keys = { + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'backbone.patch_embed.norm': backbone_norm_multi, + 'backbone.norm': backbone_norm_multi, + 'absolute_pos_embed': backbone_embed_multi, + 'relative_position_bias_table': backbone_embed_multi, + 'query_embed': embed_multi, + 'query_feat': embed_multi, + 'level_embed': embed_multi +} +custom_keys.update({ + f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi + for stage_id, num_blocks in enumerate(depths) + for block_id in range(num_blocks) +}) +custom_keys.update({ + f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi + for stage_id in range(len(depths) - 1) +}) +# optimizer +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0)) diff --git a/mmdet/core/bbox/match_costs/__init__.py b/mmdet/core/bbox/match_costs/__init__.py index 81ee588571e..1b636795082 100644 --- a/mmdet/core/bbox/match_costs/__init__.py +++ b/mmdet/core/bbox/match_costs/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_match_cost -from .match_cost import (BBoxL1Cost, ClassificationCost, DiceCost, - FocalLossCost, IoUCost) +from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost, + DiceCost, FocalLossCost, IoUCost) __all__ = [ 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost', - 'FocalLossCost', 'DiceCost' + 'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost' ] diff --git a/mmdet/core/bbox/match_costs/match_cost.py b/mmdet/core/bbox/match_costs/match_cost.py index 3c0a164b3c8..7ac0ad0f6df 100644 --- a/mmdet/core/bbox/match_costs/match_cost.py +++ b/mmdet/core/bbox/match_costs/match_cost.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn.functional as F from mmdet.core.bbox.iou_calculators import bbox_overlaps from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh @@ -281,3 +282,67 @@ def __call__(self, mask_preds, gt_masks): mask_preds = mask_preds.sigmoid() dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + Examples: + >>> from mmdet.core.bbox.match_costs import CrossEntropyLossCost + >>> import torch + >>> bce = CrossEntropyLossCost(use_sigmoid=True) + >>> cls_pred = torch.tensor([[7.6, 1.2], [-1.3, 10]]) + >>> gt_labels = torch.tensor([[1, 1], [1, 0]]) + >>> print(bce(cls_pred, gt_labels)) + """ + + def __init__(self, weight=1., use_sigmoid=True): + assert use_sigmoid, 'use_sigmoid = False is not supported yet.' + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 46e3a6cbdd6..bcdd4df3981 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -405,9 +405,6 @@ def evaluate_det_segm(self, 'bbox', 'segm', 'proposal', 'proposal_fast'. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. - jsonfile_prefix (str | None): The prefix of json files. It includes - the file path and the prefix of filename, e.g., "a/b/prefix". - If not specified, a temp file will be created. Default: None. classwise (bool): Whether to evaluating the AP for each class. proposal_nums (Sequence[int]): Proposal number used for evaluating recalls, such as recall@100, recall@1000. diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py index 7afc077cc03..53ef5947d1e 100644 --- a/mmdet/datasets/coco_panoptic.py +++ b/mmdet/datasets/coco_panoptic.py @@ -457,8 +457,20 @@ def results2json(self, results, outfile_prefix): different data types. This method will automatically recognize the type, and dump them to json files. + .. code-block:: none + + [ + { + 'pan_results': np.array, # shape (h, w) + # ins_results which includes bboxes and RLE encoded masks + # is optional. + 'ins_results': (list[np.array], list[list[str]]) + }, + ... + ] + Args: - results (dict): Testing results of the dataset. + results (list[dict]): Testing results of the dataset. outfile_prefix (str): The filename prefix of the json files. If the prefix is "somepath/xxx", the json files will be named "somepath/xxx.panoptic.json", "somepath/xxx.bbox.json", @@ -597,6 +609,7 @@ def evaluate(self, if 'PQ' in metrics: eval_pan_results = self.evaluate_pan_json( result_files, outfile_prefix, logger, classwise, nproc=nproc) + eval_results.update(eval_pan_results) metrics.remove('PQ') @@ -611,11 +624,13 @@ def evaluate(self, 'shuold not be None' coco_gt = COCO(self.ins_ann_file) + panoptic_cat_ids = self.cat_ids self.cat_ids = coco_gt.get_cat_ids(cat_names=self.THING_CLASSES) eval_ins_results = self.evaluate_det_segm(results, result_files, coco_gt, metrics, logger, classwise, **kwargs) + self.cat_ids = panoptic_cat_ids eval_results.update(eval_ins_results) if tmp_dir is not None: diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index e931e608028..375197a6987 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -20,6 +20,7 @@ from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .lad_head import LADHead from .ld_head import LDHead +from .mask2former_head import Mask2FormerHead from .maskformer_head import MaskFormerHead from .nasfcos_head import NASFCOSHead from .paa_head import PAAHead @@ -50,5 +51,6 @@ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', - 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead' + 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead', + 'Mask2FormerHead' ] diff --git a/mmdet/models/dense_heads/mask2former_head.py b/mmdet/models/dense_heads/mask2former_head.py new file mode 100644 index 00000000000..78e4d49bbd8 --- /dev/null +++ b/mmdet/models/dense_heads/mask2former_head.py @@ -0,0 +1,430 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.ops import point_sample +from mmcv.runner import ModuleList + +from mmdet.core import build_assigner, build_sampler, reduce_mean +from mmdet.models.utils import get_uncertain_point_coords_with_randomness +from ..builder import HEADS, build_loss +from .anchor_free_head import AnchorFreeHead +from .maskformer_head import MaskFormerHead + + +@HEADS.register_module() +class Mask2FormerHead(MaskFormerHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs): + super(AnchorFreeHead, self).__init__(init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers.\ + attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.\ + attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding( + positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, + gt_labels, gt_points_masks, + img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, + gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 9f05a282c18..5f2b3088de4 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -17,6 +17,7 @@ from .htc import HybridTaskCascade from .kd_one_stage import KnowledgeDistillationSingleStageDetector from .lad import LAD +from .mask2former import Mask2Former from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .maskformer import MaskFormer @@ -51,5 +52,5 @@ 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD', - 'MaskFormer' + 'MaskFormer', 'Mask2Former' ] diff --git a/mmdet/models/detectors/mask2former.py b/mmdet/models/detectors/mask2former.py new file mode 100644 index 00000000000..b9ad2ed25d3 --- /dev/null +++ b/mmdet/models/detectors/mask2former.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import DETECTORS +from .maskformer import MaskFormer + + +@DETECTORS.register_module() +class Mask2Former(MaskFormer): + r"""Implementation of `Masked-attention Mask + Transformer for Universal Image Segmentation + `_.""" + + def __init__(self, + backbone, + neck=None, + panoptic_head=None, + panoptic_fusion_head=None, + train_cfg=None, + test_cfg=None, + init_cfg=None): + super().__init__( + backbone, + neck=neck, + panoptic_head=panoptic_head, + panoptic_fusion_head=panoptic_fusion_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) diff --git a/tests/test_models/test_dense_heads/test_mask2former_head.py b/tests/test_models/test_dense_heads/test_mask2former_head.py new file mode 100644 index 00000000000..66d144301b2 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_mask2former_head.py @@ -0,0 +1,216 @@ +import numpy as np +import torch +from mmcv import ConfigDict + +from mmdet.core.mask import BitmapMasks +from mmdet.models.dense_heads import Mask2FormerHead + + +def test_mask2former_head_loss(): + """Tests head loss when truth is empty and non-empty.""" + base_channels = 64 + img_metas = [{ + 'batch_input_shape': (128, 160), + 'img_shape': (126, 160, 3), + 'ori_shape': (63, 80, 3) + }, { + 'batch_input_shape': (128, 160), + 'img_shape': (120, 160, 3), + 'ori_shape': (60, 80, 3) + }] + feats = [ + torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) + for i in range(4) + ] + num_things_classes = 80 + num_stuff_classes = 53 + num_classes = num_things_classes + num_stuff_classes + config = ConfigDict( + dict( + type='Mask2FormerHead', + in_channels=[base_channels * 2**i for i in range(4)], + feat_channels=base_channels, + out_channels=base_channels, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=base_channels, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=base_channels, + feedforward_channels=base_channels * 4, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + feedforward_channels=base_channels * 4, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=base_channels, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=base_channels, + feedforward_channels=base_channels * 8, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + # the following parameter was not used, + # just make current api happy + feedforward_channels=base_channels * 8, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + train_cfg=dict( + num_points=256, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', + weight=5.0, + use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=True, + max_dets_per_image=100, + object_mask_thr=0.8, + iou_thr=0.8))) + self = Mask2FormerHead(**config) + self.init_weights() + all_cls_scores, all_mask_preds = self.forward(feats, img_metas) + # Test that empty ground truth encourages the network to predict background + gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])] + gt_masks_list = [ + torch.zeros((0, 128, 160)).long(), + torch.zeros((0, 128, 160)).long() + ] + + empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas) + # When there is no truth, the cls loss should be nonzero but there should + # be no mask loss. + for key, loss in empty_gt_losses.items(): + if 'cls' in key: + assert loss.item() > 0, 'cls loss should be non-zero' + elif 'mask' in key: + assert loss.item( + ) == 0, 'there should be no mask loss when there are no true mask' + elif 'dice' in key: + assert loss.item( + ) == 0, 'there should be no dice loss when there are no true mask' + + # when truth is non-empty then both cls, mask, dice loss should be nonzero + # random inputs + gt_labels_list = [ + torch.tensor([10, 100]).long(), + torch.tensor([100, 10]).long() + ] + mask1 = torch.zeros((2, 128, 160)).long() + mask1[0, :50] = 1 + mask1[1, 50:] = 1 + mask2 = torch.zeros((2, 128, 160)).long() + mask2[0, :, :50] = 1 + mask2[1, :, 50:] = 1 + gt_masks_list = [mask1, mask2] + two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas) + for loss in two_gt_losses.values(): + assert loss.item() > 0, 'all loss should be non-zero' + + # test forward_train + gt_bboxes = None + gt_labels = [ + torch.tensor([10]).long(), + torch.tensor([10]).long(), + ] + thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask1[0, :50] = 1 + thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask2[0, :, 50:] = 1 + gt_masks = [ + BitmapMasks(thing_mask1, 128, 160), + BitmapMasks(thing_mask2, 128, 160), + ] + stuff_mask1 = torch.zeros((1, 128, 160)).long() + stuff_mask1[0, :50] = 10 + stuff_mask1[0, 50:] = 100 + stuff_mask2 = torch.zeros((1, 128, 160)).long() + stuff_mask2[0, :, 50:] = 10 + stuff_mask2[0, :, :50] = 100 + gt_semantic_seg = [stuff_mask1, stuff_mask2] + + self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks, + gt_semantic_seg) + + # test inference mode + self.simple_test(feats, img_metas) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index 6b28ba61514..3e5f80ba80f 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -811,3 +811,114 @@ def test_maskformer_forward(): rescale=True, return_loss=False) batch_results.append(result) + + +def test_mask2former_forward(): + model_cfg = _get_detector_cfg( + 'mask2former/mask2former_r50_lsj_8x2_50e_coco.py') + base_channels = 32 + model_cfg.backbone.depth = 18 + model_cfg.backbone.init_cfg = None + model_cfg.backbone.base_channels = base_channels + model_cfg.panoptic_head.in_channels = [ + base_channels * 2**i for i in range(4) + ] + model_cfg.panoptic_head.feat_channels = base_channels + model_cfg.panoptic_head.out_channels = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.attn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.ffn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 4 + model_cfg.panoptic_head.pixel_decoder.\ + positional_encoding.num_feats = base_channels // 2 + model_cfg.panoptic_head.positional_encoding.\ + num_feats = base_channels // 2 + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.attn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.ffn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8 + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.feedforward_channels = base_channels * 8 + + from mmdet.core import BitmapMasks + from mmdet.models import build_detector + detector = build_detector(model_cfg) + + # Test forward train with non-empty truth batch + detector.train() + img_metas = [ + { + 'batch_input_shape': (128, 160), + 'img_shape': (126, 160, 3), + 'ori_shape': (63, 80, 3), + 'pad_shape': (128, 160, 3) + }, + ] + img = torch.rand((1, 3, 128, 160)) + gt_bboxes = None + gt_labels = [ + torch.tensor([10]).long(), + ] + thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask1[0, :50] = 1 + gt_masks = [ + BitmapMasks(thing_mask1, 128, 160), + ] + stuff_mask1 = torch.zeros((1, 128, 160)).long() + stuff_mask1[0, :50] = 10 + stuff_mask1[0, 50:] = 100 + gt_semantic_seg = [ + stuff_mask1, + ] + losses = detector.forward( + img=img, + img_metas=img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_masks=gt_masks, + gt_semantic_seg=gt_semantic_seg, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward train with an empty truth batch + gt_bboxes = [ + torch.empty((0, 4)).float(), + ] + gt_labels = [ + torch.empty((0, )).long(), + ] + mask = np.zeros((0, 128, 160), dtype=np.uint8) + gt_masks = [ + BitmapMasks(mask, 128, 160), + ] + gt_semantic_seg = [ + torch.randint(0, 133, (0, 128, 160)), + ] + losses = detector.forward( + img, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_masks=gt_masks, + gt_semantic_seg=gt_semantic_seg, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward test + detector.eval() + with torch.no_grad(): + img_list = [g[None, :] for g in img] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + rescale=True, + return_loss=False) + batch_results.append(result) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 7728510b166..c40584a50fe 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -606,3 +606,27 @@ def test_mask_hungarian_match_assigner(): assert torch.all(assign_result.gt_inds > -1) assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) assert (assign_result.labels > -1).sum() == gt_labels.size(0) + + # test with mask bce mode + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=0.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=1.0, use_sigmoid=True), + dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) + assert (assign_result.labels > -1).sum() == gt_labels.size(0) + + # test with mask ce mode + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=0.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=1.0, use_sigmoid=False), + dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + with pytest.raises(NotImplementedError): + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta)