From f70592b82bebc06dbed259a9dca2870e486669f5 Mon Sep 17 00:00:00 2001 From: macan Date: Wed, 6 Mar 2019 11:57:33 +0800 Subject: [PATCH] fix do_train error bug --- bert_base/train/train_helper.py | 8 ++++---- setup.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bert_base/train/train_helper.py b/bert_base/train/train_helper.py index 4b95be5..3484c3e 100644 --- a/bert_base/train/train_helper.py +++ b/bert_base/train/train_helper.py @@ -37,12 +37,12 @@ def get_args_parser(): group2 = parser.add_argument_group('Model Config', 'config the model params') group2.add_argument('-max_seq_length', type=int, default=128, help='The maximum total input sequence length after WordPiece tokenization.') - group2.add_argument('-do_train', type=bool, default=True, + group2.add_argument('-do_train', action='store_false', default=True, help='Whether to run training.') - group2.add_argument('-do_eval', type=bool, default=True, + group2.add_argument('-do_eval', action='store_false', default=True, help='Whether to run eval on the dev set.') - group2.add_argument('-do_predict', type=bool, default=True, - help='Whether to run the model in inference mode on the test set.') + group2.add_argument('-do_predict', action='store_false', default=True, + help='Whether to run the predict in inference mode on the test set.') group2.add_argument('-batch_size', type=int, default=64, help='Total batch size for training, eval and predict.') group2.add_argument('-learning_rate', type=float, default=1e-5, diff --git a/setup.py b/setup.py index 7d45c1d..396107c 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +# encoding =utf-8 + from os import path import codecs from setuptools import setup, find_packages @@ -11,7 +13,7 @@ # print(__version__) setup( name='bert_base', - version='0.0.8', + version='0.0.9', description='Use Google\'s BERT for Chinese natural language processing tasks such as named entity recognition and provide server services', url='https://github.com/macanv/BERT-BiLSTM-CRF-NER', long_description=open('README.md', 'r', encoding='utf-8').read(),