-
Notifications
You must be signed in to change notification settings - Fork 9
/
convert_ckpt.py
61 lines (50 loc) · 1.78 KB
/
convert_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
@date: 2021/11/22
@description: Conversion training ckpt into inference ckpt
"""
import argparse
import os
import torch
from config.defaults import merge_from_file
def parse_option():
parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt')
parser.add_argument('--cfg',
type=str,
required=True,
metavar='FILE',
help='path of config file')
parser.add_argument('--output_path',
type=str,
help='path of output ckpt')
args = parser.parse_args()
print("arguments:")
for arg in vars(args):
print(arg, ":", getattr(args, arg))
print("-" * 50)
return args
def convert_ckpt():
args = parse_option()
config = merge_from_file(args.cfg)
ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net",
config.TAG)
print(f"Processing {ck_dir}")
model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name]
if len(model_paths) == 0:
print("Not find best ckpt")
return
model_path = os.path.join(ck_dir, model_paths[0])
print(f"Loading {model_path}")
checkpoint = torch.load(model_path, map_location=torch.device('cuda:0'))
net = checkpoint['net']
output_path = None
if args.output_path is None:
output_path = os.path.join(ck_dir, 'best.pkl')
else:
output_path = args.output_path
if output_path is None:
print("Output path is invalid")
print(f"Save on: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torch.save(net, output_path)
if __name__ == '__main__':
convert_ckpt()