-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoint2model.py
51 lines (40 loc) · 2.36 KB
/
checkpoint2model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os.path as osp
import torch
def main():
"""An auxiliary script for converting a checkpoint file (`checkpoint.pt`) into a support sets (`support_sets.pt`)
and a reconstructor (`reconstructor.pt`) weights files.
Options:
================================================================================================================
--exp : set experiment's wip model dir, as created by `train.py`, i.e., it should contain a sub-directory
`models/` with a checkpoint file (`checkpoint.pt`). Checkpoint file contains the weights of the
support sets and the reconstructor at an intermediate stage of training (after a given iteration).
================================================================================================================
"""
parser = argparse.ArgumentParser(description="Convert a checkpoint file into a support sets and a reconstructor "
"weights files")
parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)")
# Parse given arguments
args = parser.parse_args()
# Check structure of `args.exp`
if not osp.isdir(args.exp):
raise NotADirectoryError("Invalid given directory: {}".format(args.exp))
models_dir = osp.join(args.exp, 'models')
if not osp.isdir(models_dir):
raise NotADirectoryError("Invalid models directory: {}".format(models_dir))
checkpoint_file = osp.join(models_dir, 'checkpoint.pt')
if not osp.isfile(checkpoint_file):
raise FileNotFoundError("Checkpoint file not found: {}".format(checkpoint_file))
print("#. Convert checkpoint file into support sets and reconstructor weight files...")
# Load checkpoint file
checkpoint_dict = torch.load(checkpoint_file)
print(" \\__Checkpoint dictionary: {}".format(checkpoint_dict.keys()))
# Get checkpoint iteration
checkpoint_iter = checkpoint_dict['iter']
print(" \\__Checkpoint iteration: {}".format(checkpoint_iter))
# Save latent support sets (LSS) weights file
print(" \\__Save checkpoint latent support sets LSS weights file...")
torch.save(checkpoint_dict['latent_support_sets'],
osp.join(models_dir, 'latent_support_sets-{:07d}.pt'.format(checkpoint_iter)))
if __name__ == '__main__':
main()