-
Notifications
You must be signed in to change notification settings - Fork 7
/
mdl_selector.py
69 lines (65 loc) · 1.51 KB
/
mdl_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Select the model, loss, eval_fn
"""
from mdl_vog import (
ImgGrnd_SEP,
ImgGrnd_TEMP,
ImgGrnd_SPAT,
VidGrnd_SEP,
VidGrnd_TEMP,
VidGrnd_SPAT,
VOG_SEP,
VOG_TEMP,
VOG_SPAT,
LossB_SEP,
LossB_TEMP,
LossB_SPAT
)
from eval_vsrl_corr import (
EvaluatorSEP,
EvaluatorTEMP,
EvaluatorSPAT
)
def get_mdl_loss_eval(cfg):
conc_type = cfg.ds.conc_type
mdl_type = cfg.mdl.name
if conc_type == 'sep' or conc_type == 'svsq':
if mdl_type == 'igrnd':
mdl = ImgGrnd_SEP
elif mdl_type == 'vgrnd':
mdl = VidGrnd_SEP
elif mdl_type == 'vog':
mdl = VOG_SEP
else:
raise NotImplementedError
loss = LossB_SEP
evl = EvaluatorSEP
elif conc_type == 'temp':
if mdl_type == 'igrnd':
mdl = ImgGrnd_TEMP
elif mdl_type == 'vgrnd':
mdl = VidGrnd_TEMP
elif mdl_type == 'vog':
mdl = VOG_TEMP
else:
raise NotImplementedError
loss = LossB_TEMP
evl = EvaluatorTEMP
elif conc_type == 'spat':
if mdl_type == 'igrnd':
mdl = ImgGrnd_SPAT
elif mdl_type == 'vgrnd':
mdl = VidGrnd_SPAT
elif mdl_type == 'vog':
mdl = VOG_SPAT
else:
raise NotImplementedError
loss = LossB_SPAT
evl = EvaluatorSPAT
else:
raise NotImplementedError
return {
'mdl': mdl,
'loss': loss,
'eval': evl
}