Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert ScaledEmbedding to nn.Embedding for inference. #517

Merged
merged 2 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module):
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
Expand All @@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module):
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
the start of training. Note: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,19 @@

"""
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
and `nn.Conv2d`.
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.

The scaled version are required only in the training time. It simplifies our
life by converting them their non-scaled version during inference time.
life by converting them to their non-scaled version during inference.
"""

import copy
import re

import torch
import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear


def _get_weight(self: torch.nn.Linear):
return self.weight


def _get_bias(self: torch.nn.Linear):
return self.bias
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear


def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
Expand All @@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)

# if not hasattr(torch.nn.Linear, "get_weight"):
# torch.nn.Linear.get_weight = _get_weight
# torch.nn.Linear.get_bias = _get_bias

weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
Expand Down Expand Up @@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
return conv2d


def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.

Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale

embedding.weight.data.copy_(weight * scale.exp())

return embedding


def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
Expand Down Expand Up @@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)

for k, v in d.items():
if "." in k:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
import copy

import torch
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
)
from train import get_params, get_transducer_model
Expand Down Expand Up @@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d():
assert torch.allclose(y1, y4)


def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)

for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)


def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
Expand Down Expand Up @@ -193,6 +209,7 @@ def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_convert_scaled_to_non_scaled()


Expand Down
9 changes: 6 additions & 3 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,13 @@ def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
word_fsa
)
else:
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)

word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
word_fsa
)

path_to_utt_map = self.shape.row_ids(1)

Expand Down