-
Notifications
You must be signed in to change notification settings - Fork 4
/
config.py
116 lines (100 loc) · 4.13 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.config import CfgNode as CN
def add_DFormer_config(cfg):
"""
Add config for MASK_FORMER.
"""
# NOTE: configs from original maskformer
# data config
# select the dataset mapper
cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
# Color augmentation
cfg.INPUT.COLOR_AUG_SSD = False
# We retry random cropping until no single category in semantic segmentation GT occupies more
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
# Pad image and segmentation GT in dataset mapper.
cfg.INPUT.SIZE_DIVISIBILITY = -1
# solver config
# weight decay on embedding
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
# optimizer
cfg.SOLVER.OPTIMIZER = "ADAMW"
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
# mask_former model config
cfg.MODEL.DFORMER = CN()
# loss
cfg.MODEL.DFORMER.DEEP_SUPERVISION = True
cfg.MODEL.DFORMER.NO_OBJECT_WEIGHT = 0.1
cfg.MODEL.DFORMER.CLASS_WEIGHT = 1.0
cfg.MODEL.DFORMER.DICE_WEIGHT = 1.0
cfg.MODEL.DFORMER.MASK_WEIGHT = 20.0
# transformer config
cfg.MODEL.DFORMER.NHEADS = 8
cfg.MODEL.DFORMER.DROPOUT = 0.1
cfg.MODEL.DFORMER.DIM_FEEDFORWARD = 2048
cfg.MODEL.DFORMER.ENC_LAYERS = 0
cfg.MODEL.DFORMER.DEC_LAYERS = 6
cfg.MODEL.DFORMER.PRE_NORM = False
cfg.MODEL.DFORMER.SNR_SCALE = 0.1
cfg.MODEL.DFORMER.SAMPLE_STEP=1
cfg.MODEL.DFORMER.HIDDEN_DIM = 256
cfg.MODEL.DFORMER.NUM_OBJECT_QUERIES = 100
cfg.MODEL.DFORMER.TRANSFORMER_IN_FEATURE = "res5"
cfg.MODEL.DFORMER.ENFORCE_INPUT_PROJ = False
# DFormer inference config
cfg.MODEL.DFORMER.TEST = CN()
cfg.MODEL.DFORMER.TEST.SEMANTIC_ON = True
cfg.MODEL.DFORMER.TEST.INSTANCE_ON = False
cfg.MODEL.DFORMER.TEST.PANOPTIC_ON = False
cfg.MODEL.DFORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
cfg.MODEL.DFORMER.TEST.OVERLAP_THRESHOLD = 0.0
cfg.MODEL.DFORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
# you can use this config to override
cfg.MODEL.DFORMER.SIZE_DIVISIBILITY = 32
# pixel decoder config
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
# adding transformer in pixel decoder
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
# pixel decoder
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
# swin transformer backbone
cfg.MODEL.SWIN = CN()
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
cfg.MODEL.SWIN.PATCH_SIZE = 4
cfg.MODEL.SWIN.EMBED_DIM = 96
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
cfg.MODEL.SWIN.WINDOW_SIZE = 7
cfg.MODEL.SWIN.MLP_RATIO = 4.0
cfg.MODEL.SWIN.QKV_BIAS = True
cfg.MODEL.SWIN.QK_SCALE = None
cfg.MODEL.SWIN.DROP_RATE = 0.0
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
cfg.MODEL.SWIN.APE = False
cfg.MODEL.SWIN.PATCH_NORM = True
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
cfg.MODEL.SWIN.USE_CHECKPOINT = False
# NOTE: maskformer2 extra configs
# transformer module
cfg.MODEL.DFORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"
# LSJ aug
cfg.INPUT.IMAGE_SIZE = 1024
cfg.INPUT.MIN_SCALE = 0.1
cfg.INPUT.MAX_SCALE = 2.0
# MSDeformAttn encoder configs
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
# point loss configs
# Number of points sampled during training for a mask point head.
cfg.MODEL.DFORMER.TRAIN_NUM_POINTS = 112 * 112
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
# original paper.
cfg.MODEL.DFORMER.OVERSAMPLE_RATIO = 3.0
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
# the original paper.
cfg.MODEL.DFORMER.IMPORTANCE_SAMPLE_RATIO = 0.75