Skip to content

Commit

Permalink
[paraformer] post process eos and @@ (#2099)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Nov 2, 2023
1 parent 55d0fc7 commit 387c1a1
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 3 deletions.
6 changes: 4 additions & 2 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.paraformer.search import paraformer_greedy_search
from wenet.paraformer.search import paraformer_beautify_result, paraformer_greedy_search
from wenet.utils.file_utils import read_symbol_table


class Paraformer:

def __init__(self, model_dir: str) -> None:

model_path = os.path.join(model_dir, 'final.zip')
Expand Down Expand Up @@ -39,7 +40,8 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
result = {}
result['confidence'] = res.confidence
# # TODO(Mddct): deal with '@@' and 'eos'
result['rec'] = "".join([self.char_dict[x] for x in res.tokens])
result['rec'] = paraformer_beautify_result(
[self.char_dict[x] for x in res.tokens])

if tokens_info:
tokens_info = []
Expand Down
105 changes: 104 additions & 1 deletion wenet/paraformer/search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,115 @@
import math
from typing import List, Tuple
from typing import Any, List, Tuple, Union
import torch

from wenet.transformer.search import DecodeResult
from wenet.utils.mask import (make_non_pad_mask, mask_finished_preds,
mask_finished_scores)


def _isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
return True
return False


def _isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
cur = cur.replace('<unk>', '')
cur = cur.replace('<OOV>', '')
word_lists.append(cur)

if len(word_lists) == 0:
return False

for ch in word_lists:
if _isChinese(ch) is False:
return False
return True


def _isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
cur = cur.replace('<unk>', '')
cur = cur.replace('<OOV>', '')
word_lists.append(cur)

if len(word_lists) == 0:
return False

for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and _isChinese(ch) is True:
return False

return True


def paraformer_beautify_result(tokens: List[str]) -> str:
middle_lists = []
word_lists = []
word_item = ''

# wash words lists
for token in tokens:
if token in ['<sos>', '<eos>', '<blank>']:
continue
else:
middle_lists.append(token)

# all chinese characters
if _isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(' ', ''))

# all alpha characters
elif _isAllAlpha(middle_lists):
for i, ch in enumerate(middle_lists):
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''

# mix characters
else:
alpha_blank = False
for i, ch in enumerate(middle_lists):
word = ''
if _isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
elif _isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
else:
word_lists.append(ch)

return ''.join(word_lists).strip()


def paraformer_greedy_search(
decoder_out: torch.Tensor,
decoder_out_lens: torch.Tensor) -> List[DecodeResult]:
Expand Down

0 comments on commit 387c1a1

Please sign in to comment.