Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GigaSpeech RNN-T experiments #318

Merged
merged 12 commits into from
May 13, 2022
Merged

Conversation

wgb14
Copy link
Contributor

@wgb14 wgb14 commented Apr 17, 2022

No description provided.

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 19, 2022

Post Dan's Suggestions here:

Guanbo, if you want to run an RNN-T recipe, please now use the setup in pruned_transducer_stateless2 from librispeech. This converges much faster than the old setup. The only options you might want to change are: --lr-epochs (reduce from 6 to some number less than about half the number of epochs you plan to train, e.g. 2 or 3). You might want to add --use-fp16=True to use half-precision to speed up training... and you can adjust --max-duration accordingly. --use-fp16=True allows, on our GPUs, to increase max-duration from 300 to 550. (However it will converge slower near the start of training if the max-duration is too high, so this is a tradeoff).

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 19, 2022

However, each time I set a larger --max-duration, either in Conformer-CTC or RNN-T scripts, I got this error:

--max-duration 200: (during scan_pessimistic_batches_for_oom)

Traceback (most recent call last):
  File "./pruned_transducer_stateless2/train.py", line 977, in <module>
    main()
  File "./pruned_transducer_stateless2/train.py", line 968, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 863, in run
    params=params,
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 944, in scan_pessimistic_batches_for_oom
    loss.backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

--max-duration 150: (during train_one_epoch)

/opt/conda/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 11 leaked semaphores to clean up at shutdown
  len(cache))
/opt/conda/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 11 leaked semaphores to clean up at shutdown
  len(cache))
Traceback (most recent call last):
  File "./pruned_transducer_stateless2/train.py", line 977, in <module>
    main()
  File "./pruned_transducer_stateless2/train.py", line 968, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 892, in run
    rank=rank,
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 686, in train_one_epoch
    scaler.scale(loss).backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

It is said this CUDA error: invalid configuration argument is another type of CUDA OOM. The GPU cards I'm using has 32GB of memory. In my experiments, i can only set --max-duration 120 for bpe_500 and --max-duration 150 for bpe_5000, both in Conformer-CTC and RNN-T. I cannot benefit from use-fp16.

@csukuangfj
Copy link
Collaborator

I find that you have removed the following block from train.py:

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)

Could you restore it? You may need to adjust the max duration threshold for GigaSpeech.
From my experience, 20 is also a good choice for the GigaSpeech dataset.

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 20, 2022

I find that you have removed the following block from train.py:

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)

Could you restore it? You may need to adjust the max duration threshold for GigaSpeech. From my experience, 20 is also a good choice for the GigaSpeech dataset.

Gigaspeech has already filtered out segments longer than 20 s, but kept those shorter than 1 s. I can try this though, since we did speed perturbation, and some utterances could be 22 s long.
But I wonder if it is 20 s to 22 s that caused max-duration dropped from 550 to 120.

@csukuangfj
Copy link
Collaborator

I am also using GigaSpeech dataset for training in #312

It works very well with --max-duration 300 after applying remove_short_and_long_utt()

@csukuangfj
Copy link
Collaborator

I can try this though, since we did speed perturbation, and some utterances could be 22 s long.

Yes, I also noticed that. But such utterances account for only less than 0.1% of the data. It is safe to remove them, I think.

@danpovey
Copy link
Collaborator

Our pruned RNN-T is quite memory efficient for longer utterances; I'd be more worried about shorter ones, since constant factors can sometimes be important. But we need to be operating from definite knowledge. Please add try/except to catch the error and print out the failing batch's details (I thought we had this code at one point); and also, to catch things earlier possibly, do
export CUDA_LAUNCH_BLOCKING=1
export K2_SYNC_KERNELS=1

@danpovey
Copy link
Collaborator

... BTW, Desh Raj @desh2608 has reported a similar problems at CLSP. I'm not sure the cause, but we need to try hard to debug this thoroughly. I'd also maybe try installing nsys and running it under "nsys profile".. that may give us additional info.

@danpovey
Copy link
Collaborator

... another possibility, which is a bit ugly but would work if the problem is short utterances, would be to enforce a maximum number of utterances per batch, i.e. a maximum batch size, e.g. 512, in the main training loop, by just discarding any extra elements.

@pzelasko
Copy link
Collaborator

... another possibility, which is a bit ugly but would work if the problem is short utterances, would be to enforce a maximum number of utterances per batch, i.e. a maximum batch size, e.g. 512, in the main training loop, by just discarding any extra elements.

Just add max_cuts=512 in the sampler.

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 20, 2022

... another possibility, which is a bit ugly but would work if the problem is short utterances, would be to enforce a maximum number of utterances per batch, i.e. a maximum batch size, e.g. 512, in the main training loop, by just discarding any extra elements.

Just add max_cuts=512 in the sampler.

Do we support max_cuts in DynamicBucketingSampler? I didn't see that in
https://github.com/lhotse-speech/lhotse/blob/b3b96a1d64fa3bd97c9c9bf32ef0e8b4806b87ef/lhotse/dataset/sampling/dynamic_bucketing.py#L25

@pzelasko
Copy link
Collaborator

Oh right, dynamic bucketing is the only one that does not support it yet... it shouldn't be too complicated but I'd like to test it properly, if you need it feel free to contribute.

@pzelasko
Copy link
Collaborator

OK I found a moment to add it after all, try this PR lhotse-speech/lhotse#681

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 21, 2022

Our pruned RNN-T is quite memory efficient for longer utterances; I'd be more worried about shorter ones, since constant factors can sometimes be important. But we need to be operating from definite knowledge. Please add try/except to catch the error and print out the failing batch's details (I thought we had this code at one point); and also, to catch things earlier possibly, do export CUDA_LAUNCH_BLOCKING=1 export K2_SYNC_KERNELS=1

Yes, we do have this detail, but in original script it only print when the error is "CUDA out of memory". By changing to print whatever error is, I got this log:

2022-04-21 09:36:13,058 ERROR [train_fix.py:964] (1/8) Your GPU ran out of memory with the current max_duration setting. We recommend decreasing max_duration and trying again.
Failing criterion: largest_batch_cuts_duration (=300.0) ...
Traceback (most recent call last):
  File "./pruned_transducer_stateless2/train_fix.py", line 991, in <module>
    main()
  File "./pruned_transducer_stateless2/train_fix.py", line 982, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train_fix.py", line 877, in run
    params=params,
  File "/userhome/user/guanbo/icefall_rnnt/egs/gigaspeech/ASR/pruned_transducer_stateless2/train_fix.py", line 958, in scan_pessimistic_batches_for_oom
    loss.backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: invalid configuration argument

So the Failing criterion is still the largest_batch_cuts_duration (=300.0), instead of "max_num_cuts"

By the way, I got this error with --max-duration 300 even after applying remove_short_and_long_utt(). And

export CUDA_LAUNCH_BLOCKING=1
export K2_SYNC_KERNELS=1

didn't give me any additional information.

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 21, 2022

Post Tensorboard so far: https://tensorboard.dev/experiment/RkuNltQOR9aI7BGr5AHbcA/
Can anybody give some comments and suggestions? This seems not natural to me.

@csukuangfj
Copy link
Collaborator

csukuangfj commented Apr 21, 2022

So do you finally manage to train it with max duration == 300?

The pruned loss looks ok, I think.

There should be some checkpoints in your exp dir. You can try to decode with them and see the WER.

@csukuangfj
Copy link
Collaborator

As for the error,

RuntimeError: CUDA error: invalid configuration argument

There is a similar issue at pytorch/pytorch#48573, which has been fixed in pytorch/pytorch#64194

Wondering if you are using some recent versions of PyTorch.

@csukuangfj
Copy link
Collaborator

csukuangfj commented Apr 21, 2022

Can anybody give some comments and suggestions? This seems not natural to me.

The following is the tensorboard log for #312, which uses 10% of the speed perturbed GigaSpeech data (i.e., 3k hours) for training.

https://tensorboard.dev/experiment/lVZFGwjKS9iMYHzrsAE2cw/#scalars&_smoothingWeight=0

And the following is the WERs of that PR for the GigaSpeech dev/test datasets:

(py38) kuangfangjun:modified_beam_search$ grep -r -n --color "best for test" log-* | sort -n -k2 | head  -n5
log-decode-epoch-34-avg-11-beam-4-2022-04-20-00-50-01:125:beam_size_4   12.33   best for test
log-decode-epoch-34-avg-16-beam-4-2022-04-20-02-22-16:125:beam_size_4   12.33   best for test
log-decode-epoch-34-avg-17-beam-4-2022-04-20-02-38-56:125:beam_size_4   12.33   best for test
log-decode-epoch-34-avg-18-beam-4-2022-04-20-02-55-32:125:beam_size_4   12.33   best for test
log-decode-epoch-34-avg-20-beam-4-2022-04-20-03-28-57:125:beam_size_4   12.33   best for test
(py38) kuangfangjun:modified_beam_search$ grep -r -n --color "best for dev" log-* | sort -n -k2 | head  -n5
log-decode-epoch-34-avg-15-beam-4-2022-04-20-02-04-44:170:beam_size_4   12.23   best for dev
log-decode-epoch-34-avg-16-beam-4-2022-04-20-02-22-16:170:beam_size_4   12.23   best for dev
log-decode-epoch-34-avg-12-beam-4-2022-04-20-01-08-26:170:beam_size_4   12.25   best for dev
log-decode-epoch-34-avg-14-beam-4-2022-04-20-01-46-06:170:beam_size_4   12.25   best for dev
log-decode-epoch-34-avg-17-beam-4-2022-04-20-02-38-56:170:beam_size_4   12.25   best for dev

In your tensorboard log, the pruned loss is around 0.07, which seems to be normal. Since you are using all the training data, I believe your WERs will be lower.

@wgb14
Copy link
Contributor Author

wgb14 commented Apr 21, 2022

Thanks, Fangjun. These really help a lot.
No, I'm still training with --max-duration 120, same with Conformer-CTC setup. I'm using torch 1.10.0, a version after that fix.

@desh2608 desh2608 mentioned this pull request May 2, 2022
@wgb14
Copy link
Contributor Author

wgb14 commented May 3, 2022

Update results here: (only on DEV set)

greedy search fast beam search modified beam search
WER 10.76 10.71 10.65

The RNN-T model was trained for 20 epochs, and the best numbers are from the 20th epoch (--epoch 19 --avg 6), while in Conformer CTC model, the best numbers are from the 19th epoch ( --epoch 18 --avg 6).
Not sure if I can get better numbers if I continue to train RNN-T for 10 more epochs.

@danpovey
Copy link
Collaborator

danpovey commented May 4, 2022

Cool!

@wgb14 wgb14 mentioned this pull request May 6, 2022
@wgb14
Copy link
Contributor Author

wgb14 commented May 12, 2022

Results are:

Dev Test
greedy search 10.59 10.87
fast beam search 10.56 10.80
modified beam search 10.52 10.62

This PR is ready for review

@wgb14 wgb14 marked this pull request as ready for review May 12, 2022 04:59
@csukuangfj
Copy link
Collaborator

Thanks! Looks great to me.

Please remove WIP when you think it is ready to merge.

@wgb14
Copy link
Contributor Author

wgb14 commented May 13, 2022

Ready to merge now.

By the way, we would like to add the results to gigaspeech leaderboard, what's the preferred name of this recipe/model? RNN-T or something like Pruned Stateless Transducer?

@wgb14 wgb14 changed the title WIP: GigaSpeech RNN-T experiments GigaSpeech RNN-T experiments May 13, 2022
@danpovey
Copy link
Collaborator

Maybe pruned stateless RNN-T?
The conformer model in pruned_transducer_stateless2 also has a lot of changes from vanilla conformer but it doesn't really have a name, so I think " pruned stateless RNN-T" is OK for now.

@@ -3,6 +3,13 @@

#### 2022-05-12

#### Conformer encoder + embedding decoder

Conformer encoder + non-current decoder. The encoder a reworked
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Conformer encoder + non-current decoder. The encoder a reworked
Conformer encoder + non-recurrent decoder. The encoder is a reworked

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@wgb14
Copy link
Contributor Author

wgb14 commented May 13, 2022

Maybe pruned stateless RNN-T? The conformer model in pruned_transducer_stateless2 also has a lot of changes from vanilla conformer but it doesn't really have a name, so I think " pruned stateless RNN-T" is OK for now.

I see. I will add some details in RESULTS.md

@csukuangfj
Copy link
Collaborator

Ok, I am merging it.

@csukuangfj csukuangfj merged commit 48a6a9a into k2-fsa:master May 13, 2022
@csukuangfj
Copy link
Collaborator

@wgb14

Could you also upload the decoding results, e.g., files like log-*, errs-*, recogs-*, to huggingface?

@csukuangfj
Copy link
Collaborator

csukuangfj commented May 13, 2022

I just reproduced the greedy search results locally. Here are the decoding logs:
(Note: I have renamed pretrained-epoch-29-avg-11.pt to epoch-11.pt during decoding)

./pruned_transducer_stateless2/decode.py \
  --epoch 11 \
  --avg 1 \
  --decoding-method greedy_search \
  --exp-dir pruned_transducer_stateless2/exp \
  --bpe-model data/lang_bpe_500/bpe.model \
  --max-duration 600
2022-05-13 11:28:21,598 INFO [decode.py:478] Decoding started
2022-05-13 11:28:21,598 INFO [decode.py:484] Device: cuda:0
2022-05-13 11:28:21,605 INFO [decode.py:493] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_
idx_train': 0, 'log_interval': 500, 'reset_interval': 2000, 'valid_interval': 20000, 'feature_dim': 80, 'subsampling_factor': 4, 'encoder_dim': 512,
'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 20000, 'env_info': {'k2-vers
ion': '1.15.1', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f8d2dba06c000ffee36aab5b66f24e7c9809f116', 'k2-git-date': 'Thu Apr
21 12:20:34 2022', 'lhotse-version': '1.1.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-ver
sion': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '48a6a9a-clean', 'icefall-git-date': 'Fri May 13 11:03:26
 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-master', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-22/k2/python/k2/__init__
.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-master/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0307200233-b554c565c-lf9qd',
 'IP address': '10.177.74.201'}, 'epoch': 11, 'avg': 1, 'avg_last_n': 0, 'exp_dir': PosixPath('pruned_transducer_stateless2/exp'), 'bpe_model': 'data
/lang_bpe_500/bpe.model', 'decoding_method': 'greedy_search', 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'context_size': 2, 'max_
sym_per_frame': 1, 'manifest_dir': PosixPath('data/fbank'), 'max_duration': 600, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': Fa
lse, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2, 'enable_spec_aug': True,
'spec_aug_time_warp_factor': 80, 'enable_musan': True, 'subset': 'XL', 'small_dev': False, 'res_dir': PosixPath('pruned_transducer_stateless2/exp/greedy_search'), 'suffix': 'epoch-11-avg-1-context-2-max-sym-per-frame-1', 'blank_id': 0, 'vocab_size': 500}
2022-05-13 11:28:21,605 INFO [decode.py:495] About to create model
2022-05-13 11:28:22,243 INFO [checkpoint.py:112] Loading checkpoint from pruned_transducer_stateless2/exp/epoch-11.pt
2022-05-13 11:28:31,272 INFO [decode.py:525] Number of model parameters: 78648040
2022-05-13 11:28:31,273 INFO [asr_datamodule.py:406] About to get dev cuts
2022-05-13 11:28:31,845 INFO [asr_datamodule.py:415] About to get test cuts
2022-05-13 11:28:35,545 INFO [decode.py:397] batch 0/?, cuts processed until now is 99
2022-05-13 11:29:24,008 INFO [decode.py:415] The transcripts are stored in pruned_transducer_stateless2/exp/greedy_search/recogs-dev-greedy_search-epoch-11-avg-1-context-2-max-sym-per-frame-1.txt
2022-05-13 11:29:24,198 INFO [utils.py:405] [dev-greedy_search] %WER 10.59% [13534 / 127790, 2987 ins, 3472 del, 7075 sub ]
2022-05-13 11:29:24,600 INFO [decode.py:428] Wrote detailed error stats to pruned_transducer_stateless2/exp/greedy_search/errs-dev-greedy_search-epoch-11-avg-1-context-2-max-sym-per-frame-1.txt
2022-05-13 11:29:24,601 INFO [decode.py:445]
For dev, WER of different settings are:
greedy_search   10.59   best for dev

2022-05-13 11:29:26,661 INFO [decode.py:397] batch 0/?, cuts processed until now is 118
2022-05-13 11:30:48,172 INFO [decode.py:397] batch 100/?, cuts processed until now is 10038
2022-05-13 11:31:41,922 INFO [decode.py:397] batch 200/?, cuts processed until now is 18926
2022-05-13 11:31:52,114 INFO [decode.py:415] The transcripts are stored in pruned_transducer_stateless2/exp/greedy_search/recogs-test-greedy_search-epoch-11-avg-1-context-2-max-sym-per-frame-1.txt
2022-05-13 11:31:52,676 INFO [utils.py:405] [test-greedy_search] %WER 10.87% [42460 / 390744, 6906 ins, 11246 del, 24308 sub ]
2022-05-13 11:31:54,014 INFO [decode.py:428] Wrote detailed error stats to pruned_transducer_stateless2/exp/greedy_search/errs-test-greedy_search-epoch-11-avg-1-context-2-max-sym-per-frame-1.txt
2022-05-13 11:31:54,015 INFO [decode.py:445]
For test, WER of different settings are:
greedy_search   10.87   best for test

One thing to note is that it takes only 1 minute to decode the dev dataset, while it takes about 1 hour and 18 minutes for conformer-ctc (from the log)

@wgb14
Copy link
Contributor Author

wgb14 commented May 13, 2022

@wgb14

Could you also upload the decoding results, e.g., files like log-*, errs-*, recogs-*, to huggingface?

Sure, will do.

One thing to note is that it takes only 1 minute to decode the dev dataset, while it takes about 1 hour and 18 minutes for conformer-ctc (from the log)

It seems that your cluster has a faster IO or larger memory cache. I'm redoing the decoding with --max-duration 600 --num_workers 4, and it takes me about 15 minutes to decode dev set with greedy-search

2022-05-13 11:52:35,758 INFO [decode.py:482] Decoding started
2022-05-13 11:52:35,759 INFO [decode.py:488] Device: cuda:0
2022-05-13 11:52:35,799 INFO [decode.py:497] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 500, 'reset_interval': 2000, 'valid_interval': 20000, 'feature_dim': 80, 'subsampling_factor': 4, 'encoder_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 20000, 'env_info': {'k2-version': '1.14', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '1b29f0a946f50186aaa82df46a59f492ade9692b', 'k2-git-date': 'Wed Apr 13 08:46:49 2022', 'lhotse-version': '1.1.0', 'torch-version': '1.10.0', 'torch-cuda-available': True, 'torch-cuda-version': '11.1', 'python-version': '3.7', 'icefall-git-branch': 'gigaspeech_rnnt', 'icefall-git-sha1': '2d07df5-dirty', 'icefall-git-date': 'Sun Apr 17 10:01:21 2022', 'icefall-path': '/userhome/user/guanbo/icefall_rnnt', 'k2-path': '/opt/conda/lib/python3.7/site-packages/k2-1.14.dev20220513+cuda11.1.torch1.10.0-py3.7-linux-x86_64.egg/k2/__init__.py', 'lhotse-path': '/userhome/user/guanbo/lhotse/lhotse/__init__.py', 'hostname': 'c168bad00d26f011ec09c520b9eb77fc7b8a-chenx8564-0', 'IP address': '10.104.201.14'}, 'epoch': 29, 'avg': 11, 'avg_last_n': 0, 'exp_dir': PosixPath('pruned_transducer_stateless2/exp_500_8'), 'bpe_model': 'data/lang_bpe_500/bpe.model', 'decoding_method': 'greedy_search', 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'context_size': 2, 'max_sym_per_frame': 1, 'manifest_dir': PosixPath('data/fbank'), 'max_duration': 600, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 4, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': True, 'subset': 'XL', 'small_dev': False, 'res_dir': PosixPath('pruned_transducer_stateless2/exp_500_8/greedy_search'), 'suffix': 'epoch-29-avg-11-context-2-max-sym-per-frame-1', 'blank_id': 0, 'vocab_size': 500}
2022-05-13 11:52:35,800 INFO [decode.py:499] About to create model
2022-05-13 11:52:36,199 INFO [decode.py:515] averaging ['pruned_transducer_stateless2/exp_500_8/epoch-19.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-20.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-21.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-22.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-23.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-24.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-25.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-26.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-27.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-28.pt', 'pruned_transducer_stateless2/exp_500_8/epoch-29.pt']
2022-05-13 11:55:48,958 INFO [decode.py:529] Number of model parameters: 78648040
2022-05-13 11:55:48,958 INFO [asr_datamodule.py:406] About to get dev cuts
2022-05-13 11:55:49,292 INFO [asr_datamodule.py:415] About to get test cuts
2022-05-13 11:56:45,876 INFO [decode.py:398] batch 0/?, cuts processed until now is 99
2022-05-13 12:10:29,422 INFO [decode.py:415] The transcripts are stored in pruned_transducer_stateless2/exp_500_8/greedy_search/recogs-dev-greedy_search-epoch-29-avg-11-context-2-max-sym-per-frame-1.txt
2022-05-13 12:10:29,589 INFO [utils.py:406] [dev-greedy_search] %WER 10.59% [13536 / 127790, 2978 ins, 3466 del, 7092 sub ]
2022-05-13 12:10:29,980 INFO [decode.py:428] Wrote detailed error stats to pruned_transducer_stateless2/exp_500_8/greedy_search/errs-dev-greedy_search-epoch-29-avg-11-context-2-max-sym-per-frame-1.txt
2022-05-13 12:10:30,011 INFO [decode.py:449] 
For dev, WER of different settings are:
greedy_search	10.59	best for dev

@csukuangfj
Copy link
Collaborator

It seems that your cluster has a faster IO or larger memory cache.

I have created a google colab notebook to decode gigaspeech.
See https://colab.research.google.com/drive/14FL2q0uQt3hC4TQ61lseV3wM3gxbY58R?usp=sharing

It takes less than 3 minutes in the Colab notebook to decode the dev dataset with greedy search.

Screen Shot 2022-05-13 at 14 43 15

yaozengwei added a commit that referenced this pull request May 15, 2022
* Remove ReLU in attention

* Adding diagnostics code...

* Refactor/simplify ConformerEncoder

* First version of rand-combine iterated-training-like idea.

* Improvements to diagnostics (RE those with 1 dim

* Add pelu to this good-performing setup..

* Small bug fixes/imports

* Add baseline for the PeLU expt, keeping only the small normalization-related changes.

* pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units.

* Double learning rate of exp-scale units

* Combine ExpScale and swish for memory reduction

* Add import

* Fix backprop bug

* Fix bug in diagnostics

* Increase scale on Scale from 4 to 20

* Increase scale from 20 to 50.

* Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

* Reduce scale from 50 to 20

* Add deriv-balancing code

* Double the threshold in brelu; slightly increase max_factor.

* Fix exp dir

* Convert swish nonlinearities to ReLU

* Replace relu with swish-squared.

* Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

* Replace norm on input layer with scale of 0.1.

* Extensions to diagnostics code

* Update diagnostics

* Add BasicNorm module

* Replace most normalizations with scales (still have norm in conv)

* Change exp dir

* Replace norm in ConvolutionModule with a scaling factor.

* use nonzero threshold in DerivBalancer

* Add min-abs-value 0.2

* Fix dirname

* Change min-abs threshold from 0.2 to 0.5

* Scale up pos_bias_u and pos_bias_v before use.

* Reduce max_factor to 0.01

* Fix q*scaling logic

* Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code.

* init 1st conv module to smaller variance

* Change how scales are applied; fix residual bug

* Reduce min_abs from 0.5 to 0.2

* Introduce in_scale=0.5 for SwishExpScale

* Fix scale from 0.5 to 2.0 as I really intended..

* Set scaling on SwishExpScale

* Add identity pre_norm_final for diagnostics.

* Add learnable post-scale for mha

* Fix self.post-scale-mha

* Another rework, use scales on linear/conv

* Change dir name

* Reduce initial scaling of modules

* Bug-fix RE bias

* Cosmetic change

* Reduce initial_scale.

* Replace ExpScaleRelu with DoubleSwish()

* DoubleSwish fix

* Use learnable scales for joiner and decoder

* Add max-abs-value constraint in DerivBalancer

* Add max-abs-value

* Change dir name

* Remove ExpScale in feedforward layes.

* Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer.

* Make DoubleSwish more memory efficient

* Reduce constraints from deriv-balancer in ConvModule.

* Add warmup mode

* Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

* Add some extra info to diagnostics

* Add deriv-balancer at output of embedding.

* Add more stats.

* Make epsilon in BasicNorm learnable, optionally.

* Draft of 0mean changes..

* Rework of initialization

* Fix typo

* Remove dead code

* Modifying initialization from normal->uniform; add initial_scale when initializing

* bug fix re sqrt

* Remove xscale from pos_embedding

* Remove some dead code.

* Cosmetic changes/renaming things

* Start adding some files..

* Add more files..

* update decode.py file type

* Add remaining files in pruned_transducer_stateless2

* Fix diagnostics-getting code

* Scale down pruned loss in warmup mode

* Reduce warmup scale on pruned loss form 0.1 to 0.01.

* Remove scale_speed, make swish deriv more efficient.

* Cosmetic changes to swish

* Double warm_step

* Fix bug with import

* Change initial std from 0.05 to 0.025.

* Set also scale for embedding to 0.025.

* Remove logging code that broke with newer Lhotse; fix bug with pruned_loss

* Add norm+balancer to VggSubsampling

* Incorporate changes from master into pruned_transducer_stateless2.

* Add max-abs=6, debugged version

* Change 0.025,0.05 to 0.01 in initializations

* Fix balancer code

* Whitespace fix

* Reduce initial pruned_loss scale from 0.01 to 0.0

* Increase warm_step (and valid_interval)

* Change max-abs from 6 to 10

* Change how warmup works.

* Add changes from master to decode.py, train.py

* Simplify the warmup code; max_abs 10->6

* Make warmup work by scaling layer contributions; leave residual layer-drop

* Fix bug

* Fix test mode with random layer dropout

* Add random-number-setting function in dataloader

* Fix/patch how fix_random_seed() is imported.

* Reduce layer-drop prob

* Reduce layer-drop prob after warmup to 1 in 100

* Change power of lr-schedule from -0.5 to -0.333

* Increase model_warm_step to 4k

* Change max-keep-prob to 0.95

* Refactoring and simplifying conformer and frontend

* Rework conformer, remove some code.

* Reduce 1st conv channels from 64 to 32

* Add another convolutional layer

* Fix padding bug

* Remove dropout in output layer

* Reduce speed of some components

* Initial refactoring to remove unnecessary vocab_size

* Fix RE identity

* Bug-fix

* Add final dropout to conformer

* Remove some un-used code

* Replace nn.Linear with ScaledLinear in simple joiner

* Make 2 projections..

* Reduce initial_speed

* Use initial_speed=0.5

* Reduce initial_speed further from 0.5 to 0.25

* Reduce initial_speed from 0.5 to 0.25

* Change how warmup is applied.

* Bug fix to warmup_scale

* Fix test-mode

* Remove final dropout

* Make layer dropout rate 0.075, was 0.1.

* First draft of model rework

* Various bug fixes

* Change learning speed of simple_lm_proj

* Revert transducer_stateless/ to state in upstream/master

* Fix to joiner to allow different dims

* Some cleanups

* Make training more efficient, avoid redoing some projections.

* Change how warm-step is set

* First draft of new approach to learning rates + init

* Some fixes..

* Change initialization to 0.25

* Fix type of parameter

* Fix weight decay formula by adding 1/1-beta

* Fix weight decay formula by adding 1/1-beta

* Fix checkpoint-writing

* Fix to reading scheudler from optim

* Simplified optimizer, rework somet things..

* Reduce model_warm_step from 4k to 3k

* Fix bug in lambda

* Bug-fix RE sign of target_rms

* Changing initial_speed from 0.25 to 01

* Change some defaults in LR-setting rule.

* Remove initial_speed

* Set new scheduler

* Change exponential part of lrate to be epoch based

* Fix bug

* Set 2n rule..

* Implement 2o schedule

* Make lrate rule more symmetric

* Implement 2p version of learning rate schedule.

* Refactor how learning rate is set.

* Fix import

* Modify init (#301)

* update icefall/__init__.py to import more common functions.

* update icefall/__init__.py

* make imports style consistent.

* exclude black check for icefall/__init__.py in pyproject.toml.

* Minor fixes for logging (#296)

* Minor fixes for logging

* Minor fix

* Fix dir names

* Modify beam search to be efficient with current joienr

* Fix adding learning rate to tensorboard

* Fix docs in optim.py

* Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes

* Tedlium3 pruned transducer stateless (#261)

* update tedlium3-pruned-transducer-stateless-codes

* update README.md

* update README.md

* add fast beam search for decoding

* do a change for RESULTS.md

* do a change for RESULTS.md

* do a fix

* do some changes for pruned RNN-T

* Add mix precision support

* Minor fixes

* Minor fixes

* Updating RESULTS.md; fix in beam_search.py

* Fix rebase

* Code style check for librispeech pruned transducer stateless2 (#308)

* Update results for tedlium3 pruned RNN-T (#307)

* Update README.md

* Fix CI errors. (#310)

* Add more results

* Fix tensorboard log location

* Add one more epoch of full expt

* fix comments

* Add results for mixed precision with max-duration 300

* Changes for pretrained.py (tedlium3 pruned RNN-T) (#311)

* GigaSpeech recipe (#120)

* initial commit

* support download, data prep, and fbank

* on-the-fly feature extraction by default

* support BPE based lang

* support HLG for BPE

* small fix

* small fix

* chunked feature extraction by default

* Compute features for GigaSpeech by splitting the manifest.

* Fixes after review.

* Split manifests into 2000 pieces.

* set audio duration mismatch tolerance to 0.01

* small fix

* add conformer training recipe

* Add conformer.py without pre-commit checking

* lazy loading and use SingleCutSampler

* DynamicBucketingSampler

* use KaldifeatFbank to compute fbank for musan

* use pretrained language model and lexicon

* use 3gram to decode, 4gram to rescore

* Add decode.py

* Update .flake8

* Delete compute_fbank_gigaspeech.py

* Use BucketingSampler for valid and test dataloader

* Update params in train.py

* Use bpe_500

* update params in decode.py

* Decrease num_paths while CUDA OOM

* Added README

* Update RESULTS

* black

* Decrease num_paths while CUDA OOM

* Decode with post-processing

* Update results

* Remove lazy_load option

* Use default `storage_type`

* Keep the original tolerance

* Use split-lazy

* black

* Update pretrained model

Co-authored-by: Fangjun Kuang <[email protected]>

* Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes

* Support computing RNN-T loss with torchaudio (#316)

* Update results for torchaudio RNN-T. (#322)

* Fix some typos. (#329)

* fix fp16 option in example usage (#332)

* Support averaging models with weight tying. (#333)

* Support specifying iteration number of checkpoints for decoding. (#336)

See also #289

* Modified conformer with multi datasets (#312)

* Copy files for editing.

* Use librispeech + gigaspeech with modified conformer.

* Support specifying number of workers for on-the-fly feature extraction.

* Feature extraction code for GigaSpeech.

* Combine XL splits lazily during training.

* Fix warnings in decoding.

* Add decoding code for GigaSpeech.

* Fix decoding the gigaspeech dataset.

We have to use the decoder/joiner networks for the GigaSpeech dataset.

* Disable speed perturbe for XL subset.

* Compute the Nbest oracle WER for RNN-T decoding.

* Minor fixes.

* Minor fixes.

* Add results.

* Update results.

* Update CI.

* Update results.

* Fix style issues.

* Update results.

* Fix style issues.

* Update results. (#340)

* Update results.

* Typo fixes.

* Validate generated manifest files. (#338)

* Validate generated manifest files. (#338)

* Save batch to disk on OOM. (#343)

* Save batch to disk on OOM.

* minor fixes

* Fixes after review.

* Fix style issues.

* Fix decoding for gigaspeech in the libri + giga setup. (#345)

* Model average (#344)

* First upload of model average codes.

* minor fix

* update decode file

* update .flake8

* rename pruned_transducer_stateless3 to pruned_transducer_stateless4

* change epoch number counter starting from 1 instead of 0

* minor fix of pruned_transducer_stateless4/train.py

* refactor the checkpoint.py

* minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py

* update author info

* add docs of the scaling in function average_checkpoints_with_averaged_model

* Save batch to disk on exception. (#350)

* Bug fix (#352)

* Keep model_avg on cpu (#348)

* keep model_avg on cpu

* explicitly convert model_avg to cpu

* minor fix

* remove device convertion for model_avg

* modify usage of the model device in train.py

* change model.device to next(model.parameters()).device for decoding

* assert params.start_epoch>0

* assert params.start_epoch>0, params.start_epoch

* Do some changes for aishell/ASR/transducer stateless/export.py (#347)

* do some changes for aishell/ASR/transducer_stateless/export.py

* Support decoding with averaged model when using --iter (#353)

* support decoding with averaged model when using --iter

* minor fix

* monir fix of copyright date

* Stringify torch.__version__ before serializing it. (#354)

* Run decode.py in GitHub actions. (#356)

* Ignore padding frames during RNN-T decoding. (#358)

* Ignore padding frames during RNN-T decoding.

* Fix outdated decoding code.

* Minor fixes.

* Support --iter in export.py (#360)

* GigaSpeech RNN-T experiments (#318)

* Copy RNN-T recipe from librispeech

* flake8

* flake8

* Update params

* gigaspeech decode

* black

* Update results

* syntax highlight

* Update RESULTS.md

* typo

* Update decoding script for gigaspeech and remove duplicate files. (#361)

* Validate that there are no OOV tokens in BPE-based lexicons. (#359)

* Validate that there are no OOV tokens in BPE-based lexicons.

* Typo fixes.

* Decode gigaspeech in GitHub actions (#362)

* Add CI for gigaspeech.

* Update results for libri+giga multi dataset setup. (#363)

* Update results for libri+giga multi dataset setup.

* Update GigaSpeech reults (#364)

* Update decode.py

* Update export.py

* Update results

* Update README.md

* Fix GitHub CI for decoding GigaSpeech dev/test datasets (#366)

* modify .flake8

* minor fix

* minor fix

Co-authored-by: Daniel Povey <[email protected]>
Co-authored-by: Wei Kang <[email protected]>
Co-authored-by: Mingshuang Luo <[email protected]>
Co-authored-by: Fangjun Kuang <[email protected]>
Co-authored-by: Guo Liyong <[email protected]>
Co-authored-by: Wang, Guanbo <[email protected]>
Co-authored-by: whsqkaak <[email protected]>
Co-authored-by: pehonnet <[email protected]>
@wgb14 wgb14 deleted the gigaspeech_rnnt branch May 17, 2022 05:19
@xIaott-s
Copy link

xIaott-s commented Sep 7, 2023

Thanks, Fangjun. These really help a lot. No, I'm still training with --max-duration 120, same with Conformer-CTC setup. I'm using torch 1.10.0, a version after that fix.

I had the same problem as you,and I'm also using torch 1.10.0. Only when I set --max-duration=120 can the training continue.
Did you find the reason?Is this a bug for torch==1.10.0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants