Skip to content

Commit

Permalink
[wenet] mv all ctc functions to ctc_utils.py (#2057)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Oct 17, 2023
1 parent 900f68d commit 2d8c2a9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
2 changes: 1 addition & 1 deletion wenet/bin/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.ctc_util import forced_align
from wenet.utils.ctc_utils import forced_align
from wenet.utils.common import get_subsample
from wenet.utils.init_model import init_model

Expand Down
2 changes: 1 addition & 1 deletion wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.utils.common import remove_duplicates_and_blank
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.file_utils import read_symbol_table


Expand Down
2 changes: 1 addition & 1 deletion wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from wenet.utils.common import (add_sos_eos,
log_add,
remove_duplicates_and_blank,
reverse_pad_list)
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)

Expand Down
24 changes: 0 additions & 24 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,30 +232,6 @@ def get_subsample(config):
return 8


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
new_hyp.append(0)
cur += 1
return new_hyp


def log_add(args: List[int]) -> float:
"""
Expand Down
28 changes: 28 additions & 0 deletions wenet/utils/ctc_util.py → wenet/utils/ctc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import numpy as np

import torch

def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
new_hyp.append(0)
cur += 1
return new_hyp


def insert_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
Expand Down

0 comments on commit 2d8c2a9

Please sign in to comment.