Skip to content

Commit

Permalink
update multi_quantization installation (#469)
Browse files Browse the repository at this point in the history
* update multi_quantization installation

* Update egs/librispeech/ASR/pruned_transducer_stateless6/train.py

Co-authored-by: Fangjun Kuang <[email protected]>

Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
glynpu and csukuangfj authored Jul 13, 2022
1 parent bc2882d commit f8d28f0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
6 changes: 3 additions & 3 deletions egs/librispeech/ASR/distillation_with_hubert.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
fi

# Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master
# when testing this code:
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
# pip install git+https://github.com/k2-fsa/multi_quantization.git
# or
# pip install multi_quantization

has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then
Expand Down
6 changes: 4 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from icefall.utils import add_sos

from quantization.prediction import JointCodebookLoss
from multi_quantization.prediction import JointCodebookLoss


class Transducer(nn.Module):
Expand Down Expand Up @@ -75,7 +75,9 @@ def __init__(
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
)

def forward(
Expand Down
5 changes: 5 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless6/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,11 @@ def run(rank, world_size, args):
The return value of get_parser().parse_args()
"""
params = get_params()

# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
assert args.spec_aug_time_warp_factor < 1

params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
Expand Down

0 comments on commit f8d28f0

Please sign in to comment.