Skip to content

Commit

Permalink
Add streaming modified beam search (#142)
Browse files Browse the repository at this point in the history
* Add streaming modified beam search

* Add streaming modified beam search for conv_emformer

* Fix issue streaming modified beam search

* Fix issue modified beam search conv emformer

* Add get_texts in conv_emformer
  • Loading branch information
ezerhouni authored Sep 30, 2022
1 parent 415f126 commit b163428
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conv-emformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.11.0", "1.7.1"]
torchaudio: ["0.11.0", "0.7.2"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.11.0"
torchaudio: "0.7.2"
Expand Down
111 changes: 111 additions & 0 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Hypotheses,
Hypothesis,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
streaming_modified_beam_search,
)


Expand Down Expand Up @@ -339,3 +342,111 @@ def get_texts(self, stream: Stream) -> str:
return self.sp.decode(
stream.hyp[self.beam_search_params["context_size"] :]
)


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
hyp = [self.beam_search_params["blank_id"]] * self.beam_search_params[
"context_size"
]
stream.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)])

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do search with greedy_search
method.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyp` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
batch_size = len(stream_list)
chunk_length_pad = server.chunk_length_pad
state_list, feature_list = [], []
hyp_list = []
processed_frames_list = []
num_trailing_blank_frames_list = []

for s in stream_list:
hyp_list.append(s.hyps)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length_pad]
s.features = s.features[chunk_length:]
s.processed_frames += chunk_length

b = torch.cat(f, dim=0)
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

num_processed_frames = torch.tensor(
processed_frames_list,
device=device,
)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
num_processed_frames=num_processed_frames,
states=states,
)

# Note: There are no paddings for streaming ASR. Each stream
# has the same input number of frames, i.e., server.chunk_length.
next_hyps_list = streaming_modified_beam_search(
model=model,
encoder_out=encoder_out,
hyps=hyp_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks
s.num_trailing_blank_frames = trailing_blanks

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = self.sp.decode(hyp)
else:
result = [self.token_table[i] for i in hyp]
result = "".join(result)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import sentencepiece as spm
import torch
import websockets
from beam_search import FastBeamSearch, GreedySearch
from beam_search import FastBeamSearch, GreedySearch, ModifiedBeamSearch
from stream import Stream

from sherpa import (
Expand Down Expand Up @@ -275,6 +275,8 @@ def __init__(
beam_search_params,
device,
)
elif decoding_method == "modified_beam_search":
self.beam_search = ModifiedBeamSearch(beam_search_params)
else:
raise ValueError(
f"Decoding method {decoding_method} is not supported."
Expand Down
121 changes: 121 additions & 0 deletions sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Hypotheses,
Hypothesis,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
streaming_modified_beam_search,
)


Expand Down Expand Up @@ -378,3 +381,121 @@ def get_texts(self, stream: Stream) -> str:
result = "".join(result).replace("▁", " ")

return result


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
hyp = [self.beam_search_params["blank_id"]] * self.beam_search_params[
"context_size"
]
stream.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)])

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do modified_beam_search.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyps` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
subsampling_factor = server.subsampling_factor
# Note: chunk_size, left_context and right_context are in frames
# after subsampling
chunk_size = server.decode_chunk_size
left_context = server.decode_left_context
right_context = server.decode_right_context

batch_size = len(stream_list)

state_list, feature_list, processed_frames_list = [], [], []
hyp_list = []

num_trailing_blank_frames_list = []

for s in stream_list:
hyp_list.append(s.hyps)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length]
s.features = s.features[chunk_size * subsampling_factor :]
b = torch.cat(f, dim=0)
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)

features = torch.stack(feature_list, dim=0).to(device)

states = [
torch.stack([x[0] for x in state_list], dim=2),
torch.stack([x[1] for x in state_list], dim=2),
]

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

processed_frames = torch.tensor(processed_frames_list, device=device)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
states=states,
processed_frames=processed_frames,
left_context=left_context,
right_context=right_context,
)

next_hyps_list = streaming_modified_beam_search(
model=model,
encoder_out=encoder_out,
hyps=hyp_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

next_state_list = [
torch.unbind(next_states[0], dim=2),
torch.unbind(next_states[1], dim=2),
]

for i, s in enumerate(stream_list):
s.states = [next_state_list[0][i], next_state_list[1][i]]
s.processed_frames += encoder_out_lens[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks
s.num_trailing_blank_frames = trailing_blanks

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = self.sp.decode(hyp)
else:
result = [self.token_table[i] for i in hyp]
result = "".join(result)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import sentencepiece as spm
import torch
import websockets
from beam_search import FastBeamSearch, GreedySearch
from beam_search import FastBeamSearch, GreedySearch, ModifiedBeamSearch
from stream import Stream

from sherpa import (
Expand Down Expand Up @@ -313,6 +313,8 @@ def __init__(
beam_search_params,
device,
)
elif decoding_method == "modified_beam_search":
self.beam_search = ModifiedBeamSearch(beam_search_params)
else:
raise ValueError(
f"Decoding method {decoding_method} is not supported."
Expand Down

0 comments on commit b163428

Please sign in to comment.