-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
executable file
·138 lines (107 loc) · 4.35 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/python
# Borrow from tensorpack,credits goes to yuxin wu
# the loaded sequnce is
# default config in this file
# -> provided setting file (you can not add new config after this)
# -> manully overrided config
# -> computed config in finalize config (you can change config after this)
import os
import pprint
import yaml
__all__ = ["config", "finalize_configs"]
class AttrDict:
_freezed = False
""" Avoid accidental creation of new hierarchies. """
def __getattr__(self, name):
if self._freezed:
raise AttributeError(name)
ret = AttrDict()
setattr(self, name, ret)
return ret
def __setattr__(self, name, value):
if self._freezed and name not in self.__dict__:
raise AttributeError("Cannot create new attribute!")
super().__setattr__(name, value)
def __str__(self):
return pprint.pformat(self.to_dict(), indent=1)
__repr__ = __str__
def to_dict(self):
"""Convert to a nested dict. """
return {
k: v.to_dict() if isinstance(v, AttrDict) else v
for k, v in self.__dict__.items()
if not k.startswith("_")
}
def update_args(self, args):
"""Update from command line args. """
for cfg in args:
keys, v = cfg.split("=", maxsplit=1)
keylist = keys.split(".")
dic = self
# print(keylist)
if len(keylist) == 1:
assert keylist[0] in dir(dic), "Unknown config key: {}".format(
keylist[0]
)
for i, k in enumerate(keylist[:-1]):
assert k in dir(dic), "Unknown config key: {}".format(k)
dic = getattr(dic, k)
key = keylist[-1]
assert key in dir(dic), "Unknown config key: {}".format(key)
oldv = getattr(dic, key)
if not isinstance(oldv, str):
v = eval(v)
setattr(dic, key, v)
def update_with_yaml(self, rel_path):
base_path = os.path.dirname(os.path.abspath(__file__))
setting_path = os.path.normpath(os.path.join(base_path, "configs", rel_path))
setting_name = os.path.basename(setting_path).split(".")[0]
with open(setting_path, "r") as f:
overrided_setting = yaml.load(f,Loader=yaml.FullLoader)
# if 'setting_name' not in overrided_setting:
# raise RuntimeError('you must provide a setting name for non root_setting: {}'.format(rel_path))
self.update_with_dict(overrided_setting)
setattr(self, "setting_name", setting_name)
def init_with_yaml(self):
base_path = os.path.dirname(os.path.abspath(__file__))
setting_path = os.path.normpath(os.path.join(base_path, "root_setting.yaml"))
with open(setting_path, "r") as f:
overrided_setting = yaml.load(f,Loader=yaml.FullLoader)
self.update_with_dict(overrided_setting)
def update_with_text(self,text):
overrided_setting = yaml.load(text, Loader=yaml.FullLoader)
self.update_with_dict(overrided_setting)
def update_with_dict(self, dicts):
for k, v in dicts.items():
if isinstance(v, dict):
getattr(self, k).update_with_dict(v)
else:
setattr(self, k, v)
def freeze(self):
self._freezed = True
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.freeze()
# avoid silent bugs
def __eq__(self, _):
raise NotImplementedError()
def __ne__(self, _):
raise NotImplementedError()
config = AttrDict()
_C = config # short alias to avoid coding
# you can directly write setting here as _C.model_dir='.\checkpoint' or in root_setting.yaml
#
def finalize_configs(input_cfg=_C, freeze=True, verbose=True):
# _C.base_path = os.path.dirname(os.path.abspath(__file__))
input_cfg.base_path = os.path.dirname(__file__)
# for running in remote server
# for k, v in input_cfg.path.__dict__.items():
# v = os.path.normpath(os.path.join(input_cfg.base_path, v))
# setattr(input_cfg.path, k, v)
if freeze:
input_cfg.freeze()
# if verbose:
# logger.info("Config: ------------------------------------------\n" + str(_C))
if __name__ == "__main__":
print("?")
print(os.path.dirname(__file__))