From 32fd3c9ed695f9db103398fd5a5d54f04dfaa01a Mon Sep 17 00:00:00 2001 From: songwenli <327361201@qq.com> Date: Fri, 12 Jul 2019 15:10:56 +0800 Subject: [PATCH] issue#101: fix bug --- bert_base/server/graph.py | 2 +- bert_base/server/helper.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bert_base/server/graph.py b/bert_base/server/graph.py index cbe011e..f30fc31 100644 --- a/bert_base/server/graph.py +++ b/bert_base/server/graph.py @@ -286,7 +286,7 @@ def optimize_ner_model(args, num_labels, logger=None): from bert_base.train.models import create_model (total_loss, logits, trans, pred_ids) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=None, - labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) + labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0, lstm_size=args.lstm_size) pred_ids = tf.identity(pred_ids, 'pred_ids') saver = tf.train.Saver() diff --git a/bert_base/server/helper.py b/bert_base/server/helper.py index 337a794..f053869 100644 --- a/bert_base/server/helper.py +++ b/bert_base/server/helper.py @@ -98,6 +98,8 @@ def get_args_parser(): help='masking the embedding on [CLS] and [SEP] with zero. \ When pooling_strategy is in {CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN} \ then the embedding is preserved, otherwise the embedding is masked to zero before pooling') + group2.add_argument('-lstm_size', type=int, default=128, + help='size of lstm units.') group3 = parser.add_argument_group('Serving Configs', 'config how server utilizes GPU/CPU resources')