Skip to content

Commit

Permalink
Fix ernie csc decode (PaddlePaddle#1005)
Browse files Browse the repository at this point in the history
* use list() instead of tokenize

* use list() instead of tokenize in taskflow

* add max_seq_length in readme

* add dynamic predict in text_correction task

* fix windows predict bug
  • Loading branch information
joey12300 authored Sep 13, 2021
1 parent 1914c8a commit 42447b8
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 142 deletions.
4 changes: 2 additions & 2 deletions examples/text_correction/ernie-csc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ python change_sgml_to_txt.py -i extra_train_ds/train.sgml -o extra_train_ds/trai
### 单卡训练

```python
python train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/
python train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/ --max_seq_length 192
```

### 多卡训练

```python
python -m paddle.distributed.launch --gpus "0,1" train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/
python -m paddle.distributed.launch --gpus "0,1" train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/ --max_seq_length 192
```

## 模型预测
Expand Down
6 changes: 3 additions & 3 deletions examples/text_correction/ernie-csc/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def predict(self, data, batch_size=1):
is_test=True)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id), # segment
Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token]), # pinyin
Pad(axis=0, pad_val=self.tokenizer.pad_token_id, dtype='int64'), # input
Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id, dtype='int64'), # segment
Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token], dtype='int64'), # pinyin
Stack(axis=0, dtype='int64'), # length
): [data for data in fn(samples)]

Expand Down
6 changes: 1 addition & 5 deletions examples/text_correction/ernie-csc/predict_sighan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,11 @@ def write_sighan_result_to_file(args, corr_preds, det_preds, lengths,
lengths[i], tokenizer,
args.max_seq_length)
words = list(words)
if len(words) > args.max_seq_length - 2:
words = words[:args.max_seq_length - 2]
words = ''.join(words)

pred_result = list(pred_result)
result = ids
if pred_result == words:
result += ', 0'
else:
pred_result = list(pred_result)
assert len(pred_result) == len(
words), "pred_result: {}, words: {}".format(pred_result,
words)
Expand Down
59 changes: 8 additions & 51 deletions examples/text_correction/ernie-csc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def convert_example(example,
ignore_label=-1,
is_test=False):
source = example["source"]
words = tokenizer.tokenize(text=source)
words = list(source)
if len(words) > max_seq_length - 2:
words = words[:max_seq_length - 2]
length = len(words)
Expand All @@ -50,7 +50,6 @@ def convert_example(example,
# Use pad token in pinyin emb to map word emb [CLS], [SEP]
pinyins = lazy_pinyin(
source, style=Style.TONE3, neutral_tone_with_five=True)

pinyin_ids = [0]
# Align pinyin and chinese char
pinyin_offset = 0
Expand All @@ -71,7 +70,7 @@ def convert_example(example,

if not is_test:
target = example["target"]
correction_labels = tokenizer.tokenize(text=target)
correction_labels = list(target)
if len(correction_labels) > max_seq_length - 2:
correction_labels = correction_labels[:max_seq_length - 2]
correction_labels = tokenizer.convert_tokens_to_ids(correction_labels)
Expand Down Expand Up @@ -114,64 +113,22 @@ def parse_decode(words, corr_preds, det_preds, lengths, tokenizer,
max_seq_length):
UNK = tokenizer.unk_token
UNK_id = tokenizer.convert_tokens_to_ids(UNK)
tokens = tokenizer.tokenize(words)
if len(tokens) > max_seq_length - 2:
tokens = tokens[:max_seq_length - 2]

corr_pred = corr_preds[1:1 + lengths].tolist()
det_pred = det_preds[1:1 + lengths].tolist()
words = list(words)
rest_words = []
if len(words) > max_seq_length - 2:
rest_words = words[max_seq_length - 2:]
words = words[:max_seq_length - 2]

assert len(tokens) == len(
corr_pred
), "The number of tokens should be equal to the number of labels {}: {}: {}".format(
len(tokens), len(corr_pred), tokens)
pred_result = ""

align_offset = 0
# Need to be aligned
if len(words) != len(tokens):
first_unk_flag = True
for j, word in enumerate(words):
if word.isspace():
tokens.insert(j + 1, word)
corr_pred.insert(j + 1, UNK_id)
det_pred.insert(j + 1, 0) # No error
elif tokens[j] != word:
if tokenizer.convert_tokens_to_ids(word) == UNK_id:
if first_unk_flag:
first_unk_flag = False
corr_pred[j] = UNK_id
det_pred[j] = 0
else:
tokens.insert(j, UNK)
corr_pred.insert(j, UNK_id)
det_pred.insert(j, 0) # No error
continue
elif tokens[j] == UNK:
# Remove rest unk
k = 0
while k + j < len(tokens) and tokens[k + j] == UNK:
k += 1
tokens = tokens[:j] + tokens[j + k:]
corr_pred = corr_pred[:j] + corr_pred[j + k:]
det_pred = det_pred[:j] + det_pred[j + k:]
else:
# Maybe English, number, or suffix
token = tokens[j].lstrip("##")
corr_pred = corr_pred[:j] + [UNK_id] * len(
token) + corr_pred[j + 1:]
det_pred = det_pred[:j] + [0] * len(token) + det_pred[j +
1:]
tokens = tokens[:j] + list(token) + tokens[j + 1:]
first_unk_flag = True

for j, word in enumerate(words):
candidates = tokenizer.convert_ids_to_tokens(corr_pred[j])
if det_pred[j] == 0 or candidates == UNK or candidates == '[PAD]':
if not is_chinese_char(ord(word)) or det_pred[
j] == 0 or candidates == UNK or candidates == '[PAD]':
pred_result += word
else:
pred_result += candidates.lstrip("##")

pred_result += ''.join(rest_words)
return pred_result
154 changes: 73 additions & 81 deletions paddlenlp/taskflow/text_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def __init__(self, task, model, **kwargs):
)
self._pypinyin = pypinyin
self._max_seq_length = 128
self._batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # segment
Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token]), # pinyin
Stack(axis=0, dtype='int64'), # length
): [data for data in fn(samples)]
self._num_workers = self.kwargs[
'num_workers'] if 'num_workers' in self.kwargs else 0
self._batch_size = self.kwargs[
'batch_size'] if 'batch_size' in self.kwargs else 1
self._lazy_load = self.kwargs[
'lazy_load'] if 'lazy_load' in self.kwargs else False

def _construct_input_spec(self):
"""
Expand Down Expand Up @@ -141,61 +153,83 @@ def _construct_tokenizer(self, model):

def _preprocess(self, inputs, padding=True, add_special_tokens=True):
inputs = self._check_input_text(inputs)
batch_size = self.kwargs[
'batch_size'] if 'batch_size' in self.kwargs else 1
trans_func = self._convert_example

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # segment
Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token]), # pinyin
Stack(axis=0, dtype='int64'), # length
): [data for data in fn(samples)]

examples = []
texts = []
for text in inputs:
if not (isinstance(text, str) and len(text) > 0):
continue
example = {"source": text.strip()}
input_ids, token_type_ids, pinyin_ids, length = trans_func(example)
input_ids, token_type_ids, pinyin_ids, length = self._convert_example(
example)
examples.append((input_ids, token_type_ids, pinyin_ids, length))
texts.append(example["source"])

batch_examples = [
examples[idx:idx + batch_size]
for idx in range(0, len(examples), batch_size)
examples[idx:idx + self._batch_size]
for idx in range(0, len(examples), self._batch_size)
]
batch_texts = [
texts[idx:idx + batch_size]
for idx in range(0, len(examples), batch_size)
texts[idx:idx + self._batch_size]
for idx in range(0, len(examples), self._batch_size)
]
outputs = {}
outputs['batch_examples'] = batch_examples
outputs['batch_texts'] = batch_texts
self.batchify_fn = batchify_fn
if not self._static_mode:

def read(inputs):
for text in inputs:
example = {"source": text.strip()}
input_ids, token_type_ids, pinyin_ids, length = self._convert_example(
example)
yield input_ids, token_type_ids, pinyin_ids, length

infer_ds = load_dataset(read, inputs=inputs, lazy=self._lazy_load)
outputs['data_loader'] = paddle.io.DataLoader(
infer_ds,
collate_fn=self._batchify_fn,
num_workers=self._num_workers,
batch_size=self._batch_size,
shuffle=False,
return_list=True)

return outputs

def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_tokenize` function.
"""
results = []
with static_mode_guard():
for examples in inputs['batch_examples']:
token_ids, token_type_ids, pinyin_ids, lengths = self.batchify_fn(
examples)
self.input_handles[0].copy_from_cpu(token_ids)
self.input_handles[1].copy_from_cpu(pinyin_ids)
self.predictor.run()
det_preds = self.output_handle[0].copy_to_cpu()
char_preds = self.output_handle[1].copy_to_cpu()

batch_result = []
for i in range(len(lengths)):
batch_result.append(
(det_preds[i], char_preds[i], lengths[i]))
results.append(batch_result)
if not self._static_mode:
with dygraph_mode_guard():
for examples in inputs['data_loader']:
token_ids, token_type_ids, pinyin_ids, lengths = examples
det_preds, char_preds = self._model(token_ids, pinyin_ids)
det_preds = det_preds.numpy()
char_preds = char_preds.numpy()
lengths = lengths.numpy()

batch_result = []
for i in range(len(lengths)):
batch_result.append(
(det_preds[i], char_preds[i], lengths[i]))
results.append(batch_result)
else:
with static_mode_guard():
for examples in inputs['batch_examples']:
token_ids, token_type_ids, pinyin_ids, lengths = self._batchify_fn(
examples)
self.input_handles[0].copy_from_cpu(token_ids)
self.input_handles[1].copy_from_cpu(pinyin_ids)
self.predictor.run()
det_preds = self.output_handle[0].copy_to_cpu()
char_preds = self.output_handle[1].copy_to_cpu()

batch_result = []
for i in range(len(lengths)):
batch_result.append(
(det_preds[i], char_preds[i], lengths[i]))
results.append(batch_result)
inputs['batch_results'] = results
return inputs

Expand Down Expand Up @@ -232,7 +266,7 @@ def _postprocess(self, inputs):

def _convert_example(self, example):
source = example["source"]
words = self._tokenizer.tokenize(text=source)
words = list(source)
if len(words) > self._max_seq_length - 2:
words = words[:self._max_seq_length - 2]
length = len(words)
Expand Down Expand Up @@ -269,64 +303,22 @@ def _convert_example(self, example):
def _parse_decode(self, words, corr_preds, det_preds, lengths):
UNK = self._tokenizer.unk_token
UNK_id = self._tokenizer.convert_tokens_to_ids(UNK)
tokens = self._tokenizer.tokenize(words)
if len(tokens) > self._max_seq_length - 2:
tokens = tokens[:self._max_seq_length - 2]

corr_pred = corr_preds[1:1 + lengths].tolist()
det_pred = det_preds[1:1 + lengths].tolist()
words = list(words)
rest_words = []
if len(words) > self._max_seq_length - 2:
rest_words = words[max_seq_length - 2:]
words = words[:self._max_seq_length - 2]

assert len(tokens) == len(
corr_pred
), "The number of tokens should be equal to the number of labels {}: {}: {}".format(
len(tokens), len(corr_pred), tokens)
pred_result = ""

align_offset = 0
# Need to be aligned
if len(words) != len(tokens):
first_unk_flag = True
for j, word in enumerate(words):
if word.isspace():
tokens.insert(j + 1, word)
corr_pred.insert(j + 1, UNK_id)
det_pred.insert(j + 1, 0) # No error
elif tokens[j] != word:
if self._tokenizer.convert_tokens_to_ids(word) == UNK_id:
if first_unk_flag:
first_unk_flag = False
corr_pred[j] = UNK_id
det_pred[j] = 0
else:
tokens.insert(j, UNK)
corr_pred.insert(j, UNK_id)
det_pred.insert(j, 0) # No error
continue
elif tokens[j] == UNK:
# Remove rest unk
k = 0
while k + j < len(tokens) and tokens[k + j] == UNK:
k += 1
tokens = tokens[:j] + tokens[j + k:]
corr_pred = corr_pred[:j] + corr_pred[j + k:]
det_pred = det_pred[:j] + det_pred[j + k:]
else:
# Maybe English, number, or suffix
token = tokens[j].lstrip("##")
corr_pred = corr_pred[:j] + [UNK_id] * len(
token) + corr_pred[j + 1:]
det_pred = det_pred[:j] + [0] * len(token) + det_pred[
j + 1:]
tokens = tokens[:j] + list(token) + tokens[j + 1:]
first_unk_flag = True

for j, word in enumerate(words):
candidates = self._tokenizer.convert_ids_to_tokens(corr_pred[j])
if det_pred[j] == 0 or candidates == UNK or candidates == '[PAD]':
if not is_chinese_char(ord(word)) or det_pred[
j] == 0 or candidates == UNK or candidates == '[PAD]':
pred_result += word
else:
pred_result += candidates.lstrip("##")

pred_result += ''.join(rest_words)
return pred_result

0 comments on commit 42447b8

Please sign in to comment.