Skip to content

Commit

Permalink
[cli] support gpu for cli (#2101)
Browse files Browse the repository at this point in the history
* support gpu for cli

* update options device to gpu
  • Loading branch information
yuekaizhang authored Nov 2, 2023
1 parent 387c1a1 commit 67628f6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
10 changes: 10 additions & 0 deletions docs/python_package.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
pip install git+https://github.com/wenet-e2e/wenet.git
```

## Development Install

``` sh
git clone https://github.com/wenet-e2e/wenet.git
cd wenet
pip install -e .
```


## Command line Usage

``` sh
Expand All @@ -17,6 +26,7 @@ You can specify the following parameters.

* `-l` or `--language`: chinese/english are supported now.
* `-m` or `--model_dir`: your own model dir
* `-g` or `--gpu`: the device id of gpu, default value -1 represents for cpu.
* `-t` or `--show_tokens_info`: show the token level information such as timestamp, confidence, etc.


Expand Down
20 changes: 16 additions & 4 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,23 @@


class Model:
def __init__(self, model_dir: str):
def __init__(self, model_dir: str, gpu: int = -1):
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
if gpu >= 0:
device = 'cuda:{}'.format(gpu)
else:
device = 'cpu'
self.device = torch.device(device)
self.model = self.model.to(self.device)
self.symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in self.symbol_table.items()}

def compute_feats(self, audio_file: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
waveform = waveform.to(self.device)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
Expand All @@ -52,7 +59,10 @@ def _decode(self,
label: str = None) -> dict:
feats = self.compute_feats(audio_file)
encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1)
encoder_lens = torch.tensor([encoder_out.size(1)], dtype=torch.long)
encoder_lens = torch.tensor([
encoder_out.size(1)],
dtype=torch.long,
device=encoder_out.device)
ctc_probs = self.model.ctc_activation(encoder_out)
if label is None:
ctc_prefix_results = ctc_prefix_beam_search(
Expand Down Expand Up @@ -117,7 +127,9 @@ def align(self, audio_file: str, label: str) -> dict:
return self._decode(audio_file, True, label)


def load_model(language: str = None, model_dir: str = None) -> Model:
def load_model(language: str = None,
model_dir: str = None,
gpu: int = -1) -> Model:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Model(model_dir)
return Model(model_dir, gpu)
19 changes: 14 additions & 5 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,35 @@

class Paraformer:

def __init__(self, model_dir: str) -> None:
def __init__(self, model_dir: str, device: int = -1) -> None:

model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
if device >= 0:
device = 'cuda:{}'.format(device)
else:
device = 'cpu'
self.device = torch.device(device)
self.model = self.model.to(self.device)
symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in symbol_table.items()}
self.eos = 2

def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
waveform = waveform.to(torch.float).to(self.device)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=16000)
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)
feats_lens = torch.tensor([
feats.size(1)],
dtype=torch.int64,
device=feats.device)

decoder_out, token_num = self.model.forward_paraformer(
feats, feats_lens)
Expand Down Expand Up @@ -62,7 +71,7 @@ def align(self, audio_file: str, label: str) -> dict:
raise NotImplementedError("Align is currently not supported")


def load_model(model_dir: str = None) -> Paraformer:
def load_model(model_dir: str = None, gpu: int = -1) -> Paraformer:
if model_dir is None:
model_dir = Hub.get_model_by_lang('paraformer')
return Paraformer(model_dir)
return Paraformer(model_dir, gpu)
9 changes: 7 additions & 2 deletions wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def get_args():
'--model_dir',
default=None,
help='specify your own model dir')
parser.add_argument('-g',
'--gpu',
type=int,
default='-1',
help='gpu id to decode, default is cpu.')
parser.add_argument('-t',
'--show_tokens_info',
action='store_true',
Expand All @@ -53,9 +58,9 @@ def main():
args = get_args()

if args.paraformer:
model = load_paraformer(args.model_dir)
model = load_paraformer(args.model_dir, args.gpu)
else:
model = load_model(args.language, args.model_dir)
model = load_model(args.language, args.model_dir, args.gpu)
if args.align:
result = model.align(args.audio_file, args.label)
else:
Expand Down

0 comments on commit 67628f6

Please sign in to comment.