Skip to content

Commit

Permalink
[Feature] Support TIENet
Browse files Browse the repository at this point in the history
  • Loading branch information
BIGWangYuDong committed Oct 25, 2023
1 parent 21e9f74 commit 66a9d8d
Show file tree
Hide file tree
Showing 17 changed files with 1,411 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
_base_ = [
'../../_base_/datasets/urpc2020/urpc2020-validation_coco_detection.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='RetinaNet',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=4,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
sampler=dict(
type='PseudoSampler'), # Focal loss should use PseudoSampler
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))

# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=35, norm_type=2)) # loss may NaN without clip_grad

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
38 changes: 38 additions & 0 deletions configs/detection/tienet/base_editor/tienet_enhance_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
enhance_model = dict(
_scope_='lqit',
type='BaseEditModel',
destruct_gt=True,
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0.0, 0.0, 0.0],
std=[255.0, 255.0, 255.0],
bgr_to_rgb=False,
gt_name='img'),
generator=dict(
type='TIENetGenerator',
model=dict(
type='TIENetEnhanceModel',
in_channels=3,
feat_channels=64,
out_channels=3,
num_blocks=3,
expand_ratio=0.5,
kernel_size=[1, 3, 5],
output_weight=[1.0, 1.0],
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='SiLU'),
use_depthwise=True),
spacial_pred='structure',
structure_pred='structure',
spacial_loss=dict(type='SpatialLoss', loss_weight=1.0),
tv_loss=dict(type='MaskedTVLoss', loss_mode='mse', loss_weight=10.0),
structure_loss=dict(
type='StructureFFTLoss',
radius=4,
pass_type='high',
channel_mean=True,
loss_type='mse',
guid_filter=dict(
type='GuidedFilter2d', radius=32, eps=1e-4, fast_s=2),
loss_weight=0.1)))
35 changes: 35 additions & 0 deletions configs/detection/tienet/tienet_retinanet_r50_fpn_1x_urpc-coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# default scope is mmdet
_base_ = [
'./base_editor/tienet_enhance_model.py',
'./base_detector/retinanet_r50_fpn_1x_urpc-coco.py'
]

model = dict(
_delete_=True,
type='lqit.DetectorWithEnhanceModel',
detector={{_base_.model}},
enhance_model={{_base_.enhance_model}},
train_mode='enhance',
pred_mode='enhance',
detach_enhance_img=False)

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=35, norm_type=2))

# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='lqit.SetInputImageAsGT'),
dict(type='lqit.PackInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

model_wrapper_cfg = dict(
type='lqit.SelfEnhanceModelDDP',
broadcast_buffers=False,
find_unused_parameters=False)
8 changes: 5 additions & 3 deletions lqit/common/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .lark_manager import (MonitorManager, MonitorTracker,
context_monitor_manager, get_user_name,
initialize_monitor_manager, send_alert_message)
context_monitor_manager, get_error_message,
get_user_name, initialize_monitor_manager,
send_alert_message)

__all__ = [
'send_alert_message', 'get_user_name', 'initialize_monitor_manager',
'context_monitor_manager', 'MonitorTracker', 'MonitorManager'
'context_monitor_manager', 'MonitorTracker', 'MonitorManager',
'get_error_message'
]
27 changes: 17 additions & 10 deletions lqit/common/utils/lark_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,7 @@ def monitor_exception(self) -> None:
assert self.url is not None, \
'Please run `MonitorManager.start_monitor` first.'

filtered_trace = traceback.format_exc().split('\n')[-15:]
format_trace = ''
for line in filtered_trace:
format_trace += '\n' + line

# try to add error message into logger else directly print message
try:
print_log(format_trace, logger='current')
except Exception:
print(format_trace)
format_trace = get_error_message()
title = 'Task Error Report'
content = f"{self.user_name}'s {self.task_type} task\n" \
f'Config file: {self.cfg_file}\n' \
Expand Down Expand Up @@ -346,3 +337,19 @@ def context_monitor_manager(monitor_manager: Optional[MonitorManager] = None):
monitor_manager.stop_monitor()
else:
yield


def get_error_message() -> None:
"""Catch and format exception information, send alert message to Feishu."""

filtered_trace = traceback.format_exc().split('\n')[-15:]
format_trace = ''
for line in filtered_trace:
format_trace += '\n' + line

# try to add error message into logger else directly print message
try:
print_log(format_trace, logger='current')
except Exception:
print(format_trace)
return format_trace
3 changes: 2 additions & 1 deletion lqit/detection/models/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .detector_with_enhance_model import DetectorWithEnhanceModel
from .edffnet import EDFFNet
from .multi_input_wrapper import MultiInputDetectorWrapper
from .single_stage_enhance_head import SingleStageDetector
from .two_stage_enhance_head import TwoStageWithEnhanceHead

__all__ = [
'TwoStageWithEnhanceHead', 'MultiInputDetectorWrapper',
'SingleStageDetector', 'EDFFNet'
'SingleStageDetector', 'EDFFNet', 'DetectorWithEnhanceModel'
]
Loading

0 comments on commit 66a9d8d

Please sign in to comment.