From 6e6b61925784d1c76cb21f687c3f6c59e6d82698 Mon Sep 17 00:00:00 2001 From: Luca Date: Wed, 6 Sep 2023 15:19:12 +0200 Subject: [PATCH] add option to save log to file --- whisper_online.py | 68 ++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/whisper_online.py b/whisper_online.py index 266fc72..82e5dd6 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -46,10 +46,6 @@ class ASRBase: raise NotImplemented("must be implemented in the child class") -## requires imports: -# import whisper -# import whisper_timestamped - class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. On the other hand, the installation for GPU could be easier. @@ -64,7 +60,7 @@ class WhisperTimestampedASR(ASRBase): def load_model(self, modelsize=None, cache_dir=None, model_dir=None): if model_dir is not None: - print("ignoring model_dir, not implemented",file=sys.stderr) + print("ignoring model_dir, not implemented",file=self.output) return whisper.load_model(modelsize, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): @@ -89,9 +85,6 @@ class WhisperTimestampedASR(ASRBase): class FasterWhisperASR(ASRBase): """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version. - - Requires imports, if used: - import faster_whisper """ sep = "" @@ -101,11 +94,8 @@ class FasterWhisperASR(ASRBase): import faster_whisper def load_model(self, modelsize=None, cache_dir=None, model_dir=None): - #from faster_whisper import WhisperModel - - if model_dir is not None: - print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=sys.stderr) + print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.output) model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = modelsize @@ -153,7 +143,8 @@ class FasterWhisperASR(ASRBase): class HypothesisBuffer: - def __init__(self): + def __init__(self, output=sys.stderr): + """output: where to store the log. Leave it unchanged to print to terminal.""" self.commited_in_buffer = [] self.buffer = [] self.new = [] @@ -161,6 +152,8 @@ class HypothesisBuffer: self.last_commited_time = 0 self.last_commited_word = None + self.output = output + def insert(self, new, offset): # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content # the new tail is added to self.new @@ -179,9 +172,9 @@ class HypothesisBuffer: c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1]) tail = " ".join(self.new[j-1][2] for j in range(1,i+1)) if c == tail: - print("removing last",i,"words:",file=sys.stderr) + print("removing last",i,"words:",file=self.output) for j in range(i): - print("\t",self.new.pop(0),file=sys.stderr) + print("\t",self.new.pop(0),file=self.output) break def flush(self): @@ -218,12 +211,14 @@ class OnlineASRProcessor: SAMPLING_RATE = 16000 - def __init__(self, asr, tokenizer): + def __init__(self, asr, tokenizer, output=sys.stderr): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. + output: where to store the log. Leave it unchanged to print to terminal. """ self.asr = asr self.tokenizer = tokenizer + self.output = output self.init() @@ -232,7 +227,7 @@ class OnlineASRProcessor: self.audio_buffer = np.array([],dtype=np.float32) self.buffer_time_offset = 0 - self.transcript_buffer = HypothesisBuffer() + self.transcript_buffer = HypothesisBuffer(output=self.output) self.commited = [] self.last_chunked_at = 0 @@ -263,13 +258,13 @@ class OnlineASRProcessor: def process_iter(self): """Runs on the current audio buffer. Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). - The non-emty text is confirmed (commited) partial transcript. + The non-emty text is confirmed (committed) partial transcript. """ prompt, non_prompt = self.prompt() - print("PROMPT:", prompt, file=sys.stderr) - print("CONTEXT:", non_prompt, file=sys.stderr) - print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=sys.stderr) + print("PROMPT:", prompt, file=self.output) + print("CONTEXT:", non_prompt, file=self.output) + print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.output) res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) # transform to [(beg,end,"word1"), ...] @@ -278,8 +273,8 @@ class OnlineASRProcessor: self.transcript_buffer.insert(tsw, self.buffer_time_offset) o = self.transcript_buffer.flush() self.commited.extend(o) - print(">>>>COMPLETE NOW:",self.to_flush(o),file=sys.stderr,flush=True) - print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True) + print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.output,flush=True) + print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.output,flush=True) # there is a newly confirmed text if o: @@ -298,14 +293,14 @@ class OnlineASRProcessor: # elif self.transcript_buffer.complete(): # self.silence_iters = 0 # elif not self.transcript_buffer.complete(): -# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True) +# # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.output,flush=True) # self.silence_iters += 1 # if self.silence_iters >= 3: # n = self.last_chunked_at ## self.chunk_completed_sentence() ## if n == self.last_chunked_at: # self.chunk_at(self.last_chunked_at+self.chunk) -# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=sys.stderr) +# print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.output) ## self.silence_iters = 0 @@ -321,18 +316,18 @@ class OnlineASRProcessor: #while k>0 and self.commited[k][1] > l: # k -= 1 #t = self.commited[k][1] - print(f"chunking because of len",file=sys.stderr) + print(f"chunking because of len",file=self.output) #self.chunk_at(t) - print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr) + print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.output) return self.to_flush(o) def chunk_completed_sentence(self): if self.commited == []: return - print(self.commited,file=sys.stderr) + print(self.commited,file=self.output) sents = self.words_to_sentences(self.commited) for s in sents: - print("\t\tSENT:",s,file=sys.stderr) + print("\t\tSENT:",s,file=self.output) if len(sents) < 2: return while len(sents) > 2: @@ -340,7 +335,7 @@ class OnlineASRProcessor: # we will continue with audio processing at this timestamp chunk_at = sents[-2][1] - print(f"--- sentence chunked at {chunk_at:2.2f}",file=sys.stderr) + print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.output) self.chunk_at(chunk_at) def chunk_completed_segment(self, res): @@ -357,12 +352,12 @@ class OnlineASRProcessor: ends.pop(-1) e = ends[-2]+self.buffer_time_offset if e <= t: - print(f"--- segment chunked at {e:2.2f}",file=sys.stderr) + print(f"--- segment chunked at {e:2.2f}",file=self.output) self.chunk_at(e) else: - print(f"--- last segment not within commited area",file=sys.stderr) + print(f"--- last segment not within commited area",file=self.output) else: - print(f"--- not enough segments to chunk",file=sys.stderr) + print(f"--- not enough segments to chunk",file=self.output) @@ -408,7 +403,7 @@ class OnlineASRProcessor: """ o = self.transcript_buffer.complete() f = self.to_flush(o) - print("last, noncommited:",f,file=sys.stderr) + print("last, noncommited:",f,file=self.output) return f @@ -473,15 +468,10 @@ if __name__ == "__main__": t = time.time() print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True) - #asr = WhisperASR(lan=language, modelsize=size) if args.backend == "faster-whisper": - #from faster_whisper import WhisperModel asr_cls = FasterWhisperASR else: - #import whisper - #import whisper_timestamped - # from whisper_timestamped_model import WhisperTimestampedASR asr_cls = WhisperTimestampedASR asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)