This repository has been archived by the owner on Jun 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 570
/
export_to_tfhub.py
207 lines (176 loc) · 7.6 KB
/
export_to_tfhub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Exports a minimal TF-Hub module for ALBERT models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from albert import modeling
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
flags.DEFINE_string(
"albert_directory", None,
"The config json file corresponding to the pre-trained ALBERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"checkpoint_name", "model.ckpt-best",
"Name of the checkpoint under albert_directory to be exported.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_bool(
"use_einsum", True,
"Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
"be set to False for TFLite compatibility.")
flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.")
FLAGS = flags.FLAGS
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def get_mlm_logits(model, albert_config, mlm_positions):
"""From run_pretraining.py."""
input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=albert_config.embedding_size,
activation=modeling.get_activation(albert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
albert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[albert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(
input_tensor, model.get_embedding_table(), transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
def get_sop_log_probs(model, albert_config):
"""Get loss and log probs for the next sentence prediction."""
input_tensor = model.get_pooled_output()
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, albert_config.hidden_size],
initializer=modeling.create_initializer(
albert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
return log_probs
def module_fn(is_training):
"""Module function."""
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
albert_config_path = os.path.join(
FLAGS.albert_directory, "albert_config.json")
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
model = modeling.AlbertModel(
config=albert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False,
use_einsum=FLAGS.use_einsum)
mlm_logits = get_mlm_logits(model, albert_config, mlm_positions)
sop_log_probs = get_sop_log_probs(model, albert_config)
vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model")
vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab")
config_file = tf.constant(
value=albert_config_path, dtype=tf.string, name="config_file")
vocab_model = tf.constant(
value=vocab_model_path, dtype=tf.string, name="vocab_model")
# This is only for visualization purpose.
vocab_file = tf.constant(
value=vocab_file_path, dtype=tf.string, name="vocab_file")
# By adding `config_file, vocab_model and vocab_file`
# to the ASSET_FILEPATHS collection, TF-Hub will
# rewrite this tensor so that this asset is portable.
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model)
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
hub.add_signature(
name="tokens",
inputs=dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
outputs=dict(
sequence_output=model.get_sequence_output(),
pooled_output=model.get_pooled_output()))
hub.add_signature(
name="sop",
inputs=dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
outputs=dict(
sequence_output=model.get_sequence_output(),
pooled_output=model.get_pooled_output(),
sop_log_probs=sop_log_probs))
hub.add_signature(
name="mlm",
inputs=dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
mlm_positions=mlm_positions),
outputs=dict(
sequence_output=model.get_sequence_output(),
pooled_output=model.get_pooled_output(),
mlm_logits=mlm_logits))
hub.add_signature(
name="tokenization_info",
inputs={},
outputs=dict(
vocab_file=vocab_model,
do_lower_case=tf.constant(FLAGS.do_lower_case)))
def main(_):
tags_and_args = []
for is_training in (True, False):
tags = set()
if is_training:
tags.add("train")
tags_and_args.append((tags, dict(is_training=is_training)))
spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
tf.logging.info("Using checkpoint {}".format(checkpoint_path))
spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path)
if __name__ == "__main__":
flags.mark_flag_as_required("albert_directory")
flags.mark_flag_as_required("export_path")
app.run(main)