-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert.py
105 lines (88 loc) · 3.43 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import argparse
def replace_module_prefix(state_dict, prefix, replace_with=""):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
trained models.
Specify the prefix in the keys that should be removed.
"""
state_dict = {
(key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val
for (key, val) in state_dict.items()
}
return state_dict
def convert_and_save_model(args, replace_prefix):
model_path = args.model_path
model = torch.load(model_path, map_location=torch.device("cpu"))
# get the model trunk to rename
if "classy_state_dict" in model.keys():
model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"]
elif "model_state_dict" in model.keys():
model_trunk = model["model_state_dict"]
else:
model_trunk = model
print(f"Input model loaded. Number of params: {len(model_trunk.keys())}")
# convert the trunk
converted_model = replace_module_prefix(model_trunk, "_feature_blocks.")
print(f"Converted model. Number of params: {len(converted_model.keys())}")
# save the state
output_filename = f"{args.output_name}.pth"
output_model_filepath = f"{args.output_dir}/{output_filename}"
print(f"Saving model: {output_model_filepath}")
torch.save(converted_model, output_model_filepath)
print("DONE!")
print(f"Input model : {model_path}")
print(f"Output model : {output_model_filepath}")
return converted_model
def compare_keys(source_dict, target_dict_path):
print("\n Comparing keys \n")
target_dict = torch.load(target_dict_path, map_location=torch.device("cpu"))
if 'state_dict' in target_dict.keys():
target_dict = target_dict['state_dict']
print(f"Same number of params : {len(source_dict.keys()) == len(target_dict.keys())} \
Source/Target = {len(source_dict.keys())}/{len(target_dict.keys())}")
for key in source_dict.keys():
if key not in target_dict.keys():
print(f"{key} in source dict is not in target dict.")
for key in target_dict.keys():
if key not in source_dict.keys():
print(f"Source dict doesn't have {key}.")
def main():
parser = argparse.ArgumentParser(
description="Convert VISSL ResNe(X)ts models to Torchvision"
)
parser.add_argument(
"--model_path",
type=str,
default=None,
required=True,
help="Model url or file that contains the state dict",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
required=True,
help="Output directory where the converted state dictionary will be saved",
)
parser.add_argument(
"--output_name",
type=str,
default=None,
required=True,
help="output model name"
)
parser.add_argument(
"--target_dict_path",
type=str,
default=None,
required=True,
help="Compare keys of converted model dict and target dict"
)
args = parser.parse_args()
converted_model = convert_and_save_model(args, replace_prefix="_feature_blocks.")
if args.target_dict_path is not None:
compare_keys(converted_model, args.target_dict_path)
if __name__ == "__main__":
main()