forked from fundamentalvision/Deformable-DETR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
64 lines (53 loc) · 2.6 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from argparse import Namespace
from models import build_model as build
dependencies = ["torch", "torchvision"]
def build_deformable_detr(panoptic=False, num_classes=91, **kwargs):
args = Namespace()
args.dataset_file = 'coco_panoptic' if panoptic else 'coco'
args.device = kwargs.get('device', 'cuda')
args.num_classes = num_classes
args.num_feature_levels = kwargs.get('feature_levels', 4)
args.aux_loss = kwargs.get('aux_loss', True)
args.with_box_refine = kwargs.get('with_box_refine', False)
args.masks = kwargs.get('masks', False)
args.mask_loss_coef = kwargs.get('mask_loss_coef', 1.0)
args.dice_loss_coef = kwargs.get('dice_loss_coef', 1.0)
args.cls_loss_coef = kwargs.get('cls_loss_coef', 2.0)
args.bbox_loss_coef = kwargs.get('bbox_loss_coef', 5.0)
args.giou_loss_coef = kwargs.get('giou_loss_coef', 2.0)
args.focal_alpha = kwargs.get('focal_alpha', 0.25)
args.frozen_weights = kwargs.get('frozen_weights', None)
# backbone
args.backbone = kwargs.get('backbone', 'resnet50')
args.lr_backbone = kwargs.get('lr_backbone', 2e-5)
args.dilation = kwargs.get('dilation', False)
# positional encoding
args.position_embedding = kwargs.get('position_embedding', 'sine') # learned
args.hidden_dim = kwargs.get('hidden_dim', 256)
# transformer
args.nheads = kwargs.get('nheads', 8)
args.dim_feedforward = kwargs.get('dim_feedforward', 1024)
args.enc_layers = kwargs.get('enc_layers', 6)
args.dec_layers = kwargs.get('dec_layers', 6)
args.dropout = kwargs.get('dropout', 0.1)
args.dec_n_points = kwargs.get('dec_n_points', 4)
args.enc_n_points = kwargs.get('enc_n_points', 4)
args.num_queries = kwargs.get('num_queries', 300)
args.two_stage = kwargs.get('two_stage', False)
# loss
args.set_cost_class = kwargs.get('set_cost_class', 2)
args.set_cost_bbox = kwargs.get('set_cost_bbox', 5)
args.set_cost_giou = kwargs.get('set_cost_giou', 2)
model, criterion, postprocessors = build(args)
model.to(args.device)
return_postprocessors = kwargs.get('return_postprocessors', False)
if return_postprocessors:
return model, postprocessors
return model
def deformable_detr_r50(num_classes=91, return_postprocessor=False, **kwargs):
if return_postprocessor:
model, postprocessors = build_deformable_detr(backbone='resnet50', num_classes=num_classes, return_postprocessors=return_postprocessor, **kwargs)
return model, postprocessors['bbox']
else:
model = build_deformable_detr(backbone='resnet50', num_classes=num_classes, **kwargs)
return model