Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready to merge] Pruned transducer stateless5 recipe for tal_csasr (mix Chinese chars and English BPE) #428

Merged

Conversation

luomingshuang
Copy link
Collaborator

In this PR, I try to build a pruned transducer stateless5 recipe for TAL_CSASR dataset. I mix Chinese chars and English BPE (for English text parts) as tokens for modeling.

@luomingshuang
Copy link
Collaborator Author

luomingshuang commented Jun 17, 2022

When I try to train this recipe, a problem happens to me.
My running command is:

CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' python pruned_transducer_stateless5/train.py \
  --max-duration 100 \
  --world-size 8 \
  --num-workers 2 \
  --start-epoch 2

The value of max-duration is small. I have filterred it to 1 to 18 seconds. But the memory size in training process is very big.
image

And after some iterations, CUDA out of memory happens to it.
image

@luomingshuang
Copy link
Collaborator Author

The lengths of samples'texts have a big difference.

记 完 了 我 们 来 看
▁BE ▁AWARE ▁OF ▁SOMETHING
一 开 口 就 是 一 个 正 确 的 句 子 至 少 是 一 个 完 整 的 句 子
对 呀 他 的 朋 友 是 ▁HE ▁FRIEND
比 如 说 ▁PLAY ▁THE ▁PIANO P LA Y ▁THE ▁VI OL IN 然 后 再 加 上 还 有 什 么 乐 器
嗯 我 给 你 打 开
这 个 受 伤 一 定 不 是 啊 自 己 把 自 己 造 成 的 啊 一 定 是 这 个 外 界 因 素 导 致 的 这 种 伤 口
刚 说 在 什 么 上 面 有 ▁ON 是 吗 有 ▁ABOVE ▁OVER 这 三 个 词 它 们 都 是 在 什 么 上 面 的 意 思 区 别 在 于 ▁OVER 越 过 什 么 上 面 ▁ABOVE 在 什 么 上 面
这 四 个 句 子 我 们 都 用 了 ▁AS ▁SOON ▁AS 来 翻 译 你 看 一 下 这 四 个 句 子 其 实 都 有 一 个 特 点 就 这 里 一 有 人 求 助 篮 球 比 赛 一 结 束 包 括 刚 才 的 我 们 翻 译 的 是 一 听 到 呼 救 声 还 有 一 离 开 教 室
然 后 嗯 中 国 传 统 文 化
嗯 这 个 牛 牛 怎 么 说
啊 那 你 没 看 见 你 没 看 见 好 的 来 再 给 你 个 机 会 好 吧 来 再 再 选 一 遍 快 点 告 诉 我 选 什 么
然 后 我 们 来 看 一 下 下 面
哦 然 后 这 个 容 忍 这 个 也 讲 了 对 吧 我 们 说 诶 那 么 回 顾 一 下 容 忍 我 们 之 前 应 该 提 过 容 忍 有 哪 些 ▁STAND 对 ▁STAND 也 是 熟 词 考 生 义
也 是 一 样 的 嗯 记 不 记 得 老 师 有 跟 你 说 过 ▁COULD 和 ▁CAN 的 区 别 就 是 它 比 较 委 婉
好 那 咱 们 现 在 来 看 第 四 篇 文 章
写 的 太 小 了 那 个 一 开 始 写 的 时 候 放 开 放 大 了 去 写
之 二 分 之 一 对 不 对 你 要 记 得 啊 概 率 这 种 这 种 一 定 要 化 到 最 简 那 么 这 里 我 必 须 写 二 分 之 一 不 能 写 四 分 之 二 你 写 四 分 之 二 肯 定 是 错 误 的 但 是 四 分 之 二 和 二 分 之 一 你 看 就 是
▁CANADA ▁CANADA
这 样 这 样 得 到 之 后 下 面 是
好 那 这 个 ▁SHARE
对 还 有 一 些 这 些 都 是 一 样 的 ▁TO DAY S ▁NEWSPAPER 都 是 这 样
票 票 是 啥 呢 票 是 啥 ▁TICKET
呃 形 容 词 性 的 物 主 代 词
▁B ▁O ▁W ▁BOW 你 知 道 碗 吧 碗 是 再 加 一 个 ▁L 嘛 ▁BOW L 对 吧 把 这 ▁L 去 掉 ▁BOW 去 弓 弯 腰

@pzelasko
Copy link
Collaborator

Since it’s a larger dataset, you can try increasing num_cuts_for_bins_estimate and num_buckets in DynamicBucketingSampler. It will lead to higher quality bucketing. You might need to increase buffer_size a bit too if you have more buckets.

@pzelasko
Copy link
Collaborator

This function can help you tune the settings to minimize the overall padding: https://github.com/lhotse-speech/lhotse/blob/94e9ed9c67bcb4b4329e907ae335947dbbce99e9/lhotse/dataset/sampling/utils.py#L89

@luomingshuang
Copy link
Collaborator Author

OK, thanks. I will try your suggestions.

@luomingshuang
Copy link
Collaborator Author

luomingshuang commented Jun 17, 2022

Oh, after I increase some relative parameters (num_cuts_for_bins_estimate=20000, num_buckets=800), there still exists the above error.

@pzelasko
Copy link
Collaborator

num_buckets 800 might be excessive, what about sth like 50?

in the batch where texts have a large difference in length, do you also observe large difference in audio durations?

@luomingshuang
Copy link
Collaborator Author

Or maybe I can have a try by using bucket sampler (not dynamicbucket sampler)?

99% 18.0
99.5% 18.8
99.9% 20.8
max 36.5
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are the statistics about the durations among the train cuts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_set:
Duration statistics (seconds):
mean    5.8
std     4.1
min     0.3
25%     2.8
50%     4.4
75%     7.3
99%     18.0
99.5%   18.8
99.9%   20.8
max     36.5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about within-batch statistics? I am trying to understand if:

  • a) the bucketing sampler is doing a poor job by putting cuts with very different durations together (so we can tune the settings of the sampler to do better), or
  • b) if the mini-batch cut durations are close to each other (in which case there is nothing we can do)

# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 18.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #428 (comment)
You will drop 1% of the training data, which may be too much.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I fix the above error, I will consider to increase the max duration.

@luomingshuang
Copy link
Collaborator Author

I try to use the code cuts = cuts.sort_by_duration(ascending=False) in compute_fbank_tal_csasr.py to sort the cuts. And then I print the num_frames for each batch in the training process as follows:

num_frames in a batch: tensor([122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
        122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
        122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
        122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
        122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122,
        122, 122, 122, 122, 122, 122, 122, 122, 122, 122, 122],
       dtype=torch.int32)
num_frames in a batch: tensor([144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144,
        144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144,
        144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144,
        144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144,
        144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144, 144],
       dtype=torch.int32)
num_frames in a batch: tensor([123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123],
       dtype=torch.int32)
num_frames in a batch: tensor([125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125,
        125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125,
        125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125,
        125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125,
        125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 125,
        125, 125, 125, 125, 125, 125, 125, 125, 125, 125], dtype=torch.int32)
num_frames in a batch: tensor([113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113, 113,
        113, 113, 113, 113], dtype=torch.int32)
num_frames in a batch: tensor([143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143,
        143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143,
        143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143,
        143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143,
        143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143, 143],
       dtype=torch.int32)
num_frames in a batch: tensor([133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133,
        133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133,
        133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133,
        133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133,
        133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133, 133,
        133, 133, 133, 133, 133], dtype=torch.int32)
num_frames in a batch: tensor([138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138,
        138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138,
        138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138,
        138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138,
        138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138, 138,
        138, 138], dtype=torch.int32)
num_frames in a batch: tensor([119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
        119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
        119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
        119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
        119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
        119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119],
       dtype=torch.int32)
num_frames in a batch: tensor([127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
        127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
        127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
        127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
        127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
        127, 127, 127, 127, 127, 127, 127, 127], dtype=torch.int32)
num_frames in a batch: tensor([123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123,
        123, 123, 123, 123, 123, 123, 123, 123, 123, 123, 123],
       dtype=torch.int32)

@pzelasko
Copy link
Collaborator

pzelasko commented Jun 17, 2022

It looks like these batches don't have any padding at all. I think you don't have any other option than decreasing max duration.

@luomingshuang
Copy link
Collaborator Author

OK, I decrease the max duration to 90 and train it. Let's wait for the case.

@kobenaxie
Copy link
Contributor

How many tokens in your vocabulary ? Maybe the vocab size is too large ?

@luomingshuang
Copy link
Collaborator Author

The number of modeling tokens (including Chinese chars and English BPEs) is 7341. Yes, it is large. Now, I use the duration boundary of 1~20 seconds to filter the training cuts and a max duration 90. The training process has run for 5 epochs normally. I will let it continue running.

@pkufool
Copy link
Collaborator

pkufool commented Jun 18, 2022

@luomingshuang Dose it still raise OOM error after you sort the durations?

@luomingshuang
Copy link
Collaborator Author

At present, it runs for 5 epochs normally based on max-duration=90 and sort_by_duration when computing fbank feature. But if I only sort the cuts not decrease the max-duration, it still raises OOM error according to my attempts yesterday.

@luomingshuang Dose it still raise OOM error after you sort the durations?

@luomingshuang
Copy link
Collaborator Author

luomingshuang commented Jun 22, 2022

The results (WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR:

decoding-method epoch(iter) avg dev test
greedy_search 30 24 7.49 7.58
modified_beam_search 30 24 7.33 7.38
fast_beam_search 30 24 7.32 7.42
greedy_search(use-averaged-model=True) 30 24 7.30 7.39
modified_beam_search(use-averaged-model=True) 30 24 7.15 7.22
fast_beam_search(use-averaged-model=True) 30 24 7.18 7.26
greedy_search 348000 30 7.46 7.54
modified_beam_search 348000 30 7.24 7.36
fast_beam_search 348000 30 7.25 7.39

@csukuangfj
Copy link
Collaborator

The results (WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR:

Do you have any baseline to compare with?

@luomingshuang
Copy link
Collaborator Author

Em....It seems that I can't find any baseline on this dataset.

The results (WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR:

Do you have any baseline to compare with?

@csukuangfj
Copy link
Collaborator

Em....It seems that I can't find any baseline on this dataset.

The results (WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR:

Do you have any baseline to compare with?

ok, you are creating the baseline for others.

@luomingshuang luomingshuang changed the title [WIP] Pruned transducer stateless5 recipe for tal_csasr (mix Chinese chars and English BPE) [Ready to merge] Pruned transducer stateless5 recipe for tal_csasr (mix Chinese chars and English BPE) Jun 23, 2022
@luomingshuang luomingshuang added ready and removed ready labels Jun 23, 2022
@fanlu
Copy link
Contributor

fanlu commented Jun 24, 2022

The results (WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR:

|decoding-method | epoch(iter) | avg | dev | test |

|--|--|--|--|--|

|greedy_search | 30 | 24 | 7.49 | 7.58|

|modified_beam_search | 30 | 24 | 7.33 | 7.38|

|fast_beam_search | 30 | 24 | 7.32 | 7.42|

|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|

|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|

|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|

|greedy_search | 348000 | 30 | 7.46 | 7.54|

|modified_beam_search | 348000 | 30 | 7.24 | 7.36|

|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |

Do you have the result of Chinese CER and English WER respectivly?

@luomingshuang
Copy link
Collaborator Author

Maybe I can test the Chinese and English decoding results respectively.

@luomingshuang
Copy link
Collaborator Author

The results (CER(%) and WER(%)) for pruned_transducer_stateless5 trained with TAL_CSASR (zh: Chinese, en: English): (It includes Chinese CER and English WER respectivly)

decoding-method epoch(iter) avg dev dev_zh dev_en test test_zh test_en
greedy_search(use-averaged-model=True) 30 24 7.30 6.48 19.19 7.39 6.66 19.13
modified_beam_search(use-averaged-model=True) 30 24 7.15 6.35 18.95 7.22 6.50 18.70
fast_beam_search(use-averaged-model=True) 30 24 7.18 6.39 18.90 7.27 6.55 18.77

@luomingshuang luomingshuang added ready and removed ready labels Jun 24, 2022
@luomingshuang luomingshuang added ready and removed ready labels Jun 28, 2022
@luomingshuang luomingshuang merged commit 2cb1618 into k2-fsa:master Jun 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants