Skip to content

Commit

Permalink
Convert ScaledEmbedding to nn.Embedding for inference. (#517)
Browse files Browse the repository at this point in the history
* Convert ScaledEmbedding to nn.Embedding for inference.

* Fix CI style issues.
  • Loading branch information
csukuangfj authored Aug 3, 2022
1 parent 58a96e5 commit 6af5a82
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 24 deletions.
5 changes: 1 addition & 4 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module):
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
Expand All @@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module):
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
the start of training. Note: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,19 @@

"""
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
and `nn.Conv2d`.
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
The scaled version are required only in the training time. It simplifies our
life by converting them their non-scaled version during inference time.
life by converting them to their non-scaled version during inference.
"""

import copy
import re

import torch
import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear


def _get_weight(self: torch.nn.Linear):
return self.weight


def _get_bias(self: torch.nn.Linear):
return self.bias
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear


def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
Expand All @@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)

# if not hasattr(torch.nn.Linear, "get_weight"):
# torch.nn.Linear.get_weight = _get_weight
# torch.nn.Linear.get_bias = _get_bias

weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
Expand Down Expand Up @@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
return conv2d


def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.
Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale

embedding.weight.data.copy_(weight * scale.exp())

return embedding


def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
Expand Down Expand Up @@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)

for k, v in d.items():
if "." in k:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
import copy

import torch
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
)
from train import get_params, get_transducer_model
Expand Down Expand Up @@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d():
assert torch.allclose(y1, y4)


def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)

for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)


def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
Expand Down Expand Up @@ -193,6 +209,7 @@ def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_convert_scaled_to_non_scaled()


Expand Down
9 changes: 6 additions & 3 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,13 @@ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
word_fsa
)
else:
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)

word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
word_fsa
)

path_to_utt_map = self.shape.row_ids(1)

Expand Down

0 comments on commit 6af5a82

Please sign in to comment.