Skip to content

Commit

Permalink
fix double input bug for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
oasis-0927 committed Nov 7, 2019
1 parent 12fb822 commit 8fa51e8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion bert_base/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ def classification_model_fn(features, labels, mode, params):
graph_def.ParseFromString(f.read())
input_ids = features["input_ids"]
input_mask = features["input_mask"]
input_map = {"input_ids": input_ids, "input_mask": input_mask}
#为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。
#input_map = {"input_ids": input_ids, "input_mask": input_mask}
segment_ids=features["input_type_ids"]
input_map = {"input_ids": input_ids, "input_mask": input_mask,"segment_ids":segment_ids}
pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0'])

return EstimatorSpec(mode=mode, predictions={
Expand Down
7 changes: 5 additions & 2 deletions bert_base/server/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,11 @@ def optimize_class_model(args, num_labels, logger=None):

bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
from bert_base.train.models import create_classification_model
loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False,
input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
#为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。
#loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False,
#input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids')
loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels)
# pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids')
# pred_ids = tf.identity(pred_ids, 'pred_ids')
probabilities = tf.identity(probabilities, 'pred_prob')
Expand Down

0 comments on commit 8fa51e8

Please sign in to comment.