Skip to content

Commit

Permalink
Merge remote-tracking branch 'dan/master' into ctc-ali
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 23, 2021
2 parents d8bef09 + 455693a commit 4580ff1
Show file tree
Hide file tree
Showing 21 changed files with 717 additions and 629 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/run-yesno-recipe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ on:
branches:
- master
pull_request:
branches:
- master
types: [labeled]

jobs:
run-yesno-recipe:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
# os: [ubuntu-18.04, macos-10.15]
# TODO: enable macOS for CPU testing
os: [ubuntu-18.04]
python-version: [3.8]
torch: ["1.8.1"]
k2-version: ["1.9.dev20210919"]
fail-fast: false

steps:
Expand All @@ -54,10 +56,8 @@ jobs:
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ on:
branches:
- master
pull_request:
branches:
- master
types: [labeled]

jobs:
test:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.7.dev20210914"]
k2-version: ["1.9.dev20210919"]

fail-fast: false

Expand Down
1 change: 0 additions & 1 deletion docs/source/installation/images/k2-v-1.7.svg

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/installation/images/k2-v1.9-blueviolet.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/source/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Installation
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
:alt: Supported PyTorch versions

.. |k2_versions| image:: ./images/k2-v-1.7.svg
.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
:alt: Supported k2 versions

``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
Expand All @@ -40,7 +40,7 @@ to install ``k2``.

.. CAUTION::

You need to install ``k2`` with a version at least **v1.7**.
You need to install ``k2`` with a version at least **v1.9**.

.. HINT::

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/conformer_ctc/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_encoder(
"""
Args:
x:
The model input. Its shape is [N, T, C].
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
Expand Down
43 changes: 29 additions & 14 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
# at entry, feature is (N, T, C)

supervisions = batch["supervisions"]

nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# nnet_output is (N, T, C)

supervision_segments = torch.stack(
(
Expand All @@ -252,14 +252,19 @@ def decode_one_batch(
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}

if params.method in ["1best", "nbest"]:
if params.method == "1best":
Expand All @@ -272,7 +277,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa

Expand All @@ -296,17 +301,23 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`

best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
Expand All @@ -316,16 +327,20 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"

ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans


Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/conformer_ctc/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
Expand Down
32 changes: 16 additions & 16 deletions egs/librispeech/ASR/conformer_ctc/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
Expand All @@ -34,10 +34,10 @@ def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
Expand All @@ -58,18 +58,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x


Expand All @@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""

Expand All @@ -93,10 +93,10 @@ def __init__(self, idim: int, odim: int) -> None:
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()

Expand Down Expand Up @@ -149,10 +149,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,14 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)

supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# nnet_output is (N, T, C)

# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
Expand Down
Loading

0 comments on commit 4580ff1

Please sign in to comment.