-
Notifications
You must be signed in to change notification settings - Fork 11
/
convert_mobilebert_tf_checkpoint_to_pytorch.py
46 lines (40 loc) · 1.81 KB
/
convert_mobilebert_tf_checkpoint_to_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import logging
import torch
from model.modeling_mobilebert import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = MobileBertConfig.from_json_file(mobilebert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = MobileBertForPreTraining(config)
# Load weights from tf checkpoint
model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--mobilebert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained MobileBERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path)
'''
python convert_mobilebert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path=./prev_trained_model/mobilebert \
--mobilebert_config_file=./prev_trained_model/mobilebert/config.json \
--pytorch_dump_path=./prev_trained_model/mobilebert/pytorch_model.bin
'''