diff --git a/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..190eb71 --- /dev/null +++ b/configs/detection/tienet/base_detector/atss_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,72 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ATSS', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='ATSSHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py similarity index 96% rename from configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py rename to configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py index 2cb93f6..a28ff6f 100644 --- a/configs/detection/uod_air/faster-rcnn_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/tienet/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -1,6 +1,6 @@ _base_ = [ - '../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' ] # model settings diff --git a/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py b/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py new file mode 100644 index 0000000..8f27672 --- /dev/null +++ b/configs/detection/tienet/base_detector/fcos_r50-caffe_fpn_gn-head_1x_urpc-coco.py @@ -0,0 +1,77 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FCOS', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[102.9801, 115.9465, 122.7717], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron/resnet50_caffe')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='FCOSHead', + num_classes=4, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # testing settings + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + optimizer=dict(lr=0.01), + paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad diff --git a/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..a229916 --- /dev/null +++ b/configs/detection/tienet/base_detector/paa_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,93 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='PAA', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='PAAHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..66cd58f --- /dev/null +++ b/configs/detection/tienet/base_detector/tood_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,81 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='TOOD', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='TOODHead', + num_classes=4, + in_channels=256, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) diff --git a/configs/detection/tienet/base_editor/tienet_enhance_model.py b/configs/detection/tienet/base_editor/tienet_enhance_model.py index 70cdd39..dc8f5bb 100644 --- a/configs/detection/tienet/base_editor/tienet_enhance_model.py +++ b/configs/detection/tienet/base_editor/tienet_enhance_model.py @@ -31,7 +31,7 @@ type='StructureFFTLoss', radius=4, pass_type='high', - channel_mean=True, + channel_mean=False, loss_type='mse', guid_filter=dict( type='GuidedFilter2d', radius=32, eps=1e-4, fast_s=2), diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py new file mode 100644 index 0000000..17910b9 --- /dev/null +++ b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_4xbs4_urpc-coco.py @@ -0,0 +1,3 @@ +_base_ = ['./tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py'] + +train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..f503b19 --- /dev/null +++ b/configs/detection/tienet/tienet_faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,35 @@ +# default scope is mmdet +_base_ = [ + './base_editor/tienet_enhance_model.py', + './base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py' +] + +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceModel', + detector={{_base_.model}}, + enhance_model={{_base_.enhance_model}}, + train_mode='enhance', + pred_mode='enhance', + detach_enhance_img=False) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +model_wrapper_cfg = dict( + type='lqit.SelfEnhanceModelDDP', + broadcast_buffers=False, + find_unused_parameters=False) diff --git a/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py new file mode 100644 index 0000000..3c5ab3b --- /dev/null +++ b/configs/detection/tienet/tienet_retinanet_r50_fpn_1x_4xbs4_urpc-coco.py @@ -0,0 +1,3 @@ +_base_ = ['./tienet_retinanet_r50_fpn_1x_urpc-coco.py'] + +train_dataloader = dict(batch_size=4, num_workers=4) diff --git a/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..512d9d7 --- /dev/null +++ b/configs/detection/uod_air/base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='FasterRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='lqit.UFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + add_extra_convs='on_output', + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=4, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py similarity index 90% rename from configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py rename to configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py index eef8d7a..9dd042d 100644 --- a/configs/detection/uod_air/retinanet_r50_fpn_1x_urpc-coco.py +++ b/configs/detection/uod_air/base_detector/retinanet_r50_ufpn_1x_urpc-coco.py @@ -1,6 +1,6 @@ _base_ = [ - '../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' + '../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' ] # model settings @@ -23,12 +23,12 @@ style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( - type='FPN', + type='lqit.UFPN', in_channels=[256, 512, 1024, 2048], out_channels=256, - start_level=1, - add_extra_convs='on_input', - num_outs=5), + start_level=0, + add_extra_convs='on_output', + num_outs=6), bbox_head=dict( type='RetinaHead', num_classes=4, diff --git a/configs/detection/uod_air/base_ehance_head/enhance_head.py b/configs/detection/uod_air/base_ehance_head/enhance_head.py new file mode 100644 index 0000000..0d308f2 --- /dev/null +++ b/configs/detection/uod_air/base_ehance_head/enhance_head.py @@ -0,0 +1,14 @@ +enhance_head = dict( + _scope_='lqit', + type='BasicEnhanceHead', + in_channels=256, + feat_channels=256, + num_convs=2, + loss_enhance=dict(type='L1Loss', loss_weight=1.0), + gt_preprocessor=dict( + type='GTPixelPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + element_name='img')) diff --git a/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py b/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py new file mode 100644 index 0000000..c8a8d94 --- /dev/null +++ b/configs/detection/uod_air/uod-air_faster-rcnn_r50_fpn_1x_urpc-coco.py @@ -0,0 +1,23 @@ +_base_ = [ + './base_detector/faster-rcnn_r50_fpn_1x_urpc-coco.py', + './base_ehance_head/enhance_head.py' +] + +# model settings +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceHead', + detector={{_base_.model}}, + enhance_head={{_base_.enhance_head}}, + vis_enhance=False) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py new file mode 100644 index 0000000..2b03a40 --- /dev/null +++ b/configs/detection/uod_air/uod-air_retinanet_r50_ufpn_1x_urpc-coco.py @@ -0,0 +1,41 @@ +_base_ = [ + './base_detector/retinanet_r50_ufpn_1x_urpc-coco.py', + './base_ehance_head/enhance_head.py' +] + +# model settings +model = dict( + _delete_=True, + type='lqit.DetectorWithEnhanceHead', + detector={{_base_.model}}, + enhance_head={{_base_.enhance_head}}, + vis_enhance=False) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='lqit.SetInputImageAsGT'), + dict(type='lqit.PackInputs') +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] diff --git a/lqit/detection/models/detectors/__init__.py b/lqit/detection/models/detectors/__init__.py index 9e25cb0..b9fe038 100644 --- a/lqit/detection/models/detectors/__init__.py +++ b/lqit/detection/models/detectors/__init__.py @@ -1,3 +1,4 @@ +from .detector_with_enhance_head import DetectorWithEnhanceHead from .detector_with_enhance_model import DetectorWithEnhanceModel from .edffnet import EDFFNet from .multi_input_wrapper import MultiInputDetectorWrapper @@ -6,5 +7,6 @@ __all__ = [ 'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper', - 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel' + 'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel', + 'DetectorWithEnhanceHead' ] diff --git a/lqit/detection/models/detectors/detector_with_enhance_head.py b/lqit/detection/models/detectors/detector_with_enhance_head.py index 6bcf7bc..92e110d 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_head.py +++ b/lqit/detection/models/detectors/detector_with_enhance_head.py @@ -1,78 +1,162 @@ import copy -from typing import Optional +import warnings +from typing import Dict, Optional, Tuple, Union import torch from mmdet.models import SingleStageDetector, TwoStageDetector -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmengine.model import BaseModel from torch import Tensor -from lqit.common.structures import SampleList +from lqit.common.structures import OptSampleList, SampleList from lqit.edit.models import add_pixel_pred_to_datasample from lqit.registry import MODELS +from lqit.utils import ConfigType, OptConfigType, OptMultiConfig +ForwardResults = Union[Dict[str, Tensor], SampleList, Tuple[Tensor], Tensor] -@MODELS.register_module() -class SingleStageWithEnhanceHead(SingleStageDetector): - """Base class for two-stage detectors with enhance head. - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. +@MODELS.register_module() +class DetectorWithEnhanceHead(BaseModel): + """Detector with enhance head. + + Args: + detector (dict or ConfigDict): Config for detector. + enhance_head (dict or ConfigDict, optional): Config for enhance head. + vis_enhance (bool): Whether to visualize the enhanced image during + inference. Defaults to False. + init_cfg (dict or ConfigDict, optional): The config to control the + initialization. Defaults to None. """ def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - bbox_head: OptConfigType = None, + detector: ConfigType, enhance_head: OptConfigType = None, vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) + + # process gt_preprocessor + if enhance_head is not None: + enhance_head = self.process_gt_preprocessor(detector, enhance_head) + + # build data_preprocessor + self.data_preprocessor = MODELS.build(detector['data_preprocessor']) + # build detector + self.detector = MODELS.build(detector) + if isinstance(self.detector, SingleStageDetector): + self.detector_type = 'SingleStage' + elif isinstance(self.detector, TwoStageDetector): + self.detector_type = 'TwoStage' + else: + raise TypeError( + f'Only support SingleStageDetector and TwoStageDetector, ' + f'but got {type(self.detector)}.') + # build enhance head if enhance_head is not None: self.enhance_head = MODELS.build(enhance_head) + else: + self.enhance_head = None + if vis_enhance: + assert self.with_enhance_head self.vis_enhance = vis_enhance + @staticmethod + def process_gt_preprocessor(detector, enhance_head): + """Process the gt_preprocessor of enhance head.""" + data_preprocessor = detector.get('data_preprocessor', None) + data_preprocessor_mean = data_preprocessor['mean'] + data_preprocessor_std = data_preprocessor['std'] + data_preprocessor_bgr_to_rgb = data_preprocessor['bgr_to_rgb'] + data_preprocessor_pad_size_divisor = \ + data_preprocessor['pad_size_divisor'] + + gt_preprocessor = enhance_head.get('gt_preprocessor', None) + gt_preprocessor_mean = gt_preprocessor['mean'] + gt_preprocessor_std = gt_preprocessor['std'] + gt_preprocessor_bgr_to_rgb = gt_preprocessor['bgr_to_rgb'] + gt_preprocessor_pad_size_divisor = gt_preprocessor['pad_size_divisor'] + + if data_preprocessor_mean != gt_preprocessor_mean: + warnings.warn( + 'the `mean` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `mean` of data_preprocessor.') + enhance_head['data_preprocessor']['mean'] = data_preprocessor_mean + if data_preprocessor_std != gt_preprocessor_std: + warnings.warn( + 'the `std` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `std` of data_preprocessor.') + enhance_head['data_preprocessor']['std'] = data_preprocessor_std + if data_preprocessor_bgr_to_rgb != gt_preprocessor_bgr_to_rgb: + warnings.warn( + 'the `bgr_to_rgb` of data_preprocessor and gt_preprocessor' + 'are different, force to use the `bgr_to_rgb` of ' + 'data_preprocessor.') + enhance_head['data_preprocessor']['bgr_to_rgb'] = \ + data_preprocessor_bgr_to_rgb + if data_preprocessor_pad_size_divisor != \ + gt_preprocessor_pad_size_divisor: + warnings.warn('the `pad_size_divisor` of data_preprocessor and ' + 'gt_preprocessor are different, force to use the ' + '`pad_size_divisor` of data_preprocessor.') + enhance_head['data_preprocessor']['pad_size_divisor'] = \ + data_preprocessor_pad_size_divisor + return enhance_head + @property def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" + """Whether has enhance head.""" return hasattr(self, 'enhance_head') and self.enhance_head is not None - def _forward(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> tuple: - """Network forward process. Usually includes backbone, neck and head - forward without any post-processing. + def forward(self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DetDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle either back propagation or + parameter update, which are supposed to be done in :meth:`train_step`. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). + inputs (Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. Returns: - tuple: A tuple of features from ``rpn_head`` and ``roi_head`` - forward. + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. """ - x = self.extract_feat(batch_inputs) - results = self.bbox_head.forward(x) - if self.with_enhance_head: - enhance_outs = self.enhance_head.forward(x) - results = results + (enhance_outs, ) - return results + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') - def loss(self, batch_inputs: Tensor, - batch_data_samples: SampleList) -> dict: - """Calculate losses from a batch of inputs and data samples. + def calculate_det_loss(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + """Calculate detection loss. Args: - batch_inputs (Tensor): Input images of shape (N, C, H, W). - These should usually be mean centered and std scaled. + x (tuple[Tensor]): Tuple of multi-level img features. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. @@ -80,29 +164,53 @@ def loss(self, batch_inputs: Tensor, Returns: dict: A dictionary of loss components """ - x = self.extract_feat(batch_inputs) - - losses = dict() - if self.with_enhance_head: - - enhance_loss = self.enhance_head.loss(x, batch_data_samples) - # avoid loss override - assert not set(enhance_loss.keys()) & set(losses.keys()) - losses.update(enhance_loss) + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + losses = self.detector.bbox_head.loss(x, batch_data_samples) + else: + losses = dict() + # RPN forward and loss + if self.detector.with_rpn: + proposal_cfg = self.detector.train_cfg.get( + 'rpn_proposal', self.detector.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = \ + self.detector.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second + # stage to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + roi_losses = self.detector.roi_head.loss(x, rpn_results_list, + batch_data_samples) + losses.update(roi_losses) - det_losses = self.bbox_head.loss(x, batch_data_samples) - losses.update(det_losses) return losses - def predict(self, - batch_inputs: Tensor, - batch_data_samples: SampleList, - rescale: bool = True) -> SampleList: - """Predict results from a batch of inputs and data samples with post- - processing. + def predict_det_results(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict detection results. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). + x (tuple[Tensor]): Tuple of multi-level img features. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. @@ -110,7 +218,7 @@ def predict(self, Defaults to True. Returns: - list[:obj:`DataSample`]: Return the detection results of the + list[:obj:`DetDataSample`]: Return the detection results of the input images. The returns value is DetDataSample, which usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. @@ -123,58 +231,62 @@ def predict(self, the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ - x = self.extract_feat(batch_inputs) - results_list = self.bbox_head.predict( - x, batch_data_samples, rescale=rescale) - - if self.vis_enhance and self.with_enhance_head: - enhance_list = self.enhance_head.predict( + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + results_list = self.detector.bbox_head.predict( x, batch_data_samples, rescale=rescale) - batch_data_samples = add_pixel_pred_to_datasample( - data_samples=batch_data_samples, pixel_list=enhance_list) - - batch_data_samples = self.add_pred_to_datasample( + else: + assert self.detector.with_bbox, 'Bbox head must be implemented.' + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.detector.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.detector.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale) + + batch_data_samples = self.detector.add_pred_to_datasample( batch_data_samples, results_list) return batch_data_samples + def det_head_forward(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> tuple: + """Forward process of detection head. -@MODELS.register_module() -class TwoStageWithEnhanceHead(TwoStageDetector): - """Base class for two-stage detectors with enhance head. - - Two-stage detectors typically consisting of a region proposal network and a - task-specific regression head. - """ - - def __init__(self, - backbone: ConfigType, - neck: OptConfigType = None, - rpn_head: OptConfigType = None, - roi_head: OptConfigType = None, - enhance_head: OptConfigType = None, - vis_enhance: Optional[bool] = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None) -> None: - super().__init__( - backbone=backbone, - neck=neck, - rpn_head=rpn_head, - roi_head=roi_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - data_preprocessor=data_preprocessor, - init_cfg=init_cfg) - - if enhance_head is not None: - self.enhance_head = MODELS.build(enhance_head) - self.vis_enhance = vis_enhance + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - @property - def with_enhance_head(self) -> bool: - """bool: whether the detector has a RoI head""" - return hasattr(self, 'enhance_head') and self.enhance_head is not None + Returns: + tuple: A tuple of features from detector head (`bbox_head` in + single-stage detector or `rpn_head` and `roi_head` in + two-stage detector). + """ + if len(x) > 5: + x = x[1:] + if self.detector_type == 'SingleStage': + results = self.detector.bbox_head.forward(x) + else: + results = () + if self.detector.with_rpn: + rpn_results_list = self.detector.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + assert batch_data_samples[0].get('proposals', None) is not None + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + roi_outs = self.detector.roi_head.forward(x, rpn_results_list, + batch_data_samples) + results = results + (roi_outs, ) + return results def _forward(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> tuple: @@ -188,24 +300,12 @@ def _forward(self, batch_inputs: Tensor, tuple: A tuple of features from ``rpn_head`` and ``roi_head`` forward. """ - results = () - x = self.extract_feat(batch_inputs) - - if self.with_rpn: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - assert batch_data_samples[0].get('proposals', None) is not None - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] + x = self.detector.extract_feat(batch_inputs) - if self.with_enhance_head: + results = self.det_head_forward(x, batch_data_samples) + if self.vis_enhance: enhance_outs = self.enhance_head.forward(x) results = results + (enhance_outs, ) - - roi_outs = self.roi_head.forward(x, rpn_results_list) - results = results + (roi_outs, ) return results def loss(self, batch_inputs: Tensor, @@ -222,10 +322,9 @@ def loss(self, batch_inputs: Tensor, Returns: dict: A dictionary of loss components """ - x = self.extract_feat(batch_inputs) + x = self.detector.extract_feat(batch_inputs) losses = dict() - if self.with_enhance_head: enhance_loss = self.enhance_head.loss(x, batch_data_samples) @@ -233,36 +332,10 @@ def loss(self, batch_inputs: Tensor, assert not set(enhance_loss.keys()) & set(losses.keys()) losses.update(enhance_loss) - # RPN forward and loss - if self.with_rpn: - proposal_cfg = self.train_cfg.get('rpn_proposal', - self.test_cfg.rpn) - rpn_data_samples = copy.deepcopy(batch_data_samples) - # set cat_id of gt_labels to 0 in RPN - for data_sample in rpn_data_samples: - data_sample.gt_instances.labels = \ - torch.zeros_like(data_sample.gt_instances.labels) - - rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( - x, rpn_data_samples, proposal_cfg=proposal_cfg) - # avoid get same name with roi_head loss - keys = rpn_losses.keys() - for key in keys: - if 'loss' in key and 'rpn' not in key: - rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) - losses.update(rpn_losses) - else: - # TODO: Not support currently, should have a check at Fast R-CNN - assert batch_data_samples[0].get('proposals', None) is not None - # use pre-defined proposals in InstanceData for the second stage - # to extract ROI features. - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] - roi_losses = self.roi_head.loss(x, rpn_results_list, - batch_data_samples) - losses.update(roi_losses) - + det_losses = self.calculate_det_loss(x, batch_data_samples) + # avoid loss override + assert not set(det_losses.keys()) & set(losses.keys()) + losses.update(det_losses) return losses def predict(self, @@ -294,27 +367,14 @@ def predict(self, the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ - assert self.with_bbox, 'Bbox head must be implemented.' - x = self.extract_feat(batch_inputs) - - # If there are no pre-defined proposals, use RPN to get proposals - if batch_data_samples[0].get('proposals', None) is None: - rpn_results_list = self.rpn_head.predict( - x, batch_data_samples, rescale=False) - else: - rpn_results_list = [ - data_sample.proposals for data_sample in batch_data_samples - ] + x = self.detector.extract_feat(batch_inputs) + batch_data_samples = self.predict_det_results( + x, batch_data_samples, rescale=rescale) - if self.vis_enhance and self.with_enhance_head: + if self.vis_enhance: enhance_list = self.enhance_head.predict( x, batch_data_samples, rescale=rescale) batch_data_samples = add_pixel_pred_to_datasample( data_samples=batch_data_samples, pixel_list=enhance_list) - results_list = self.roi_head.predict( - x, rpn_results_list, batch_data_samples, rescale=rescale) - - batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) return batch_data_samples diff --git a/lqit/detection/models/detectors/detector_with_enhance_model.py b/lqit/detection/models/detectors/detector_with_enhance_model.py index 892668f..d10ecb4 100644 --- a/lqit/detection/models/detectors/detector_with_enhance_model.py +++ b/lqit/detection/models/detectors/detector_with_enhance_model.py @@ -1,7 +1,6 @@ import copy from typing import Any, Dict, Optional, Tuple, Union -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from mmengine.model import BaseModel from mmengine.model.wrappers import MMDistributedDataParallel as MMDDP from mmengine.utils import is_list_of @@ -11,6 +10,7 @@ from lqit.detection.utils import merge_det_results from lqit.edit.models.post_processor import add_pixel_pred_to_datasample from lqit.registry import MODEL_WRAPPERS, MODELS +from lqit.utils import ConfigType, OptConfigType, OptMultiConfig ForwardResults = Union[Dict[str, Tensor], SampleList, Tuple[Tensor], Tensor] @@ -30,7 +30,7 @@ class DetectorWithEnhanceModel(BaseModel): enhance_model (dict or ConfigDict, optional): Config for enhance model. loss_weight (list): Detection loss weight for raw and enhanced image. Only used when `train_mode` is `both`. - vis_enhance (bool): Whether visualize enhance image during inference. + vis_enhance (bool): Whether visualize enhanced image during inference. Defaults to False. train_mode (str): Train mode of detector, support `raw`, `enhance` and `both`. Defaults to `enhance`. @@ -99,7 +99,7 @@ def __init__(self, @property def with_enhance_model(self) -> bool: - """bool: whether the detector has a Enhance Model""" + """Whether has a enhance model.""" return (hasattr(self, 'enhance_model') and self.enhance_model is not None) diff --git a/lqit/detection/models/necks/__init__.py b/lqit/detection/models/necks/__init__.py index d463b99..35ca8e3 100644 --- a/lqit/detection/models/necks/__init__.py +++ b/lqit/detection/models/necks/__init__.py @@ -1,3 +1,4 @@ from .dffpn import DFFPN +from .ufpn import UFPN -__all__ = ['DFFPN'] +__all__ = ['DFFPN', 'UFPN'] diff --git a/lqit/detection/models/necks/ufpn.py b/lqit/detection/models/necks/ufpn.py new file mode 100644 index 0000000..74b5531 --- /dev/null +++ b/lqit/detection/models/necks/ufpn.py @@ -0,0 +1,249 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmdet.models.necks.fpn import FPN +from torch import Tensor + +from lqit.registry import MODELS +from lqit.utils import ConfigType, MultiConfig, OptConfigType + + +@MODELS.register_module() +class UFPN(FPN): + """UNet-based Feature Pyramid Network, UFPN. + + This is an implementation of paper `Underwater Object Detection Aided + by Image Reconstruction + `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Defaults to 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Defaults to -1, which means the + last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Defaults to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Defaults to False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + num_outs: int = 6, + start_level: int = 0, + end_level: int = -1, + add_extra_convs: str = 'on_output', + relu_before_extra_convs: bool = False, + no_norm_on_lateral: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = None, + upsample_cfg: ConfigType = dict(mode='nearest'), + init_cfg: MultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + num_outs=num_outs, + start_level=start_level, + end_level=end_level, + add_extra_convs=add_extra_convs, + relu_before_extra_convs=relu_before_extra_convs, + no_norm_on_lateral=no_norm_on_lateral, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg, + init_cfg=init_cfg) + # add encoder pathway + self.encode_convs = nn.ModuleList() + self.connect_convs = nn.ModuleList() + for i in range(self.start_level, self.num_outs + self.start_level): + connect_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + if i < self.num_outs + self.start_level - 1: + encode_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.encode_convs.append(encode_conv) + self.connect_convs.append(connect_conv) + + # add decoder pathway + self.decode_convs = nn.ModuleList() + decode_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + + self.decode_convs.append(decode_conv) + + for _ in range(self.start_level + 1, self.num_outs + self.start_level): + conv1 = ConvModule( + out_channels * 2, # concat with other feature map + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + conv2 = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, # may use ReLU or LeakyReLU + inplace=False) + decode_conv = nn.Sequential(conv1, conv2) + self.decode_convs.append(decode_conv) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + inter_outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # build extra outputs + + if self.num_outs > len(inter_outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + inter_outs.append( + F.max_pool2d(inter_outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = inter_outs[-1] + else: + raise NotImplementedError + inter_outs.append( + self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + inter_outs.append(self.fpn_convs[i](F.relu( + inter_outs[-1]))) + else: + inter_outs.append(self.fpn_convs[i](inter_outs[-1])) + + # part 2: add encoder path + connect_feats = [ + self.connect_convs[i](inter_outs[i]) for i in range(self.num_outs) + ] + + encode_outs = [] + encode_outs.append(connect_feats[0]) + for i in range(0, self.num_outs - 1): + encode_outs.append(self.encode_convs[i](connect_feats[i]) + + connect_feats[i + 1]) + + # part 3: add decoder levels + decode_outs = [ + torch.zeros_like(encode_outs[i]) for i in range(self.num_outs) + ] + decode_outs[-1] = self.decode_convs[0](encode_outs[-1]) + for i in range(1, self.num_outs): + reverse_i = self.num_outs - i + if 'scale_factor' in self.upsample_cfg: + up_feat = F.interpolate(decode_outs[reverse_i], + **self.upsample_cfg) + else: + prev_shape = encode_outs[reverse_i - 1].shape[2:] + up_feat = F.interpolate( + decode_outs[reverse_i], + size=prev_shape, + **self.upsample_cfg) + + decode_outs[reverse_i - 1] = self.decode_convs[i]( + torch.cat((encode_outs[reverse_i - 1], up_feat), dim=1)) + + return tuple(decode_outs) diff --git a/lqit/edit/models/editor_heads/basic_enhance_head.py b/lqit/edit/models/editor_heads/basic_enhance_head.py index 3a8cae8..6034ad1 100644 --- a/lqit/edit/models/editor_heads/basic_enhance_head.py +++ b/lqit/edit/models/editor_heads/basic_enhance_head.py @@ -232,13 +232,12 @@ def loss_by_feat_single(self, enhance_img, gt_img, img_meta): @MODELS.register_module() class BasicEnhanceHead(BaseEnhanceHead): - """[(convs)+ShufflePixes] * 2 - """ + """[(convs)+ShufflePixes] * 2""" def __init__(self, in_channels=256, feat_channels=256, - num_convs=5, + num_convs=2, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'),