Skip to content

Commit

Permalink
Update align_and_segment.py (#5317)
Browse files Browse the repository at this point in the history
Fix MMS alignment code
  • Loading branch information
vineelpratap authored Sep 7, 2023
1 parent 4db2649 commit b5d89cd
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/mms/data_prep/align_and_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ def get_alignments(
blank = dictionary["<blank>"]

targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE)
input_lengths = torch.tensor(emissions.shape[0])
target_lengths = torch.tensor(targets.shape[0])


input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1)
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1)
path, _ = F.forced_align(
emissions, targets, input_lengths, target_lengths, blank=blank
emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank
)
path = path.to("cpu").tolist()
path = path.squeeze().to("cpu").tolist()

segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
return segments, stride

Expand Down

0 comments on commit b5d89cd

Please sign in to comment.