Skip to content

Commit

Permalink
fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
stevewyl committed Jun 3, 2019
1 parent 452ac6c commit 257dabd
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 21 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pip install git+https://www.github.com/keras-team/keras-contrib.git

3. Trainer:定义模型的训练流程,支持bucket序列、自定义callbacks和N折交叉验证

* bucket序列:通过将相似长度的文本放入同一batch来减小padding的多余计算来实现模型训练的加速,在文本分类任务中,能够对RNN网络提速2倍以上(暂时不支持含有Flatten层的网络)
* bucket序列:通过将相似长度的文本放入同一batch来减小padding的多余计算来实现模型训练的加速,在文本分类任务中,能够对RNN网络提速2倍以上(**暂时不支持含有Flatten层的网络**

* callbacks:通过自定义回调器来控制训练流程,目前预设的回调器有提前终止训练,学习率自动变化,更丰富的评估函数等

Expand Down Expand Up @@ -98,6 +98,7 @@ y_pred = text_classifier.predict(dataset.texts)
# chunk分词
# 第一次import的时候,会自动下载模型和字典数据
# 支持单句和多句文本的输入格式,建议以列表的形式传入分词器
# 源代码中已略去相关数据的下载路径,有需要的请邮件联系
from nlp_toolkit.chunk_segmentor import Chunk_Segmentor
cutter = Chunk_Segmentor()
s = '这是一个能够输出名词短语的分词器,欢迎试用!'
Expand Down
3 changes: 3 additions & 0 deletions nlp_toolkit/chunk_segmentor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

环境依赖:python 3.6.5 (暂时只支持python3)

**不再维护更新**
**源代码中已略去相关数据的下载路径,有需要的请邮件联系**

## 安装

```bash
Expand Down
26 changes: 13 additions & 13 deletions nlp_toolkit/chunk_segmentor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
MD5_FILE_PATH = DATA_PATH / 'model_data.md5'
UPDATE_TAG_PATH = DATA_PATH / 'last_update.pkl'
UPDATE_INIT_PATH = DATA_PATH / 'init_update.txt'
MD5_HDFS_PATH = '/user/kdd_wangyilei/chunk_segmentor/model_data.md5'
MODEL_HDFS_PATH = '/user/kdd_wangyilei/chunk_segmentor/model_data.zip'
USER_NAME = 'yilei.wang'
PASSWORD = 'ifchange0829FWGR'
FTP_PATH_1 = 'ftp://192.168.8.23:21/chunk_segmentor'
FTP_PATH_2 = 'ftp://211.148.28.11:21/chunk_segmentor'
MD5_HDFS_PATH = '/user/xxxx/chunk_segmentor/model_data.md5'
MODEL_HDFS_PATH = '/user/xxxx/chunk_segmentor/model_data.zip'
USER_NAME = 'xxxx'
PASSWORD = 'xxxxx'
FTP_PATH_1 = 'ftp://xxx.xxx.xx.xx:xx/chunk_segmentor'
FTP_PATH_2 = 'ftp://xxx.xxx.xx.xx:xx/chunk_segmentor'
IP = socket.gethostbyname(socket.gethostname())


Expand All @@ -43,7 +43,7 @@ def check_version():
with open(UPDATE_INIT_PATH, 'w') as fout:
fout.write(init_update_time)
else:
print('请寻找一台有hadoop或者能访问ftp://192.168.8.23:21或者ftp://211.148.28.11:21的机器')
print('请寻找一台有hadoop或者能访问ftp://xxx.xxx.xx.xx:xx或者ftp://xxx.xxx.xx.xx:xx的机器')


def write_config(config_path, new_root_path):
Expand All @@ -67,7 +67,7 @@ def download():
os.remove(fname)

if not IP.startswith('127'):
print('尝试从ftp://192.168.8.23:21获取数据')
print('尝试从ftp://xxx.xxx.xx.xx:xx获取数据')
ret2 = os.system('wget -q --timeout=2 --tries=1 --ftp-user=%s --ftp-password=%s %s/model_data.md5' %
(USER_NAME, PASSWORD, FTP_PATH_1))
if ret2 == 0:
Expand All @@ -78,7 +78,7 @@ def download():
ret1 = os.system('hadoop fs -get %s' % MODEL_HDFS_PATH)
ret2 = os.system('hadoop fs -get %s' % MD5_HDFS_PATH)
else:
print('尝试从ftp://211.148.28.11:21获取数据')
print('尝试从ftp://xxx.xxx.xx.xx:xx获取数据')
ret2 = os.system('wget -q --timeout=2 --tries=1 --ftp-user=%s --ftp-password=%s %s/model_data.md5' %
(USER_NAME, PASSWORD, FTP_PATH_2))
if ret2 == 0:
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_data_md5():
if ret == 0:
src = 'ftp2'
if ret != 0:
print('请寻找一台有hadoop或者能访问ftp://192.168.8.23:21或者ftp://211.148.28.11:21的机器')
print('请寻找一台有hadoop或者能访问ftp://xxx.xxx.xx.xx:xx或者ftp://xxx.xxx.xx.xx:xx的机器')
return None
else:
return src
Expand Down Expand Up @@ -154,12 +154,12 @@ def update_data(src):
os.remove(fname)
if src == 'hdfs':
print('尝试从hdfs上拉取数据,大约20-30s')
os.system('hadoop fs -get /user/kdd_wangyilei/chunk_segmentor/model_data.zip')
os.system('hadoop fs -get /user/xxxxx/chunk_segmentor/model_data.zip')
elif src == 'ftp1':
print('尝试从ftp://192.168.8.23:21获取数据')
print('尝试从ftp://xxx.xxx.xx.xx:xx获取数据')
os.system('wget --ftp-user=%s --ftp-password=%s %s/model_data.zip' % (USER_NAME, PASSWORD, FTP_PATH_1))
elif src == 'ftp2':
print('尝试从ftp://211.148.28.11:21获取数据')
print('尝试从ftp://xxx.xxx.xx.xx:xx获取数据')
os.system('wget --ftp-user=%s --ftp-password=%s %s/model_data.zip' % (USER_NAME, PASSWORD, FTP_PATH_2))

os.system('unzip -q model_data.zip')
Expand Down
3 changes: 2 additions & 1 deletion nlp_toolkit/chunk_segmentor/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def extract_item(self, item):
poss = list(flatten_gen(complete_poss))
if self.cut_all:
words, poss = zip(*[(x1, y1) for x, y in zip(words, poss) for x1, y1 in self.cut_qualifier(x, y)])
words = [' ' if word == 's_' else word for word in words]
if self.pos:
d = (words, # C_CUT_WORD
poss, # C_CUT_POS
Expand All @@ -186,7 +187,7 @@ def extract_item(self, item):
return d

def cut_qualifier(self, x, y):
if y == 'np' and '_' in x:
if y == 'np' and '_' in x and x not in ['s_', 'ss_', 'lan_']:
for sub_word in x.split('_'):
yield sub_word, y
else:
Expand Down
19 changes: 17 additions & 2 deletions nlp_toolkit/chunk_segmentor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,30 @@ def jieba_cut(sent_list, segmentor, qualifier_word=None, mode='accurate', dict_l
# URLs
# r'(?:https?://|www\.)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
# r'\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-] +\.[a-zA-Z0-9-.] +\b' # E-MAIL
r'&#[\s\w\d]+;'
]
START_PATTERN = r'(\d+、|\d+\.(?!\d+)|\d+\)|(?<![a-z0-9])[a-z]{1}(?=[、\.\)])|\(\d+\)|[一二三四五六七八九十]+[、\)\.)])'
START_PATTERN = [
r'\* ',
r'\d{1,2}\.\d{1,2}\.\d{1,2}', # 1.2.1
r'\d+\t',
r'([1-9][0-9]){1,2}[。;::,,、\.\t/]{1}\s?(?![年月日\d+])',
r'([1-9][0-9]){1,2}[))]{1}、?',
r' \| ',
r'\n[1-9][0-9]',
r'\n{2,}',
r'(?<![A-Za-z0-9/])[A-Za-z]{1}\s?[、\.\)、\t]{1}',
r'\(1?[1-9]\)',
r'第?[一二三四五六七八九十]+[、\)\.) \t,]{1}',
r'\([一二三四五六七八九十]+\)\.?'
]
START_PATTERN = re.compile(r'('+'|'.join(START_PATTERN)+')+', re.UNICODE)
END_PATTERN = r'(。|!|?|!|\?|;|;)'
HTML = re.compile(r'('+'|'.join(REGEX_STR)+')', re.UNICODE)


# 异常字符过滤
def preprocess(string):
invalid_unicode = u'[\u25A0-\u25FF\u0080-\u00A0\uE000-\uFBFF\u2000-\u2027\u2030-\u206F]+'
invalid_unicode = u'[\u25A0-\u25FF\u0080-\u00A0\uE000-\uFBFF\u2000-\u201B\u201E-\u2027\u2030-\u206F]+'
lang_char = u'[\u3040-\u309f\u30A0-\u30FF\u1100-\u11FF\u0E00-\u0E7F\u0600-\u06ff\u0750-\u077f\u0400-\u04ff]+'
invalid_char = u'[\xa0\x7f\x9f]+'
string = re.sub(EMOJI_UNICODE, '', string)
Expand Down
4 changes: 2 additions & 2 deletions nlp_toolkit/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def predict(self, x: Dict[str, List[List[str]]], batch_size=64,
used_time = time.time() - start
logger.info('predict {} samples used {:4.1f}s'.format(
len(x['token']), used_time))
if result.shape[1] > n_labels:
if result.shape[1] > n_labels and self.model_name == 'bi_lstm_att':
attention = result[:, n_labels:]
attention = [attention[idx][:l] for idx, l in enumerate(x_len)]
return y_pred, attention
Expand All @@ -171,7 +171,7 @@ def evaluate(self, x: Dict[str, List[List[str]]], y: List[str],
def load(self, weight_fname, para_fname):
if self.model_name == 'bi_lstm_att':
self.model = bi_lstm_attention.load(weight_fname, para_fname)
elif self.model_name == 'multi_head_self_att':
elif self.model_name == 'transformer':
self.model = Transformer.load(weight_fname, para_fname)
elif self.model_name == 'text_cnn':
self.model = textCNN.load(weight_fname, para_fname)
Expand Down
3 changes: 2 additions & 1 deletion nlp_toolkit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __init__(self, mode, fname='', tran_fname='',
self.config = config
self.data_config = config['data']
self.embed_config = config['embed']
self.data_format = self.data_config['format']
if self.task_type == 'sequence':
self.data_format = self.data_config['format']
if self.basic_token == 'word':
self.max_tokens = self.data_config['max_words']
self.inner_char = self.data_config['inner_char']
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='nlp_toolkit',
version='1.3.1',
version='1.3.2',
description='NLP Toolkit with easy model training and applications',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 257dabd

Please sign in to comment.