Skip to content

Commit

Permalink
logfile reviewed, whisper_timestamped loading module and vad
Browse files Browse the repository at this point in the history
 PR #10, issues #9, #30
  • Loading branch information
Gldkslfmsd committed Nov 28, 2023
1 parent 39e06b5 commit 0a50eec
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions whisper_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when neeeded)

def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
self.logfile = logfile

self.transcribe_kargs = {}
self.original_language = lan

self.model = self.load_model(modelsize, cache_dir, model_dir)


def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")

Expand All @@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
sep = " "

def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
global whisper_timestamped # has to be global as it is used at each `transcribe` call
import whisper
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
print("ignoring model_dir, not implemented",file=self.logfile)
return whisper.load_model(modelsize, download_root=cache_dir)

def transcribe(self, audio, init_prompt=""):
result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
result = self.transcribe_timestamped(self.model,
audio, language=self.original_language,
initial_prompt=init_prompt, verbose=None,
condition_on_previous_text=True, **self.transcribe_kargs)
return result

def ts_words(self,r):
Expand All @@ -74,7 +80,12 @@ def segments_end_ts(self, res):
return [s["end"] for s in res["segments"]]

def use_vad(self):
raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.")
self.transcribe_kargs["vad"] = True

def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"




class FasterWhisperASR(ASRBase):
Expand Down Expand Up @@ -135,7 +146,6 @@ def set_translate_task(self):
class HypothesisBuffer:

def __init__(self, logfile=sys.stderr):
"""output: where to store the log. Leave it unchanged to print to terminal."""
self.commited_in_buffer = []
self.buffer = []
self.new = []
Expand Down Expand Up @@ -205,7 +215,7 @@ class OnlineASRProcessor:
def __init__(self, asr, tokenizer, logfile=sys.stderr):
"""asr: WhisperASR object
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
output: where to store the log. Leave it unchanged to print to terminal.
logfile: where to store the log.
"""
self.asr = asr
self.tokenizer = tokenizer
Expand Down Expand Up @@ -468,21 +478,24 @@ def split(self, sent):
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
args = parser.parse_args()

# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = sys.stderr

if args.offline and args.comp_unaware:
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=sys.stderr)
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
sys.exit(1)

audio_path = args.audio_path

SAMPLING_RATE = 16000
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
print("Audio duration is: %2.2f seconds" % duration, file=logfile)

size = args.model
language = args.lan

t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)

if args.backend == "faster-whisper":
asr_cls = FasterWhisperASR
Expand All @@ -499,15 +512,15 @@ def split(self, sent):


e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)

if args.vad:
print("setting VAD filter",file=sys.stderr)
print("setting VAD filter",file=logfile)
asr.use_vad()


min_chunk = args.min_chunk_size
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)


# load the audio into the LRU cache before we start the timer
Expand All @@ -529,18 +542,18 @@ def output_transcript(o, now=None):
if now is None:
now = time.time()-start
if o[0] is not None:
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
else:
print(o,file=sys.stderr,flush=True)
print(o,file=logfile,flush=True)

if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
Expand All @@ -553,12 +566,12 @@ def output_transcript(o, now=None):
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o, now=end)

print(f"## last processed {end:.2f}s",file=sys.stderr,flush=True)
print(f"## last processed {end:.2f}s",file=logfile,flush=True)

beg = end
end += min_chunk
Expand All @@ -580,12 +593,12 @@ def output_transcript(o, now=None):
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=sys.stderr)
print("assertion error",file=logfile)
pass
else:
output_transcript(o)
now = time.time() - start
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr,flush=True)
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)

if end >= duration:
break
Expand Down

0 comments on commit 0a50eec

Please sign in to comment.