From 8fa51e8aaae8af69fa76bfd55eb52f959cffad68 Mon Sep 17 00:00:00 2001 From: JuntingGuo Date: Thu, 7 Nov 2019 18:36:32 +0800 Subject: [PATCH] fix double input bug for classification --- bert_base/server/__init__.py | 5 ++++- bert_base/server/graph.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/bert_base/server/__init__.py b/bert_base/server/__init__.py index 26cb51a..55690d4 100644 --- a/bert_base/server/__init__.py +++ b/bert_base/server/__init__.py @@ -459,7 +459,10 @@ def classification_model_fn(features, labels, mode, params): graph_def.ParseFromString(f.read()) input_ids = features["input_ids"] input_mask = features["input_mask"] - input_map = {"input_ids": input_ids, "input_mask": input_mask} + #为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。 + #input_map = {"input_ids": input_ids, "input_mask": input_mask} + segment_ids=features["input_type_ids"] + input_map = {"input_ids": input_ids, "input_mask": input_mask,"segment_ids":segment_ids} pred_probs = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=['pred_prob:0']) return EstimatorSpec(mode=mode, predictions={ diff --git a/bert_base/server/graph.py b/bert_base/server/graph.py index f30fc31..18369e1 100644 --- a/bert_base/server/graph.py +++ b/bert_base/server/graph.py @@ -339,8 +339,11 @@ def optimize_class_model(args, num_labels, logger=None): bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json')) from bert_base.train.models import create_classification_model - loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, - input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels) + #为了兼容多输入,增加segment_id特征,即训练代码中的input_type_ids特征。 + #loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, + #input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels) + segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids') + loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels) # pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids') # pred_ids = tf.identity(pred_ids, 'pred_ids') probabilities = tf.identity(probabilities, 'pred_prob')