Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Try to use multiple datasets with pruned transducer loss #245

Closed
wants to merge 10 commits into from
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute transducer loss given the model and its inputs.

Args:
params:
Expand Down Expand Up @@ -599,7 +599,6 @@ def maybe_log_param_relative_changes():
)

if batch_idx % params.log_interval == 0:

if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, List, Optional

import numpy as np
import torch
from model import Transducer


def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3

# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)

encoder_out = model.simple_encoder_linear(encoder_out)
encoder_out = model.encoder_linear(encoder_out)

blank_id = model.decoder.blank_id
context_size = model.decoder.context_size

device = model.device

decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.simple_decoder_linear(decoder_out)
decoder_out = model.decoder_linear(decoder_out)

T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size

# Maximum symbols per utterance.
max_sym_per_utt = 1000

# symbols per frame
sym_per_frame = 0

# symbols per utterance decoded so far
sym_per_utt = 0

while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue

# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)

y = logits.argmax().item()
if y != blank_id:
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.simple_decoder_linear(decoder_out)
decoder_out = model.decoder_linear(decoder_out)

sym_per_utt += 1
sym_per_frame += 1
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks

return hyp


@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]

# The log prob of ys
log_prob: float

@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))


class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data

@property
def data(self):
return self._data

def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.

If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.

Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
else:
self._data[key] = hyp

def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.

Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.

"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)

def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.

Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]

def filter(self, threshold: float) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.

Caution:
`self` is not modified. Instead, a new HypothesisList is returned.

Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans

def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())

hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]

ans = HypothesisList(dict(hyps))
return ans

def __contains__(self, key: str):
return key in self._data

def __iter__(self):
return iter(self._data.values())

def __len__(self) -> int:
return len(self._data)

def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)


def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf

espnet/nets/beam_search_transducer.py#L247 is used as a reference.

Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3

# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size

device = model.device

decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False)

T = encoder_out.size(1)
t = 0

B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))

max_sym_per_utt = 20000

sym_per_utt = 0

decoder_cache: Dict[str, torch.Tensor] = {}

while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()

joint_cache: Dict[str, torch.Tensor] = {}

# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A

while True:
y_star = A.get_most_probable()
A.remove(y_star)

cached_key = y_star.key

if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
).reshape(1, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]

cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)

# TODO(fangjun): Cache the blank posterior

log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]

# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()

# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))

# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))

# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = A.get_most_probable()

kept_B = B.filter(A_most_probable.log_prob)

if len(kept_B) >= beam:
B = kept_B.topk(beam)
break

t += 1

best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
Loading