Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] support arbitrary language and task #2342

Merged
merged 8 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ deepspeed<0.13.0
librosa
openai-whisper
pre-commit==3.5.0
langid
6 changes: 5 additions & 1 deletion test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
from torch.utils.data import datapipes
from torch.utils.data.datapipes.iter import IterableWrapper
from functools import partial

from wenet.dataset.datapipes import (SortDataPipe, WenetRawDatasetSource,
WenetTarShardDatasetSource)
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding,
parse_json, compute_fbank)
parse_json, compute_fbank,
detect_language, detect_task)


@pytest.mark.parametrize("data_list", [
Expand Down Expand Up @@ -98,6 +100,8 @@ def test_dynamic_batch_datapipe(data_list):
dataset = dataset.map(decode_wav)
dataset = dataset.map(compute_fbank)
dataset = dataset.map(fake_labels)
dataset = dataset.map(partial(detect_language, limited_langs=['zh', 'en']))
dataset = dataset.map(detect_task)
max_frames_in_batch = 10000
dataset = dataset.dynamic_batch(
window_class=DynamicBatchWindow(max_frames_in_batch),
Expand Down
6 changes: 2 additions & 4 deletions test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def test_sinusoids(length, channels):
@pytest.mark.parametrize("model,audio_path", [
("tiny", "test/resources/aishell-BAC009S0724W0121.wav"),
("base", "test/resources/librispeech-1995-1837-0001.wav"),
("small", "test/resources/aishell-BAC009S0724W0121.wav"),
("medium", "test/resources/librispeech-1995-1837-0001.wav"),
])
def test_model(model, audio_path):
default = os.path.join(os.path.expanduser("~"), ".cache")
Expand Down Expand Up @@ -362,9 +360,9 @@ def test_model(model, audio_path):
configs['tokenizer_conf']['special_tokens'],
torch.tensor([dummy_tokens], dtype=torch.long),
ignore_id=-1,
task=task,
tasks=[task],
no_timestamp=True,
language=language,
langs=[language],
use_prev=False)
L = wenet_tokens.size(1)
tgt_mask = ~make_pad_mask(torch.tensor([L], dtype=torch.long),
Expand Down
4 changes: 3 additions & 1 deletion wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def main():
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
infos = {"tasks": batch["tasks"], "langs": batch["langs"]}
results = model.decode(
args.modes,
feats,
Expand All @@ -257,7 +258,8 @@ def main():
context_graph=context_graph,
blank_id=blank_id,
blank_penalty=args.blank_penalty,
length_penalty=args.length_penalty)
length_penalty=args.length_penalty,
infos=infos)
for i, key in enumerate(keys):
for mode, hyps in results.items():
tokens = hyps[i].tokens
Expand Down
4 changes: 4 additions & 0 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def Dataset(data_type,
spec_trim_conf = conf.get('spec_trim_conf', {})
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))

language_conf = conf.get('language_conf', {"limited_langs": ['zh', 'en']})
dataset = dataset.map(partial(processor.detect_language, **language_conf))
dataset = dataset.map(processor.detect_task)

shuffle = conf.get('shuffle', True)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
Expand Down
32 changes: 32 additions & 0 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import json
from subprocess import PIPE, Popen
from urllib.parse import urlparse
from langid.langid import LanguageIdentifier, model
import logging
import librosa
import random

Expand All @@ -31,6 +33,10 @@

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])

lid = LanguageIdentifier.from_modelstring(model, norm_probs=True)

logging.getLogger('langid').setLevel(logging.INFO)


class UrlOpenError(Exception):

Expand Down Expand Up @@ -79,6 +85,28 @@ def parse_speaker(sample, speaker_dict):
return sample


def detect_language(sample, limited_langs):
assert 'txt' in sample
# NOTE(xcsong): Because language classification may not be very accurate
# (for example, Chinese being classified as Japanese), our workaround,
# given we know for certain that the training data only consists of
# Chinese and English, is to limit the classification results to reduce
# the impact of misclassification.
lid.set_languages(limited_langs)
# i.e., ('zh', 0.9999999909903544)
sample['lang'] = lid.classify(sample['txt'])[0]
return sample


def detect_task(sample):
# TODO(xcsong): Currently, the task is hard-coded to 'transcribe'.
# In the future, we could dynamically determine the task based on
# the contents of sample. For instance, if a sample contains both
# 'txt_en' and 'txt_zh', the task should be set to 'translate'.
sample['task'] = "transcribe"
return sample


def decode_wav(sample):
""" Parse key/wav/txt from json line

Expand Down Expand Up @@ -452,6 +480,8 @@ def padding(data):
torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
]
sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order]
langs = [sample[i]['lang'] for i in order]
tasks = [sample[i]['task'] for i in order]
label_lengths = torch.tensor([x.size(0) for x in sorted_labels],
dtype=torch.int32)
wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs],
Expand All @@ -472,6 +502,8 @@ def padding(data):
"target_lengths": label_lengths,
"pcm": padded_wavs,
"pcm_length": wav_lengths,
"langs": langs,
"tasks": tasks,
}
if 'speaker' in sample[0]:
speaker = torch.tensor([sample[i]['speaker'] for i in order],
Expand Down
11 changes: 8 additions & 3 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,14 @@ def forward(
}

def _calc_att_loss(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
ys_pad: torch.Tensor, ys_pad_emb: torch.Tensor,
ys_pad_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_emb: torch.Tensor,
ys_pad_lens: torch.Tensor,
infos: Dict[str, List[str]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
decoder_out, _, _ = self.decoder(encoder_out, encoder_mask, ys_pad_emb,
ys_pad_lens)
loss_att = self.criterion_att(decoder_out, ys_pad)
Expand Down
14 changes: 10 additions & 4 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ def forward(
encoder_out, encoder_mask = self.filter_blank_embedding(
ctc_probs, encoder_out)
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
loss_att, acc_att = self._calc_att_loss(
encoder_out, encoder_mask, text, text_lengths, {
"langs": batch["langs"],
"tasks": batch["tasks"]
})
else:
loss_att = None
acc_att = None
Expand Down Expand Up @@ -174,6 +177,7 @@ def _calc_att_loss(
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
infos: Dict[str, List[str]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
Expand Down Expand Up @@ -256,6 +260,7 @@ def decode(
blank_id: int = 0,
blank_penalty: float = 0.0,
length_penalty: float = 0.0,
infos: Dict[str, List[str]] = None,
) -> Dict[str, List[DecodeResult]]:
""" Decode input speech

Expand Down Expand Up @@ -292,7 +297,8 @@ def decode(
results = {}
if 'attention' in methods:
results['attention'] = attention_beam_search(
self, encoder_out, encoder_mask, beam_size, length_penalty)
self, encoder_out, encoder_mask, beam_size, length_penalty,
infos)
if 'ctc_greedy_search' in methods:
results['ctc_greedy_search'] = ctc_greedy_search(
ctc_probs, encoder_lens, blank_id)
Expand All @@ -314,7 +320,7 @@ def decode(
ctc_probs, encoder_out)
results['attention_rescoring'] = attention_rescoring(
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
reverse_weight)
reverse_weight, infos)
return results

@torch.jit.export
Expand Down
44 changes: 24 additions & 20 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import math
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Dict

import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.utils.common import (add_sos_eos, log_add, WHISPER_LANGS,
add_whisper_tokens)
from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens)
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
Expand Down Expand Up @@ -253,6 +252,7 @@ def attention_beam_search(
encoder_mask: torch.Tensor,
beam_size: int = 10,
length_penalty: float = 0.0,
infos: Dict[str, List[str]] = None,
) -> List[DecodeResult]:
device = encoder_out.device
batch_size = encoder_out.shape[0]
Expand All @@ -265,17 +265,20 @@ def attention_beam_search(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, 1, maxlen) # (B*N, 1, max_len)

if getattr(model, 'special_tokens', None) is not None \
and "transcribe" in model.special_tokens:
hyps = torch.ones([running_size, 4], dtype=torch.long,
device=device) # (B*N, 4)
# TODO(xcsong): add args for language, task, etc
hyps[:, 0] = model.special_tokens["sot"]
hyps[:,
1] = model.special_tokens["sot"] + 1 + WHISPER_LANGS.index("zh")
hyps[:, 2] = model.special_tokens["transcribe"]
hyps[:, 3] = model.special_tokens["no_timestamps"]
tasks, langs = infos["tasks"], infos["langs"]
tasks = [t for t in tasks for _ in range(beam_size)]
langs = [l for l in langs for _ in range(beam_size)]
hyps = torch.ones([running_size, 0], dtype=torch.long,
device=device) # (B*N, 0)
hyps, _ = add_whisper_tokens(model.special_tokens,
hyps,
model.ignore_id,
tasks=tasks,
no_timestamp=True,
langs=langs,
use_prev=False)
else:
hyps = torch.ones([running_size, 1], dtype=torch.long,
device=device).fill_(model.sos) # (B*N, 1)
Expand Down Expand Up @@ -360,6 +363,7 @@ def attention_rescoring(
encoder_lens: torch.Tensor,
ctc_weight: float = 0.0,
reverse_weight: float = 0.0,
infos: Dict[str, List[str]] = None,
) -> List[DecodeResult]:
"""
Args:
Expand All @@ -382,15 +386,15 @@ def attention_rescoring(
dtype=torch.long) # (beam_size,)
if getattr(model, 'special_tokens', None) is not None \
and "transcribe" in model.special_tokens:
# TODO(xcsong): add args for language, task, etc
prev_len = hyps_pad.size(1)
hyps_pad, _ = add_whisper_tokens(model.special_tokens,
hyps_pad,
model.ignore_id,
task="transcribe",
no_timestamp=True,
language="zh",
use_prev=False)
hyps_pad, _ = add_whisper_tokens(
model.special_tokens,
hyps_pad,
model.ignore_id,
tasks=[infos["tasks"][b]] * len(hyps),
no_timestamp=True,
langs=[infos["langs"][b]] * len(hyps),
use_prev=False)
cur_len = hyps_pad.size(1)
hyps_lens = hyps_lens + cur_len - prev_len
prefix_len = 4
Expand Down
Loading
Loading