Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add timestamps for streaming ASR #119

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 85 additions & 14 deletions sherpa/bin/lstm_transducer_stateless/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def process(

processed_lens = (num_processed_frames >> 2) + encoder_out_lens
if self.decoding_method == "fast_beam_search_nbest":
next_hyp_list, next_trailing_blank_frames = fast_beam_search_nbest(
res = fast_beam_search_nbest(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -144,10 +144,7 @@ def process(
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search_nbest_LG":
(
next_hyp_list,
next_trailing_blank_frames,
) = fast_beam_search_nbest_LG(
res = fast_beam_search_nbest_LG(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -159,10 +156,7 @@ def process(
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
(
next_hyp_list,
next_trailing_blank_frames,
) = fast_beam_search_one_best(
res = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
Expand All @@ -177,8 +171,12 @@ def process(
next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyp = next_hyp_list[i]
s.num_trailing_blank_frames = next_trailing_blank_frames[i]
s.hyp = res.hyps[i]
s.num_trailing_blank_frames = res.num_trailing_blanks[i]
s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)
s.timestamps = res.timestamps[i]
s.tokens = res.tokens[i]

def get_texts(self, stream: Stream) -> str:
"""
Expand All @@ -200,6 +198,22 @@ def get_texts(self, stream: Stream) -> str:

return result

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
tokens = stream.tokens

if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in tokens]
else:
result = [self.token_table[i] for i in tokens]

return result


class GreedySearch:
def __init__(
Expand Down Expand Up @@ -244,6 +258,10 @@ def init_stream(self, stream: Stream):
self.beam_search_params["blank_id"]
] * self.beam_search_params["context_size"]

# timestamps[i] is the frame number on which stream.hyp[i+context_size]
# is decoded
stream.timestamps = [] # containing frame numbers after subsampling

@torch.no_grad()
def process(
self,
Expand All @@ -266,9 +284,13 @@ def process(
chunk_length = server.chunk_length
batch_size = len(stream_list)
chunk_length_pad = server.chunk_length_pad
state_list, feature_list = [], []
decoder_out_list, hyp_list = [], []
state_list = []
feature_list = []
decoder_out_list = []
hyp_list = []
num_trailing_blank_frames_list = []
frame_offset_list = []
timestamps_list = []

for s in stream_list:
decoder_out_list.append(s.decoder_out)
Expand All @@ -282,6 +304,8 @@ def process(
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)
frame_offset_list.append(s.segment_frame_offset)
timestamps_list.append(s.timestamps)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)
Expand Down Expand Up @@ -311,12 +335,15 @@ def process(
next_decoder_out,
next_hyp_list,
next_trailing_blank_frames,
next_timestamps,
) = streaming_greedy_search(
model=model,
encoder_out=encoder_out,
decoder_out=decoder_out,
hyps=hyp_list,
num_trailing_blank_frames=num_trailing_blank_frames_list,
frame_offset=frame_offset_list,
timestamps=timestamps_list,
)

next_decoder_out_list = next_decoder_out.split(1)
Expand All @@ -327,6 +354,9 @@ def process(
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]
s.num_trailing_blank_frames = next_trailing_blank_frames[i]
s.timestamps = next_timestamps[i]
s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)

def get_texts(self, stream: Stream) -> str:
"""
Expand All @@ -344,6 +374,21 @@ def get_texts(self, stream: Stream) -> str:

return result

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
hyp = stream.hyp[self.beam_search_params["context_size"] :]
if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in hyp]
else:
result = [self.token_table[i] for i in hyp]

return result


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
Expand Down Expand Up @@ -383,6 +428,7 @@ def process(
state_list = []
hyps_list = []
feature_list = []
frame_offset_list = []
for s in stream_list:
state_list.append(s.states)
hyps_list.append(s.hyps)
Expand All @@ -393,6 +439,7 @@ def process(

b = torch.cat(f, dim=0)
feature_list.append(b)
frame_offset_list.append(s.segment_frame_offset)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)
Expand All @@ -415,15 +462,22 @@ def process(
model=model,
encoder_out=encoder_out,
hyps=hyps_list,
frame_offset=frame_offset_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks

best_hyp = s.hyps.get_most_probable(True)

trailing_blanks = best_hyp.num_trailing_blanks
s.timestamps = best_hyp.timestamps
s.num_trailing_blank_frames = trailing_blanks
s.frame_offset += encoder_out.size(1)
s.segment_frame_offset += encoder_out.size(1)

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
Expand All @@ -436,3 +490,20 @@ def get_texts(self, stream: Stream) -> str:
result = "".join(result)

return result

def get_tokens(self, stream: Stream) -> str:
"""
Return tokens after decoding
Args:
stream:
Stream to be processed.
"""
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = [self.sp.id_to_piece(i) for i in hyp]
else:
result = [self.token_table[i] for i in hyp]

return result
9 changes: 8 additions & 1 deletion sherpa/bin/lstm_transducer_stateless/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,15 @@ def __init__(
self.subsampling_factor = subsampling_factor
self.log_eps = math.log(1e-10)

# whenever an endpoint is detected, it is incremented
# increment on endpointing
self.segment = 0

# Number of frames decoded so far (after subsampling)
self.frame_offset = 0 # never reset

# frame offset within the current segment after subsampling
self.segment_frame_offset = 0 # reset on endpointing

def accept_waveform(
self,
sampling_rate: float,
Expand Down Expand Up @@ -225,5 +231,6 @@ def endpoint_detected(
self.num_trailing_blank_frames = 0
self.processed_frames = 0
self.segment += 1
self.segment_frame_offset = 0

return detected
24 changes: 24 additions & 0 deletions sherpa/bin/lstm_transducer_stateless/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
RnntLstmModel,
add_beam_search_arguments,
add_online_endpoint_arguments,
convert_timestamp,
)


Expand Down Expand Up @@ -298,6 +299,7 @@ def __init__(
raise ValueError(
f"Decoding method {decoding_method} is not supported."
)
self.decoding_method = decoding_method

if hasattr(self, "sp"):
self.beam_search.sp = self.sp
Expand Down Expand Up @@ -484,15 +486,27 @@ async def handle_connection_impl(
while len(stream.features) > self.chunk_length_pad:
await self.compute_and_decode(stream)
hyp = self.beam_search.get_texts(stream)
tokens = self.beam_search.get_tokens(stream)

segment = stream.segment
timestamps = convert_timestamp(
frames=stream.timestamps,
subsampling_factor=stream.subsampling_factor,
)

frame_offset = stream.frame_offset * stream.subsampling_factor

is_final = stream.endpoint_detected(self.online_endpoint_config)
if is_final:
self.beam_search.init_stream(stream)

message = {
"method": self.decoding_method,
"segment": segment,
"frame_offset": frame_offset,
"text": hyp,
"tokens": tokens,
"timestamps": timestamps,
"final": is_final,
}

Expand All @@ -509,10 +523,20 @@ async def handle_connection_impl(
stream.features = []

hyp = self.beam_search.get_texts(stream)
tokens = self.beam_search.get_tokens(stream)
frame_offset = stream.frame_offset * stream.subsampling_factor
timestamps = convert_timestamp(
frames=stream.timestamps,
subsampling_factor=stream.subsampling_factor,
)

message = {
"method": self.decoding_method,
"segment": stream.segment,
"frame_offset": frame_offset,
"text": hyp,
"tokens": tokens,
"timestamps": timestamps,
"final": True, # end of connection, always set final to True
}

Expand Down
21 changes: 20 additions & 1 deletion sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,23 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
async for message in socket:
result = json.loads(message)

method = result["method"]
segment = result["segment"]
is_final = result["final"]
text = result["text"]
tokens = result["tokens"]
timestamps = result["timestamps"]

if is_final:
ans.append(dict(segment=segment, text=text))
ans.append(
dict(
method=method,
segment=segment,
text=text,
tokens=tokens,
timestamps=timestamps,
)
)
logging.info(f"Final result of segment {segment}: {text}")
continue

Expand Down Expand Up @@ -121,8 +132,16 @@ async def run(server_addr: str, server_port: int, test_wav: str):
decoding_results = await receive_task
s = ""
for r in decoding_results:
s += f"method: {r['method']}\n"
s += f"segment: {r['segment']}\n"
s += f"text: {r['text']}\n"

token_time = []
for token, time in zip(r["tokens"], r["timestamps"]):
token_time.append((token, time))

s += f"timestamps: {r['timestamps']}\n"
s += f"(token, time): {token_time}\n"
logging.info(f"{test_wav}\n{s}")


Expand Down
4 changes: 4 additions & 0 deletions sherpa/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ struct Hypothesis {
// The predicted tokens so far. Newly predicated tokens are appended.
std::vector<int32_t> ys;

// timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded.
std::vector<int32_t> timestamps;

// The total score of ys in log space.
double log_prob = 0;

Expand Down
7 changes: 6 additions & 1 deletion sherpa/csrc/rnnt_beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ std::vector<std::vector<int32_t>> GreedySearch(
torch::Tensor StreamingGreedySearch(
RnntModel &model, // NOLINT
torch::Tensor encoder_out, torch::Tensor decoder_out,
const std::vector<int32_t> &frame_offset,
std::vector<std::vector<int32_t>> *hyps,
std::vector<int32_t> *num_trailing_blank_frames) {
std::vector<int32_t> *num_trailing_blank_frames,
std::vector<std::vector<int32_t>> *timestamps) {
TORCH_CHECK(encoder_out.dim() == 3, encoder_out.dim(), " vs ", 3);
TORCH_CHECK(decoder_out.dim() == 2, decoder_out.dim(), " vs ", 2);

Expand Down Expand Up @@ -246,6 +248,7 @@ torch::Tensor StreamingGreedySearch(
if (index != blank_id && index != unk_id) {
emitted = true;
(*hyps)[n].push_back(index);
(*timestamps)[n].push_back(t + frame_offset[n]);
(*num_trailing_blank_frames)[n] = 0;
} else {
(*num_trailing_blank_frames)[n] += 1;
Expand Down Expand Up @@ -430,6 +433,7 @@ std::vector<std::vector<int32_t>> ModifiedBeamSearch(
std::vector<Hypotheses> StreamingModifiedBeamSearch(
RnntModel &model, // NOLINT
torch::Tensor encoder_out, std::vector<Hypotheses> in_hyps,
const std::vector<int32_t> &frame_offset,
int32_t num_active_paths /*= 4*/) {
TORCH_CHECK(encoder_out.dim() == 3, encoder_out.dim(), " vs ", 3);

Expand Down Expand Up @@ -522,6 +526,7 @@ std::vector<Hypotheses> StreamingModifiedBeamSearch(
int32_t new_token = topk_token_indexes_acc[j];
if (new_token != blank_id && new_token != unk_id) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset[k]);
new_hyp.num_trailing_blanks = 0;
} else {
new_hyp.num_trailing_blanks += 1;
Expand Down
Loading