-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: YanxingLiu <[email protected]>
- Loading branch information
1 parent
d45bbda
commit 073626f
Showing
17 changed files
with
1,556 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection | ||
|
||
[GLIP: Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
In this paper, we present an open-set object detector, called Grounding DINO, by marrying Transformer-based detector DINO with grounded pre-training, which can detect arbitrary objects with human inputs such as category names or referring expressions. The key solution of open-set object detection is introducing language to a closed-set detector for open-set concept generalization. To effectively fuse language and vision modalities, we conceptually divide a closed-set detector into three phases and propose a tight fusion solution, which includes a feature enhancer, a language-guided query selection, and a cross-modality decoder for cross-modality fusion. While previous works mainly evaluate open-set object detection on novel categories, we propose to also perform evaluations on referring expression comprehension for objects specified with attributes. Grounding DINO performs remarkably well on all three settings, including benchmarks on COCO, LVIS, ODinW, and RefCOCO/+/g. Grounding DINO achieves a 52.5 AP on the COCO detection zero-shot transfer benchmark, i.e., without any training data from COCO. It sets a new record on the ODinW zero-shot benchmark with a mean 26.1 AP. | ||
|
||
<div align=center> | ||
<img src="https://github.com/open-mmlab/mmdetection/assets/42299757/0ed51aeb-3d53-42d8-8563-f6d21364ac95"/> | ||
</div> | ||
|
||
## Installation | ||
|
||
```shell | ||
cd $MMDETROOT | ||
|
||
# source installation | ||
pip install -r requirements/multimodal.txt | ||
|
||
# or mim installation | ||
mim install mmdet[multimodal] | ||
``` | ||
|
||
``` | ||
cd $MMDETROOT | ||
wget https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth | ||
python demo/image_demo.py \ | ||
demo/demo.jpg \ | ||
configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py \ | ||
--weights groundingdino_swint_ogc_mmdet-822d7e9d.pth \ | ||
--texts 'bench . car .' | ||
``` | ||
|
||
<div align=center> | ||
<img src="https://github.com/open-mmlab/mmdetection/assets/42299757/3a3bd6f1-e2ed-43d4-aa22-0bb07ee6f20b"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
| Model | backbone | COCO mAP | Pre-Train Data | Config | Download | | ||
| :--------------: | :------: | :------: | :----------------------------------------------: | :------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | | ||
| Grounding DINO-T | Swin-T | 48.5 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) | | ||
| Grounding DINO-B | Swin-B | 56.9 | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth) | | ||
|
||
Note: | ||
|
||
1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being. |
16 changes: 16 additions & 0 deletions
16
configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
_base_ = [ | ||
'./grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py', | ||
] | ||
|
||
model = dict( | ||
type='GroundingDINO', | ||
backbone=dict( | ||
pretrain_img_size=384, | ||
embed_dims=128, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[4, 8, 16, 32], | ||
window_size=12, | ||
drop_path_rate=0.3, | ||
patch_norm=True), | ||
neck=dict(in_channels=[256, 512, 1024]), | ||
) |
127 changes: 127 additions & 0 deletions
127
configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
lang_model_name = 'bert-base-uncased' | ||
|
||
model = dict( | ||
type='GroundingDINO', | ||
num_queries=900, | ||
with_box_refine=True, | ||
as_two_stage=True, | ||
data_preprocessor=dict( | ||
type='DetDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_mask=False, | ||
), | ||
language_model=dict( | ||
type='BertModel', | ||
name=lang_model_name, | ||
pad_to_max=False, | ||
use_sub_sentence_represent=True, | ||
special_tokens_list=['[CLS]', '[SEP]', '.', '?'], | ||
add_pooling_layer=True, | ||
), | ||
backbone=dict( | ||
type='SwinTransformer', | ||
embed_dims=96, | ||
depths=[2, 2, 6, 2], | ||
num_heads=[3, 6, 12, 24], | ||
window_size=7, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
qk_scale=None, | ||
drop_rate=0., | ||
attn_drop_rate=0., | ||
drop_path_rate=0.2, | ||
patch_norm=True, | ||
out_indices=(1, 2, 3), | ||
with_cp=False, | ||
convert_weights=False), | ||
neck=dict( | ||
type='ChannelMapper', | ||
in_channels=[192, 384, 768], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
bias=True, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
num_outs=4), | ||
encoder=dict( | ||
num_layers=6, | ||
# visual layer config | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), | ||
# text layer config | ||
text_layer_cfg=dict( | ||
self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)), | ||
# fusion layer config | ||
fusion_layer_cfg=dict( | ||
v_dim=256, | ||
l_dim=256, | ||
embed_dim=1024, | ||
num_heads=4, | ||
init_values=1e-4), | ||
), | ||
decoder=dict( | ||
num_layers=6, | ||
return_intermediate=True, | ||
layer_cfg=dict( | ||
# query self attention layer | ||
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), | ||
# cross attention layer query to text | ||
cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), | ||
# cross attention layer query to image | ||
cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), | ||
post_norm_cfg=None), | ||
positional_encoding=dict( | ||
num_feats=128, normalize=True, offset=0.0, temperature=20), | ||
bbox_head=dict( | ||
type='GroundingDINOHead', | ||
num_classes=80, | ||
sync_cls_avg_factor=True, | ||
max_text_len=256, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), # 2.0 in DeformDETR | ||
loss_bbox=dict(type='L1Loss', loss_weight=5.0)), | ||
dn_cfg=dict( # TODO: Move to model.train_cfg ? | ||
label_noise_scale=0.5, | ||
box_noise_scale=1.0, # 0.4 for DN-DETR | ||
group_cfg=dict(dynamic=True, num_groups=None, | ||
num_dn_queries=100)), # TODO: half num_dn_queries | ||
# training and testing settings | ||
train_cfg=None, | ||
test_cfg=dict(max_per_img=300)) | ||
|
||
test_pipeline = [ | ||
dict( | ||
type='LoadImageFromFile', backend_args=None, | ||
imdecode_backend='pillow'), | ||
dict( | ||
type='FixScaleResize', | ||
scale=(800, 1333), | ||
keep_ratio=True, | ||
backend='pillow'), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor', 'text', 'custom_entities')) | ||
] | ||
|
||
val_dataloader = dict( | ||
dataset=dict(pipeline=test_pipeline, return_classes=True)) | ||
test_dataloader = val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
Collections: | ||
- Name: Grounding DINO | ||
Metadata: | ||
Training Data: Objects365, GoldG, CC3M and COCO | ||
Training Techniques: | ||
- AdamW | ||
- Multi Scale Train | ||
- Gradient Clip | ||
Training Resources: A100 GPUs | ||
Architecture: | ||
- Swin Transformer | ||
- BERT | ||
Paper: | ||
URL: https://arxiv.org/abs/2303.05499 | ||
Title: 'Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection | ||
' | ||
README: configs/grounding_dino/README.md | ||
Code: | ||
URL: | ||
Version: v3.0.0 | ||
|
||
Models: | ||
- Name: grounding_dino_swin-t_pretrain_obj365_goldg_cap4m | ||
In Collection: Grounding DINO | ||
Config: configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py | ||
Results: | ||
- Task: Object Detection | ||
Dataset: COCO | ||
Metrics: | ||
box AP: 48.5 | ||
Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth | ||
- Name: grounding_dino_swin-b_pretrain_mixeddata | ||
In Collection: GLIPGrounding DINO | ||
Config: configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py | ||
Results: | ||
- Task: Object Detection | ||
Dataset: COCO | ||
Metrics: | ||
box AP: 56.9 | ||
Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.