forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#15 from jiweibo/elmo
[ELMo] Add elmo lac demo.
- Loading branch information
Showing
4 changed files
with
331 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
## 基于ELMo的LAC分词预测样例 | ||
|
||
### 一:准备环境 | ||
|
||
请您在环境中安装1.7或以上版本的Paddle,具体的安装方式请参照[飞桨官方页面](https://www.paddlepaddle.org.cn/)的指示方式。 | ||
|
||
### 二:下载模型以及测试数据 | ||
|
||
|
||
1) **获取预测模型** | ||
|
||
点击[链接](https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/elmo/elmo.tgz)下载模型,如果你想获取更多的**模型训练信息**,请访问[链接](https://github.com/PaddlePaddle/models/tree/release/1.8/PaddleNLP/pretrain_language_models/ELMo)。解压后存储到该工程的根目录。 | ||
|
||
2) **获取相关数据** | ||
|
||
点击[链接](https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/elmo/elmo_data.tgz)下载相关数据,解压后存储到该工程的根目录。 | ||
|
||
### 三:运行预测 | ||
|
||
`reader.py` 包含了数据读取等功能。 | ||
`infer.py` 包含了创建predictor,读取输入,预测,获取输出的等功能。 | ||
|
||
运行: | ||
``` | ||
python infer.py | ||
``` | ||
|
||
分词结果为: | ||
|
||
``` | ||
1 sample's result: <UNK>/n 电脑/vn 对/v-I 胎儿/v-I 影响/vn-B 大/v-I 吗/a | ||
2 sample's result: 这个/r 跟/p 我们/ns 一直/p 传承/n 《/p 易经/n 》/n 的/u 精神/n 是/v-I 分/v 不/d 开/v 的/u | ||
3 sample's result: 他们/p 不/r 但/ad 上/v-I 名医/v-I 门诊/n ,/w 还/n 兼/ns-I 作/ns-I 门诊/n 医生/v-I 的/n 顾问/v-I 团/nt | ||
4 sample's result: 负责/n 外商/v-I 投资/v-I 企业/n 和/v-I 外国/v-I 企业/n 的/u 税务/nr-I 登记/v-I ,/w 纳税/n 申报/vn 和/n 税收/vn 资料/n 的/u 管理/n ,/w 全面/c 掌握/n 税收/vn 信息/n | ||
5 sample's result: 采用/ns-I 弹性/ns-I 密封/ns-I 结构/n ,/w 实现/n 零/v-B 间隙/v-I | ||
6 sample's result: 要/r 做/n 好/p 这/n 三/p 件/vn 事/n ,/w 支行/q 从/q 风险/n 管理/p 到/a 市场/q 营销/n 策划/c 都/p 必须/vn 专业/n 到位/vn | ||
7 sample's result: 那么/nz-B ,/r 请/v-I 你/v-I 一定/nz-B 要/d-I 幸福/ad ./v-I | ||
8 sample's result: 叉车/ns-I 在/ns-I 企业/n 的/u 物流/n 系统/vn 中/ns-I 扮演/ns-I 着/v-I 非常/q 重要/n 的/u 角色/n ,/w 是/u 物料/vn 搬运/ns-I 设备/n 中/vn 的/u 主力/ns-I 军/v-I | ||
9 sample's result: 我/r 真/t 的/u 能够/vn 有/ns-I 机会/ns-I 拍摄/v-I 这部/vn 电视/ns-I 剧/v-I 么/vn | ||
10 sample's result: 这种/r 情况/n 应该/v-I 是/v-I 没/n 有/p 危害/n 的/u | ||
``` | ||
|
||
### 相关链接 | ||
- [Paddle Inference使用Quick Start!]() | ||
- [Paddle Inference Python Api使用]() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#coding: utf-8 | ||
import numpy as np | ||
import paddle | ||
import argparse | ||
import reader | ||
import sys | ||
|
||
from paddle.fluid.core import AnalysisConfig | ||
from paddle.fluid.core import create_paddle_predictor | ||
|
||
|
||
def parse_args(): | ||
""" | ||
Parsing the input parameters. | ||
""" | ||
parser = argparse.ArgumentParser("Inference for lexical analyzer.") | ||
parser.add_argument( | ||
"--model_dir", | ||
type=str, | ||
default="elmo", | ||
help="The folder where the test data is located.") | ||
parser.add_argument( | ||
"--testdata_dir", | ||
type=str, | ||
default="elmo_data/dev", | ||
help="The folder where the test data is located.") | ||
parser.add_argument( | ||
"--use_gpu", | ||
type=int, | ||
default=False, | ||
help="Whether or not to use GPU. 0-->CPU 1-->GPU") | ||
parser.add_argument( | ||
"--word_dict_path", | ||
type=str, | ||
default="elmo_data/vocabulary_min5k.txt", | ||
help="The path of the word dictionary.") | ||
parser.add_argument( | ||
"--label_dict_path", | ||
type=str, | ||
default="elmo_data/tag.dic", | ||
help="The path of the label dictionary.") | ||
parser.add_argument( | ||
"--word_rep_dict_path", | ||
type=str, | ||
default="elmo_data/q2b.dic", | ||
help="The path of the word replacement Dictionary.") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def to_lodtensor(data): | ||
""" | ||
Convert data in list into lodtensor. | ||
""" | ||
seq_lens = [len(seq) for seq in data] | ||
cur_len = 0 | ||
lod = [cur_len] | ||
for l in seq_lens: | ||
cur_len += l | ||
lod.append(cur_len) | ||
flattened_data = np.concatenate(data, axis=0).astype("int64") | ||
flattened_data = flattened_data.reshape([len(flattened_data), 1]) | ||
return flattened_data, [lod] | ||
|
||
|
||
def create_predictor(args): | ||
if args.model_dir is not "": | ||
config = AnalysisConfig(args.model_dir) | ||
else: | ||
config = AnalysisConfig(args.model_file, args.params_file) | ||
|
||
config.switch_use_feed_fetch_ops(False) | ||
config.enable_memory_optim() | ||
if args.use_gpu: | ||
config.enable_use_gpu(1000, 0) | ||
else: | ||
# If not specific mkldnn, you can set the blas thread. | ||
# The thread num should not be greater than the number of cores in the CPU. | ||
config.set_cpu_math_library_num_threads(4) | ||
|
||
predictor = create_paddle_predictor(config) | ||
return predictor | ||
|
||
|
||
def run(predictor, datas, lods): | ||
input_names = predictor.get_input_names() | ||
for i, name in enumerate(input_names): | ||
input_tensor = predictor.get_input_tensor(name) | ||
input_tensor.reshape(datas[i].shape) | ||
input_tensor.copy_from_cpu(datas[i].copy()) | ||
input_tensor.set_lod(lods[i]) | ||
|
||
# do the inference | ||
predictor.zero_copy_run() | ||
|
||
results = [] | ||
# get out data from output tensor | ||
output_names = predictor.get_output_names() | ||
for i, name in enumerate(output_names): | ||
output_tensor = predictor.get_output_tensor(name) | ||
output_data = output_tensor.copy_to_cpu() | ||
results.append(output_data) | ||
return results | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
args = parse_args() | ||
word2id_dict = reader.load_reverse_dict(args.word_dict_path) | ||
label2id_dict = reader.load_reverse_dict(args.label_dict_path) | ||
word_rep_dict = reader.load_dict(args.word_rep_dict_path) | ||
word_dict_len = max(map(int, word2id_dict.values())) + 1 | ||
label_dict_len = max(map(int, label2id_dict.values())) + 1 | ||
|
||
pred = create_predictor(args) | ||
|
||
test_data = paddle.batch( | ||
reader.file_reader(args.testdata_dir, word2id_dict, label2id_dict, | ||
word_rep_dict), | ||
batch_size=1) | ||
batch_id = 0 | ||
id2word = {v: k for k, v in word2id_dict.items()} | ||
id2label = {v: k for k, v in label2id_dict.items()} | ||
for data in test_data(): | ||
batch_id += 1 | ||
word_data, word_lod = to_lodtensor(list(map(lambda x: x[0], data))) | ||
target_data, target_lod = to_lodtensor(list(map(lambda x: x[1], data))) | ||
result_list = run(pred, [word_data, target_data], | ||
[word_lod, target_lod]) | ||
number_infer = np.array(result_list[0]) | ||
number_label = np.array(result_list[1]) | ||
number_correct = np.array(result_list[2]) | ||
lac_result = "" | ||
for i in range(len(data[0][0])): | ||
lac_result += id2word[data[0][0][i]] + '/' + id2label[np.array( | ||
result_list[3]).tolist()[i][0]] + " " | ||
print("%d sample's result:" % batch_id, lac_result) | ||
if batch_id >= 10: | ||
exit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#coding: utf-8 | ||
""" | ||
The file_reader converts raw corpus to input. | ||
""" | ||
import os | ||
import __future__ | ||
import io | ||
|
||
|
||
def file_reader(file_dir, | ||
word2id_dict, | ||
label2id_dict, | ||
word_replace_dict, | ||
filename_feature=""): | ||
""" | ||
define the reader to read files in file_dir | ||
""" | ||
word_dict_len = max(map(int, word2id_dict.values())) + 1 | ||
label_dict_len = max(map(int, label2id_dict.values())) + 1 | ||
|
||
def reader(): | ||
""" | ||
the data generator | ||
""" | ||
index = 0 | ||
for root, dirs, files in os.walk(file_dir): | ||
for filename in files: | ||
for line in io.open( | ||
os.path.join(root, filename), 'r', encoding='utf8'): | ||
index += 1 | ||
bad_line = False | ||
line = line.strip("\n") | ||
if len(line) == 0: | ||
continue | ||
seg_tag = line.rfind("\t") | ||
word_part = line[0:seg_tag].strip().split(' ') | ||
label_part = line[seg_tag + 1:] | ||
word_idx = [] | ||
words = word_part | ||
for word in words: | ||
if word in word_replace_dict: | ||
word = word_replace_dict[word] | ||
if word in word2id_dict: | ||
word_idx.append(int(word2id_dict[word])) | ||
else: | ||
word_idx.append(int(word2id_dict["<UNK>"])) | ||
target_idx = [] | ||
labels = label_part.strip().split(" ") | ||
for label in labels: | ||
if label in label2id_dict: | ||
target_idx.append(int(label2id_dict[label])) | ||
else: | ||
target_idx.append(int(label2id_dict["O"])) | ||
if len(word_idx) != len(target_idx): | ||
print(line) | ||
continue | ||
yield word_idx, target_idx | ||
|
||
return reader | ||
|
||
|
||
def test_reader(file_dir, | ||
word2id_dict, | ||
label2id_dict, | ||
word_replace_dict, | ||
filename_feature=""): | ||
""" | ||
define the reader to read test files in file_dir | ||
""" | ||
word_dict_len = max(map(int, word2id_dict.values())) + 1 | ||
label_dict_len = max(map(int, label2id_dict.values())) + 1 | ||
|
||
def reader(): | ||
""" | ||
the data generator | ||
""" | ||
index = 0 | ||
for root, dirs, files in os.walk(file_dir): | ||
for filename in files: | ||
if not filename.startswith(filename_feature): | ||
continue | ||
for line in io.open( | ||
os.path.join(root, filename), 'r', encoding='utf8'): | ||
index += 1 | ||
bad_line = False | ||
line = line.strip("\n") | ||
if len(line) == 0: | ||
continue | ||
seg_tag = line.rfind("\t") | ||
if seg_tag == -1: | ||
seg_tag = len(line) | ||
word_part = line[0:seg_tag] | ||
label_part = line[seg_tag + 1:] | ||
word_idx = [] | ||
words = word_part | ||
for word in words: | ||
if ord(word) < 0x20: | ||
word = ' ' | ||
if word in word_replace_dict: | ||
word = word_replace_dict[word] | ||
if word in word2id_dict: | ||
word_idx.append(int(word2id_dict[word])) | ||
else: | ||
word_idx.append(int(word2id_dict["OOV"])) | ||
yield word_idx, words | ||
|
||
return reader | ||
|
||
|
||
def load_reverse_dict(dict_path): | ||
""" | ||
Load a dict. The first column is the key and the second column is the value. | ||
""" | ||
result_dict = {} | ||
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')): | ||
terms = line.strip("\n") | ||
result_dict[terms] = idx | ||
return result_dict | ||
|
||
|
||
def load_dict(dict_path): | ||
""" | ||
Load a dict. The first column is the value and the second column is the key. | ||
""" | ||
result_dict = {} | ||
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')): | ||
terms = line.strip("\n") | ||
result_dict[idx] = terms | ||
return result_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters