Skip to content

Commit

Permalink
Merge pull request #13 from Forced-Alignment-and-Vowel-Extraction/whi…
Browse files Browse the repository at this point in the history
…sper-timestamped

Migrate to whisper-timestamp
  • Loading branch information
chrisbrickhouse authored Mar 20, 2024
2 parents 2b9efaf + 09fc310 commit 4c2989e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 29 deletions.
42 changes: 30 additions & 12 deletions formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import warnings

import textgrid

class Formatter():
def __init__(self):
pass

def to_TextGrid(self, diarized_transcription):
def to_TextGrid(self, diarized_transcription, by_phrase=True):
"""
Convert a diarized transcription dictionary to a TextGrid
Args:
diarized_transcription: Output of pipeline.assign_speakers()
by_phrase: Flag for whether the intervals should be by phrase (True) or word (False)
Returns:
A textgrid.TextGrid object populated with the diarized and
Expand All @@ -34,29 +37,44 @@ def to_TextGrid(self, diarized_transcription):
maxTime = diarized_transcription['segments'][-1]['end']
tg = textgrid.TextGrid(minTime=minTime,maxTime=maxTime)

speakers = [x['speaker'] for x in diarized_transcription['segments']]
speakers = [x['speaker'] for x in diarized_transcription['segments'] if 'speaker' in x]
for speaker in set(speakers):
tg.append(textgrid.IntervalTier(name=speaker,minTime=minTime,maxTime=maxTime))
# Create a lookup table of tier indices based on the given speaker name
tier_key = dict((name,index) for index, name in enumerate([x.name for x in tg.tiers]))

for segment in diarized_transcription['segments']:
for i in range(len(diarized_transcription['segments'])):
segment = diarized_transcription['segments'][i]
# There's no guarantee, weirdly, that a given word's assigned speaker
# is the same as the speaker assigned to the whole segment. Since
# the tiers are based on assigned /segment/ speakers, not assigned
# word speakers, we need to look up the tier in the segment loop
# not in the word loop. See Issue #7
if 'speaker' not in segment:
warnings.warn('No speaker for segment')
#print(segment)
continue
tier_index = tier_key[segment['speaker']]
tier = tg.tiers[tier_index]
minTime = segment['start']
maxTime = segment['end']
if i+1 == len(diarized_transcription['segments']):
maxTime = segment['end']
else:
maxTime = diarized_transcription['segments'][i+1]['start']
mark = segment['text']
tier.add(minTime,maxTime,mark)
# In testing, the word-level alignments are not very good. A future version
# might want to add an option for end users to enable the following loop.
#for word in segment['words']:
# minTime = word['start']
# maxTime = word['end']
# mark = word['word']
# tier.add(minTime,maxTime,mark)
if by_phrase:
tier.add(minTime,maxTime,mark)
continue
for word in segment['words']:
if 'speaker' not in word:
warnings.warn('No speaker assigned to word, using phrase-level speaker')
elif word['speaker'] != segment['speaker']:
warnings.warn('Mismatched speaker for word and phrase, using phrase-level speaker')
#print(word['speaker'],word)
#print(segment['speaker'],segment)
#raise ValueError('Word and segment have different speakers')
minTime = word['start']
maxTime = word['end']
mark = word['text']
tier.add(minTime,maxTime,mark)
return tg
25 changes: 16 additions & 9 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,33 @@
import psutil
import GPUtil
import matplotlib.pyplot as plt
import whisper
import whisper_timestamped as whisper
from whisperx import load_align_model, align
from whisperx.diarize import DiarizationPipeline, assign_word_speakers

def transcribe(audio_file: str, model_name: str, device: str = "cpu") -> Dict[str, Any]:
def transcribe(
audio_file: str,
model_name: str,
device: str = "cpu",
detect_disfluencies: bool = True
) -> Dict[str, Any]:
"""
Transcribe an audio file using a whisper model.
Args:
audio_file: Path to the audio file to transcribe.
model_name: Name of the model to use for transcription.
device: The device to use for inference (e.g., "cpu" or "cuda").
detect_disfluencies: Flag for whether the transcription should include disfluencies, marked with [*]
Returns:
A dictionary representing the transcript segments and language code.
"""
model = whisper.load_model(model_name, device)
result = model.transcribe(audio_file)
model = whisper.load_model(model_name, device=device)
audio = whisper.load_audio(audio_file)
result = whisper.transcribe(model, audio_file,detect_disfluencies=detect_disfluencies)

language_code = result["language"]
language_code = result['language']
return {
"segments": result["segments"],
"language_code": language_code,
Expand Down Expand Up @@ -130,11 +137,11 @@ def transcribe_and_diarize(
spoken text, and the speaker ID.
"""
transcript = transcribe(audio_file, model_name, device)
aligned_segments = align_segments(
transcript["segments"], transcript["language_code"], audio_file, device
)
#aligned_segments = align_segments(
# transcript["segments"], transcript["language_code"], audio_file, device
#)
diarization_result = diarize(audio_file, hf_token)
results_segments_w_speakers = assign_speakers(diarization_result, aligned_segments)
results_segments_w_speakers = assign_speakers(diarization_result, transcript)

# Print the results in a user-friendly way
for i, segment in enumerate(results_segments_w_speakers['segments']):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
openai-whisper @ git+https://github.com/openai/whisper.git@b38a1f20f4b23f3f3099af2c3e0ca95627276ddf
whisperx @ git+https://github.com/m-bain/whisperx.git@49e0130e4e0c0d99d60715d76e65a71826a97109
whisper_timestamped
GPUtil
psutil
textgrid

Large diffs are not rendered by default.

31 changes: 24 additions & 7 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
import math
import json
import numpy.testing as nptest
import pytest
import textgrid
import warnings

import formatter

class TestFormatter():
Format = formatter.Formatter()

def test_to_TextGrid(self):
for input_fname, ex_fname in self.provide_to_TextGrid():
for input_fname, by_phrase in self.provide_to_TextGrid():
with open(input_fname) as f:
case = json.load(f)
observed = self.Format.to_TextGrid(case)
observed = self.Format.to_TextGrid(case, by_phrase=by_phrase)

expected = textgrid.TextGrid()
expected.read(ex_fname)
assert observed.maxTime is not None
assert len(observed.tiers) > 0

nptest.assert_array_equal(observed,expected)
def test_no_speaker_warning(self):
for input_fname in self.provide_no_speaker_warning():
with open(input_fname) as f:
case = json.load(f)
with pytest.warns(UserWarning, match="No speaker for segment") as record:
_ = self.Format.to_TextGrid(case, by_phrase=False)

def provide_to_TextGrid(self):
return [
(
'tests/data/TestAudio_SnoopDogg_85SouthMedia_segments.json',
'tests/data/TestAudio_SnoopDogg_85SouthMedia.TextGrid'
'tests/data/TestAudio_SnoopDogg_85SouthMedia_WhisperTimestampSegments.json',
True
),
(
'tests/data/TestAudio_SnoopDogg_85SouthMedia_WhisperTimestampSegments.json',
False
),
]

def provide_no_speaker_warning(self):
return [
'tests/data/TestAudio_SnoopDogg_85SouthMedia.json',
]

0 comments on commit 4c2989e

Please sign in to comment.