Skip to content

Commit

Permalink
[paraformer] timestamp (#2277)
Browse files Browse the repository at this point in the history
* [paraformer] timestamp fist commit

* [paraformer] cif2 timestamp forward works

* [paraformer] cif2 forwar workd and export jit works

* [paraformer] fix import

* [paraformer] timestamp works

* [paraformer] add textgrid support commit , rm later

* [paraformer] delete textgrid

* [paraformer] fix unit test
  • Loading branch information
Mddct authored Jan 6, 2024
1 parent 62a486f commit 6243184
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 57 deletions.
2 changes: 1 addition & 1 deletion test/wenet/text/test_paraformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def paraformer_tokenizer(request):
_download_fn(download_root, seg_dict)

config_name = 'config.yaml'
_download_fn(download_root, config_name)
_download_fn(download_root, config_name, version='v1.2.4')
with open(os.path.join(download_root, config_name), 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
wenet_units = os.path.join(download_root, 'units.txt')
Expand Down
21 changes: 12 additions & 9 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.paraformer.search import paraformer_beautify_result, paraformer_greedy_search
from wenet.paraformer.search import (gen_timestamps_from_peak,
paraformer_beautify_result,
paraformer_greedy_search)
from wenet.text.paraformer_tokenizer import ParaformerTokenizer


Expand Down Expand Up @@ -45,28 +47,29 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
dtype=torch.int64,
device=feats.device)

decoder_out, token_num = self.model.forward_paraformer(
decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
feats, feats_lens)

res = paraformer_greedy_search(decoder_out, token_num)[0]

cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
result = {}
result['confidence'] = res.confidence
result['text'] = paraformer_beautify_result(
self.tokenizer.detokenize(res.tokens)[1])
if tokens_info:
tokens_info = []
times = gen_timestamps_from_peak(res.times,
num_frames=tp_alphas.size(1),
frame_rate=0.02)

for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.tokenizer.char_dict[x],
# TODO(Mddct): support times
# 'start': 0,
# 'end': 0,
'start': times[i][0],
'end': times[i][1],
'confidence': res.tokens_confidence[i]
})
result['tokens'] = tokens_info

# result = ''.join(hyp)
return result

def align(self, audio_file: str, label: str) -> dict:
Expand Down
66 changes: 48 additions & 18 deletions wenet/paraformer/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@

class Cif(nn.Module):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0.0,
tail_threshold=0.45,
residual=True,
cnn_groups=0,
):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
Expand All @@ -50,13 +52,15 @@ def __init__(self,
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
target_label: Optional[torch.Tensor] = None,
mask: torch.Tensor = torch.tensor(0),
ignore_id: int = -1,
mask_chunk_predictor: Optional[torch.Tensor] = None,
target_label_length: Optional[torch.Tensor] = None):
def forward(
self,
hidden,
target_label: Optional[torch.Tensor] = None,
mask: torch.Tensor = torch.tensor(0),
ignore_id: int = -1,
mask_chunk_predictor: Optional[torch.Tensor] = None,
target_label_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
Expand Down Expand Up @@ -94,6 +98,7 @@ def forward(self,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand Down Expand Up @@ -217,6 +222,31 @@ def forward(self, token_length, pre_token_length):
return loss


def cif_without_hidden(alphas: torch.Tensor, threshold: float):
# https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/models/predictor/cif.py#L187
batch_size, len_time = alphas.size()

# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []

for t in range(len_time):
alpha = alphas[:, t]

integrate += alpha
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate -
torch.ones([batch_size], device=alphas.device) * threshold,
integrate)

fires = torch.stack(list_fires, 1)
return fires


def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
batch_size, len_time, hidden_size = hidden.size()

Expand Down
48 changes: 33 additions & 15 deletions wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from pathlib import Path
import shutil
import urllib.request
import torch
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple

import yaml

from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.init_model import init_model


def _load_paraformer_cmvn(cmvn_file) -> Tuple[List, List]:
with open(cmvn_file, 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -107,7 +105,8 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6}

configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80
configs['predictor'] = 'cif_predictor'
# configs['predictor'] = 'cif_predictor'
configs['predictor'] = 'paraformer_predictor'
configs['predictor_conf'] = configs.pop('predictor_conf')
configs['predictor_conf']['cnn_groups'] = 1
configs['predictor_conf']['residual'] = False
Expand Down Expand Up @@ -162,10 +161,26 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
return configs


def convert_to_wenet_state_dict(args, configs, wenet_model_path):
args.checkpoint = args.paraformer_model
model, _ = init_model(args, configs)
save_checkpoint(model, wenet_model_path)
def convert_to_wenet_state_dict(args, wenet_model_path):
wenet_state_dict = {}
checkpoint = torch.load(args.paraformer_model, map_location='cpu')
for name in checkpoint.keys():
wenet_name = name

if wenet_name.startswith('predictor.cif_output2'):
wenet_name = wenet_name.replace('predictor.cif_output2.',
'predictor.tp_output.')
elif wenet_name.startswith('predictor.cif'):
wenet_name = wenet_name.replace('predictor.cif',
'predictor.predictor.cif')
elif wenet_name.startswith('predictor.upsample'):
wenet_name = wenet_name.replace('predictor.', 'predictor.tp_')
elif wenet_name.startswith('predictor.blstm'):
wenet_name = wenet_name.replace('predictor.', 'predictor.tp_')

wenet_state_dict[wenet_name] = checkpoint[name].float()

torch.save(wenet_state_dict, wenet_model_path)


def get_args():
Expand All @@ -190,11 +205,15 @@ def get_args():
return args


def _download_fn(output_dir, name, renmae: Optional[str] = None):
def _download_fn(output_dir,
name,
renmae: Optional[str] = None,
version: str = 'master'):
url = "https://www.modelscope.cn/api/v1/"\
"models/damo/"\
"speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
"/repo?Revision=v1.0.4&FilePath=" + name
"speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
"/repo?Revision={}&FilePath=".format(version) + name
# "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
if renmae is None:
output_file = os.path.join(output_dir, name)
else:
Expand Down Expand Up @@ -232,7 +251,7 @@ def may_get_assets_and_refine_args(args):
config_name = 'config.yaml'
args.paraformer_config = os.path.join(assets_dir, config_name)
if not os.path.exists(args.paraformer_config):
_download_fn(assets_dir, config_name)
_download_fn(assets_dir, config_name, version='v1.2.4')
if args.paraformer_cmvn is None:
cmvn_name = 'am.mvn'
args.paraformer_cmvn = os.path.join(assets_dir, cmvn_name)
Expand Down Expand Up @@ -280,11 +299,10 @@ def main():
'tokenizer_conf'
]
wenet_train_yaml = os.path.join(args.output_dir, "train.yaml")
wenet_configs = convert_to_wenet_yaml(configs, wenet_train_yaml,
fields_to_keep)
convert_to_wenet_yaml(configs, wenet_train_yaml, fields_to_keep)

wenet_model_path = os.path.join(args.output_dir, "wenet_paraformer.pt")
convert_to_wenet_state_dict(args, wenet_configs, wenet_model_path)
convert_to_wenet_state_dict(args, wenet_model_path)

print("Please check {} {} {} {} {} in {}".format(json_cmvn_path,
wenet_train_yaml,
Expand Down
Loading

0 comments on commit 6243184

Please sign in to comment.