Skip to content

Commit

Permalink
Use new APIs with k2.RaggedTensor (#38)
Browse files Browse the repository at this point in the history
* Use new APIs with k2.RaggedTensor

* Fix style issues.

* Update the installation doc, saying it requires at least k2 v1.7

* Use k2 v1.7
  • Loading branch information
csukuangfj authored Sep 8, 2021
1 parent 331e5eb commit abadc71
Show file tree
Hide file tree
Showing 26 changed files with 197 additions and 147 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-yesno-recipe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ jobs:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.4.dev20210822"]
k2-version: ["1.7.dev20210908"]

fail-fast: false

steps:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ path.sh
exp
exp*/
*.pt
download/
download
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import sphinx_rtd_theme


# -- Project information -----------------------------------------------------

project = "icefall"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation/images/device-CPU_CUDA-orange.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/installation/images/k2-v-1.7.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/installation/images/os-Linux_macOS-ff69b4.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 18 additions & 5 deletions docs/source/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Installation
- |device|
- |python_versions|
- |torch_versions|
- |k2_versions|

.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems
Expand All @@ -20,7 +21,10 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions

icefall depends on `k2 <https://github.com/k2-fsa/k2>`_ and
.. |k2_versions| image:: ./images/k2-v-1.7.svg
:alt: Supported k2 versions

``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_.

We recommend you to install ``k2`` first, as ``k2`` is bound to
Expand All @@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``.
--------------

Please refer to `<https://k2.readthedocs.io/en/latest/installation/index.html>`_
to install `k2`.
to install ``k2``.

.. CAUTION::

You need to install ``k2`` with a version at least **v1.7**.

.. HINT::

If you have already installed PyTorch and don't want to replace it,
please install a version of k2 that is compiled against the version
please install a version of ``k2`` that is compiled against the version
of PyTorch you are using.

(2) Install lhotse
Expand All @@ -50,10 +58,15 @@ to install ``lhotse``.

Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.

.. CAUTION::

If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.

(3) Download icefall
--------------------

icefall is a collection of Python scripts, so you don't need to install it
``icefall`` is a collection of Python scripts, so you don't need to install it
and we don't provide a ``setup.py`` to install it.

What you need is to download it and set the environment variable ``PYTHONPATH``
Expand Down Expand Up @@ -367,7 +380,7 @@ Now let us run the training part:
.. CAUTION::

We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU
We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
even if there are GPUs available.

The training log is given below:
Expand Down
1 change: 0 additions & 1 deletion docs/source/recipes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future.
yesno

librispeech

18 changes: 9 additions & 9 deletions docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ After downloading, you will have the following files:
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 10 files
Expand Down Expand Up @@ -256,14 +256,14 @@ The output is:
2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
Expand Down Expand Up @@ -297,14 +297,14 @@ The decoding output is:
2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
Expand Down
1 change: 0 additions & 1 deletion egs/librispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE
|--|--|
|test-clean|0.8|
|test-other|0.9|

19 changes: 19 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

Expand Down Expand Up @@ -116,6 +117,17 @@ def get_parser():
""",
)

parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)

return parser


Expand Down Expand Up @@ -541,6 +553,13 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))

if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return

model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
# limitations under the License.


from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
from subsampling import Conv2dSubsampling, VggSubsampling


def test_conv2d_subsampling():
Expand Down
9 changes: 4 additions & 5 deletions egs/librispeech/ASR/conformer_ctc/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@


import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import (
Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)

from torch.nn.utils.rnn import pad_sequence


def test_encoder_padding_mask():
supervisions = {
Expand Down
6 changes: 3 additions & 3 deletions egs/librispeech/ASR/local/compile_hlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:

LG.labels[LG.labels >= first_token_disambig_id] = 0

assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
Expand Down
5 changes: 4 additions & 1 deletion egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def get_params() -> AttributeDict:
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "whole-lattice-rescoring",
# "method": "1best",
# "method": "nbest",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30,
"num_paths": 100,
}
)
return params
Expand Down Expand Up @@ -424,6 +426,7 @@ def main():
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return

model.to(device)
model.eval()
Expand Down
Empty file modified egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
100644 → 100755
Empty file.
6 changes: 3 additions & 3 deletions egs/yesno/ASR/local/compile_hlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:

LG.labels[LG.labels >= first_token_disambig_id] = 0

assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
Expand Down
1 change: 1 addition & 0 deletions egs/yesno/ASR/tdnn/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def main():
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return

model.to(device)
model.eval()
Expand Down
Loading

0 comments on commit abadc71

Please sign in to comment.