Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli] support model dir, add related docs #2087

Merged
merged 2 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
pip install git+https://github.com/wenet-e2e/wenet.git
```

Command-line usage(use `-h` for parameters):
**Command-line usage** (use `-h` for parameters):

``` sh
wenet --language chinese audio.wav
```

Python programming usage:
**Python programming usage**:

``` python
import wenet
Expand All @@ -43,6 +43,8 @@ result = model.transcribe('audio.wav')
print(result['text'])
```

Please refer [python usage](docs/python_package.md) for more command line and python programming usage.

### Install for training & deployment

- Clone the repo
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ wenet is an tansformer-based end-to-end ASR toolkit.
:maxdepth: 2
:caption: Contents:

./python_package.md
./train.rst
./production.rst
./reference.rst
Expand Down
32 changes: 32 additions & 0 deletions docs/python_package.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Python Package


## Install

``` sh
pip install git+https://github.com/wenet-e2e/wenet.git
```

## Command line Usage

``` sh
wenet --language chinese audio.wav
```

You can specify the following parameters.

* `-l` or `--language`: chinese/english are supported now.
* `-m` or `--model_dir`: your own model dir
* `-t` or `--show_tokens_info`: show the token level information such as timestamp, confidence, etc.


## Python Programming Usage

``` python
import wenet

model = wenet.load_model('chinese')
# or model = wenet.load_model(model_dir='xxx')
result = model.transcribe('audio.wav')
print(result['text'])
```
9 changes: 5 additions & 4 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@


class Model:
def __init__(self, language: str):
model_dir = Hub.get_model_by_lang(language)
def __init__(self, model_dir: str):
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
Expand Down Expand Up @@ -74,5 +73,7 @@ def transcribe(self, audio_file: str, tokens_info: bool = False):
return result


def load_model(language: str) -> Model:
return Model(language)
def load_model(language: str = None, model_dir: str = None) -> Model:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Model(model_dir)
11 changes: 8 additions & 3 deletions wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,24 @@

import argparse

from wenet.cli.model import Model
from wenet.cli.model import load_model


def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('audio_file', help='audio file to transcribe')
parser.add_argument('--language',
parser.add_argument('-l',
'--language',
choices=[
'chinese',
'english',
],
default='chinese',
help='language type')
parser.add_argument('-m',
'--model_dir',
default=None,
help='specify your own model dir')
parser.add_argument('-t',
'--show_tokens_info',
action='store_true',
Expand All @@ -38,7 +43,7 @@ def get_args():

def main():
args = get_args()
model = Model(args.language)
model = load_model(args.language, args.model_dir)
result = model.transcribe(args.audio_file, args.show_tokens_info)
print(result)

Expand Down
Loading