diff --git a/whisper_online.py b/whisper_online.py index 51a9ab6..8efbbab 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -26,12 +26,15 @@ class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when neeeded) - def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None): + def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): + self.logfile = logfile + self.transcribe_kargs = {} self.original_language = lan self.model = self.load_model(modelsize, cache_dir, model_dir) + def load_model(self, modelsize, cache_dir): raise NotImplemented("must be implemented in the child class") @@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase): sep = " " def load_model(self, modelsize=None, cache_dir=None, model_dir=None): - global whisper_timestamped # has to be global as it is used at each `transcribe` call import whisper - import whisper_timestamped + from whisper_timestamped import transcribe_timestamped + self.transcribe_timestamped = transcribe_timestamped if model_dir is not None: print("ignoring model_dir, not implemented",file=self.logfile) return whisper.load_model(modelsize, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): - result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True) + result = self.transcribe_timestamped(self.model, + audio, language=self.original_language, + initial_prompt=init_prompt, verbose=None, + condition_on_previous_text=True, **self.transcribe_kargs) return result def ts_words(self,r): @@ -74,7 +80,12 @@ def segments_end_ts(self, res): return [s["end"] for s in res["segments"]] def use_vad(self): - raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.") + self.transcribe_kargs["vad"] = True + + def set_translate_task(self): + self.transcribe_kargs["task"] = "translate" + + class FasterWhisperASR(ASRBase): @@ -135,7 +146,6 @@ def set_translate_task(self): class HypothesisBuffer: def __init__(self, logfile=sys.stderr): - """output: where to store the log. Leave it unchanged to print to terminal.""" self.commited_in_buffer = [] self.buffer = [] self.new = [] @@ -205,7 +215,7 @@ class OnlineASRProcessor: def __init__(self, asr, tokenizer, logfile=sys.stderr): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. - output: where to store the log. Leave it unchanged to print to terminal. + logfile: where to store the log. """ self.asr = asr self.tokenizer = tokenizer @@ -468,21 +478,24 @@ def split(self, sent): parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') args = parser.parse_args() + # reset to store stderr to different file stream, e.g. open(os.devnull,"w") + logfile = sys.stderr + if args.offline and args.comp_unaware: - print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=sys.stderr) + print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile) sys.exit(1) audio_path = args.audio_path SAMPLING_RATE = 16000 duration = len(load_audio(audio_path))/SAMPLING_RATE - print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr) + print("Audio duration is: %2.2f seconds" % duration, file=logfile) size = args.model language = args.lan t = time.time() - print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True) + print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) if args.backend == "faster-whisper": asr_cls = FasterWhisperASR @@ -499,15 +512,15 @@ def split(self, sent): e = time.time() - print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr) + print(f"done. It took {round(e-t,2)} seconds.",file=logfile) if args.vad: - print("setting VAD filter",file=sys.stderr) + print("setting VAD filter",file=logfile) asr.use_vad() min_chunk = args.min_chunk_size - online = OnlineASRProcessor(asr,create_tokenizer(tgt_language)) + online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile) # load the audio into the LRU cache before we start the timer @@ -529,10 +542,10 @@ def output_transcript(o, now=None): if now is None: now = time.time()-start if o[0] is not None: - print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True) + print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True) print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True) else: - print(o,file=sys.stderr,flush=True) + print(o,file=logfile,flush=True) if args.offline: ## offline mode processing (for testing/debugging) a = load_audio(audio_path) @@ -540,7 +553,7 @@ def output_transcript(o, now=None): try: o = online.process_iter() except AssertionError: - print("assertion error",file=sys.stderr) + print("assertion error",file=logfile) pass else: output_transcript(o) @@ -553,12 +566,12 @@ def output_transcript(o, now=None): try: o = online.process_iter() except AssertionError: - print("assertion error",file=sys.stderr) + print("assertion error",file=logfile) pass else: output_transcript(o, now=end) - print(f"## last processed {end:.2f}s",file=sys.stderr,flush=True) + print(f"## last processed {end:.2f}s",file=logfile,flush=True) beg = end end += min_chunk @@ -580,12 +593,12 @@ def output_transcript(o, now=None): try: o = online.process_iter() except AssertionError: - print("assertion error",file=sys.stderr) + print("assertion error",file=logfile) pass else: output_transcript(o) now = time.time() - start - print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr,flush=True) + print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True) if end >= duration: break