-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
106 lines (84 loc) · 2.95 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import urllib
import torch
from torch.utils import model_zoo
import os
class CheckpointIO(object):
''' CheckpointIO class.
It handles saving and loading checkpoints.
Args:
checkpoint_dir (str): path where checkpoints are saved
'''
def __init__(self, checkpoint_dir='./chkpts', **kwargs):
self.module_dict = kwargs
self.checkpoint_dir = checkpoint_dir
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def register_modules(self, **kwargs):
''' Registers modules in current module dictionary.
'''
self.module_dict.update(kwargs)
def save(self, filename, **kwargs):
''' Saves the current module dictionary.
Args:
filename (str): name of output file
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
outdict = kwargs
for k, v in self.module_dict.items():
outdict[k] = v.state_dict()
torch.save(outdict, filename)
def load(self, filename):
'''Loads a module dictionary from local file or url.
Args:
filename (str): name of saved module dictionary
'''
if is_url(filename):
return self.load_url(filename)
else:
print(filename)
return self.load_file(filename)
def load_file(self, filename):
'''Loads a module dictionary from file.
Args:
filename (str): name of saved module dictionary
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
print(filename)
if os.path.exists(filename):
print(filename)
print('=> Loading checkpoint from local file...')
state_dict = torch.load(filename)
scalars = self.parse_state_dict(state_dict)
return scalars
else:
raise FileExistsError
def load_url(self, url):
'''Load a module dictionary from url.
Args:
url (str): url to saved model
'''
print(url)
print('=> Loading checkpoint from url...')
state_dict = model_zoo.load_url(url, progress=True)
scalars = self.parse_state_dict(state_dict)
return scalars
def parse_state_dict(self, state_dict):
'''Parse state_dict of model and return scalars.
Args:
state_dict (dict): State dict of model
'''
for k, v in self.module_dict.items():
if k in state_dict:
print(k, v)
v.load_state_dict(state_dict[k])
else:
print('Warning: Could not find %s in checkpoint!' % k)
scalars = {k: v for k, v in state_dict.items()
if k not in self.module_dict}
return scalars
def is_url(url):
scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https')