-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmisc.py
176 lines (146 loc) · 5.51 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import torch
import numpy as np
import random
import time
import logging
import logging.handlers
import subprocess
from datetime import datetime
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
class CheckpointManager(object):
def __init__(self, save_dir, logger=BlackHole()):
super().__init__()
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
self.ckpts = []
self.logger = logger
for f in os.listdir(self.save_dir):
if f[:4] != 'ckpt':
continue
_, it, score = f.split('_')
it = it.split('.')[0]
self.ckpts.append({
'score': float(score),
'file': f,
'iteration': int(it),
})
def get_worst_ckpt_idx(self):
idx = -1
worst = float('-inf')
for i, ckpt in enumerate(self.ckpts):
if ckpt['score'] >= worst:
idx = i
worst = ckpt['score']
return idx if idx >= 0 else None
def get_best_ckpt_idx(self):
idx = -1
best = float('inf')
for i, ckpt in enumerate(self.ckpts):
if ckpt['score'] <= best:
idx = i
best = ckpt['score']
return idx if idx >= 0 else None
def get_latest_ckpt_idx(self):
idx = -1
latest_it = -1
for i, ckpt in enumerate(self.ckpts):
if ckpt['iteration'] > latest_it:
idx = i
latest_it = ckpt['iteration']
return idx if idx >= 0 else None
def save(self, model, args, score=None, others=None, step=None):
assert step > -1, 'Please define the value of step'
if score is None:
fname = 'ckpt_%d.pt' % int(step)
else:
fname = 'ckpt_%d_%.6f.pt' % (int(step), float(score))
path = os.path.join(self.save_dir, fname)
torch.save({
'args': args,
'state_dict': model.state_dict(),
'others': others
}, path)
self.ckpts.append({
'score': score,
'file': fname
})
return True
def load_best(self):
idx = self.get_best_ckpt_idx()
if idx is None:
raise IOError('No checkpoints found.')
ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file']))
return ckpt
def load_latest(self):
idx = self.get_latest_ckpt_idx()
if idx is None:
raise IOError('No checkpoints found.')
ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file']))
return ckpt
def load_selected(self, file):
ckpt = torch.load(os.path.join(self.save_dir, file))
return ckpt
def seed_all(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
# if seed == 0: # slower, more reproducible
# torch.backends.cudnn.benchmark = False # default is False
# torch.backends.cudnn.deterministic = True
# else: # faster, less reproducible
# torch.backends.cudnn.benchmark = True # if True, the net graph and input size should be fixed !!!
# torch.backends.cudnn.deterministic = False
def git_commit(logger, log_dir=None, git_name=None):
"""
Logs source code configuration
"""
import git
try:
repo = git.Repo(search_parent_directories=True)
git_sha = repo.head.object.hexsha
git_date = datetime.fromtimestamp(repo.head.object.committed_date).strftime('%Y-%m-%d')
git_message = repo.head.object.message
logger.info('Source is from Commit {} ({}): {}'.format(git_sha[:8], git_date, git_message.strip()))
# Also create diff file in the log directory
if log_dir is not None:
with open(os.path.join(log_dir, 'compareHead.diff'), 'w') as fid:
subprocess.run(['git', 'diff'], stdout=fid)
git_name = git_name if git_name is not None else datetime.now().strftime("%y%m%d_%H%M%S")
os.system("git add --all")
os.system("git commit --all -m '{}'".format(git_name))
except git.exc.InvalidGitRepositoryError:
pass
def get_logger(name, log_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_dir is not None:
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'), mode='w')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info('Output and logs will be saved to: {}'.format(log_dir))
return logger
def get_new_log_dir(root='./logs', prefix='', postfix=''):
name = prefix + time.strftime("%y%m%d_%H%M%S", time.localtime()) + postfix
log_dir = os.path.join(root, name)
os.makedirs(log_dir)
return log_dir, name