-
Notifications
You must be signed in to change notification settings - Fork 1
/
gvars.py
87 lines (70 loc) · 4.08 KB
/
gvars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import arguments as a
import numpy as np
import os
# bit suss recreating args in every file, but possibly workaround for conflict with ray workers?
# a.args = a.create_args()
# a.arguments_check(a.args)
FEAT_MOD_DIR = './models/feat_extractors/'
VIZ_DIR = './viz'
VIZ_DATA_DIR = './viz/data/'
WEIGHT_DIR = './weights'
MODEL_DIR = './models'
LOG_DIR = './logs'
C_DIR = './cstates'
DMAP_DIR = './data/precompute/size_{}_sigma_{}/'.format(15,4) #a.args.sigma)
# used in search.py and cyclic
MIN_LR = 1e-5
MAX_LR = 0.1
if os.uname().nodename == 'hydra':
#RAYLOGDIR = '/home/mks29/clones/cow_flow/'
ABSDIR = '/Scratch/repository/skiff/'
elif os.getenv('WHEREAMI') == 'laptop': # conda env config vars WHEREAMI='laptop'
ABSDIR = '/home/mks29/clones/cow_flow/'
elif os.getenv('WHEREAMI') == 'fg':
ABSDIR = '/home/matthew/Desktop/laptop_desktop/clones/cow_flow/'
else:
ABSDIR = '/Scratch/mks29/cow_flow/'
bm_dir = 'final_models_paper/256/'
bm_dir_86 = 'final_models_paper/800x600/'
BEST_MODEL_PATH_DICT = {
'UNet':bm_dir+'best_5DD_UNet_hydra_BS64_LR_I0.002_E5000_DIM256_OPTIMadam_weight_none_09_05_2022_20_05_07',
'UNet_seg':bm_dir+'best_9S1_UNet_seg_ml-17_BS16_LR_I0.0002_E500_DIM256_OPTIMsgd_weight_MX_SZ_4_none_11_05_2022_13_13_32',
'CSRNet':bm_dir+'best_4H_CSRNet_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_16_48_45',
'FCRN':bm_dir+'best_9CC_FCRN_hydra_BS64_LR_I0.1_E900_DIM256_OPTIMsgd_weight_none_WD_None_09_05_2022_20_05_07',
'LCFCN':bm_dir_86+'best_86_LCFCN_A_LCFCN_quartet_BS1_LR_I1e-05_E1000_DIM608_OPTIMadam_weighted_MX_SZ_4_none_WD_1e-05_15_06_2022_14_17_53',
#'MCNN':bm_dir+'best_86_MCNN_B_MCNN_quatern_BS32_LR_I0.001_E1000_DIM608_OPTIMadam_weighted_none_WD_1e-05_FT_15_06_2022_17_21_20',
'MCNN':bm_dir+'best_relu3X_MCNN_hydra_BS64_LR_I0.001_E1000_DIM256_OPTIMsgd_SC1000_weighted_none_WD_1e-05_12_08_2022_12_01_49',
'Res50':bm_dir+'best_4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42',
'NF':bm_dir_86+'best_86_widereal_w_vgg1NF_C5_NF_hydra_BS8_LR_I0.0001_E250_DIM608_OPTIMadam_FE_vgg16_bn_NC5_conv_JC_weighted_none_JO_PY_1_WD_1e-05_15_07_2022_17_52_19',
'NF_FRZ':bm_dir_86+'best_2_FRZ86_widereal_w_vgg1NF_C5_NF_hydra_BS12_LR_I0.0001_E250_DIM608_OPTIMadam_FE_vgg16_bn_NC5_conv_JC_weighted_none_PT_PY_1_WD_1e-05_29_08_2022_13_28_01',
'VGG':bm_dir_86+'best_vgg_real1_VGG_hydra_BS16_LR_I0.001_E1000_DIM608_OPTIMsgd_SC1000_weighted_none_WD_0.01_29_09_2022_19_06_45'
}
# order needs to be the same
BASELINE_MODEL_NAMES = ['UNet','CSRNet','FCRN','LCFCN','UNet_seg','MCNN','Res50','VGG']
BASELINE_TITLE_NAMES = ['UNet (d)','CSRNet','FCRN','LC-FCN','UNet (s)','MCNN','ResNet-50','VGG-16']
# must be in order of BASELINE MODEL NAMES
ERROR_FILES_LIST = [
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
'4G2_Res50_hydra_BS64_LR_I0.0001_E1000_DIM256_OPTIMsgd_weight_none_04_05_2022_17_21_42.errors',
]
SUBNETS = ['conv','conv_shallow','fc','MCNN','UNet','conv_deep']
THRES_SEQ = np.arange(0, 20, 0.5, dtype=float)
global FILTERS
FILTERS = 0
global SUBNET_BN
SUBNET_BN = False
def check_gvars():
if a.args.model_name in BASELINE_MODEL_NAMES:
assert a.args.noise == 0
if not (a.args.mode == 'plot' and a.args.plot_errors):
assert a.args.model_name in BASELINE_MODEL_NAMES + ['NF','ALL']
if a.args.model_name == 'NF' and a.args.mode == 'train':
assert a.args.subnet_type in SUBNETS
if __name__ == "__main__":
check_gvars()